feat(api): enable customizable calls functions (#457)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2024-09-02 12:22:48 -04:00 committed by GitHub
parent d520f09333
commit 7266661413
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 291 additions and 267 deletions

View File

@ -7,13 +7,73 @@ local Utils = require("avante.utils")
---@field hint ToggleBind.wrap ---@field hint ToggleBind.wrap
--- ---
---@class avante.Api ---@class avante.Api
---@field ask fun(): boolean ---@field ask fun(question:string?): boolean
---@field edit fun(): nil ---@field edit fun(question:string?): nil
---@field refresh fun(): nil ---@field refresh fun(): nil
---@field build fun(): boolean ---@field build fun(): boolean
---@field switch_provider fun(target: string): nil
---@field toggle avante.ApiToggle ---@field toggle avante.ApiToggle
return setmetatable({}, { local M = {}
---@param target Provider
M.switch_provider = function(target)
require("avante.providers").refresh(target)
end
---@param question? string
M.ask = function(question)
if not require("avante").toggle() then
return false
end
if question == nil or question == "" then
return true
end
vim.api.nvim_exec_autocmds("User", { pattern = "AvanteInputSubmitted", data = { request = question } })
return true
end
---@param question? string
M.edit = function(question)
local _, selection = require("avante").get()
if not selection then
return
end
selection:create_editing_input()
if question ~= nil or question ~= "" then
vim.api.nvim_exec_autocmds("User", { pattern = "AvanteEditSubmitted", data = { request = question } })
end
end
M.refresh = function()
local sidebar, _ = require("avante").get()
if not sidebar then
return
end
if not sidebar:is_open() then
return
end
local curbuf = vim.api.nvim_get_current_buf()
local focused = sidebar.result.bufnr == curbuf or sidebar.input.bufnr == curbuf
if focused or not sidebar:is_open() then
return
end
local listed = vim.api.nvim_get_option_value("buflisted", { buf = curbuf })
if Utils.is_sidebar_buffer(curbuf) or not listed then
return
end
local curwin = vim.api.nvim_get_current_win()
sidebar:close()
sidebar.code.winid = curwin
sidebar.code.bufnr = curbuf
sidebar:render()
end
return setmetatable(M, {
__index = function(t, k) __index = function(t, k)
local module = require("avante") local module = require("avante")
---@class AvailableApi: ApiCaller ---@class AvailableApi: ApiCaller

View File

@ -7,10 +7,10 @@ local Utils = require("avante.utils")
local M = {} local M = {}
---@class avante.Config ---@class avante.Config
---@field silent_warning boolean will be determined from debug ---@field silent_warning? boolean will be determined from debug
M.defaults = { M.defaults = {
debug = false, debug = false,
---@alias Provider "claude" | "openai" | "azure" | "gemini" | "cohere" | "copilot" | string ---@alias Provider "claude" | "openai" | "azure" | "gemini" | "cohere" | "copilot" | [string]
provider = "claude", -- Only recommend using Claude provider = "claude", -- Only recommend using Claude
---@alias Tokenizer "tiktoken" | "hf" ---@alias Tokenizer "tiktoken" | "hf"
-- Used for counting tokens and encoding text. -- Used for counting tokens and encoding text.
@ -111,7 +111,6 @@ M.defaults = {
ours = "co", ours = "co",
theirs = "ct", theirs = "ct",
all_theirs = "ca", all_theirs = "ca",
none = "c0",
both = "cb", both = "cb",
cursor = "cc", cursor = "cc",
next = "]x", next = "]x",
@ -130,6 +129,7 @@ M.defaults = {
edit = "<leader>ae", edit = "<leader>ae",
refresh = "<leader>ar", refresh = "<leader>ar",
toggle = { toggle = {
default = "<leader>at",
debug = "<leader>ad", debug = "<leader>ad",
hint = "<leader>ah", hint = "<leader>ah",
}, },
@ -151,11 +151,9 @@ M.defaults = {
border = "rounded", border = "rounded",
}, },
}, },
--- @class AvanteConflictUserConfig --- @class AvanteConflictConfig
diff = { diff = {
autojump = true, autojump = true,
---@type string | fun(): any
list_opener = "copen",
}, },
--- @class AvanteHintsConfig --- @class AvanteHintsConfig
hints = { hints = {
@ -166,11 +164,14 @@ M.defaults = {
---@type avante.Config ---@type avante.Config
M.options = {} M.options = {}
---@class avante.ConflictConfig: AvanteConflictUserConfig ---@class avante.ConflictConfig: AvanteConflictConfig
---@field mappings AvanteConflictMappings ---@field mappings AvanteConflictMappings
---@field highlights AvanteConflictHighlights ---@field highlights AvanteConflictHighlights
M.diff = {} M.diff = {}
---@type Provider[]
M.providers = {}
---@param opts? avante.Config ---@param opts? avante.Config
function M.setup(opts) function M.setup(opts)
vim.validate({ opts = { opts, "table", true } }) vim.validate({ opts = { opts, "table", true } })
@ -190,6 +191,16 @@ function M.setup(opts)
-- set silent_warning to true if debug is false -- set silent_warning to true if debug is false
M.options.silent_warning = not M.options.debug M.options.silent_warning = not M.options.debug
end end
M.providers = vim
.iter(M.defaults)
:filter(function(_, value)
return type(value) == "table" and value.endpoint ~= nil
end)
:fold({}, function(acc, k)
acc = vim.list_extend({}, acc)
acc = vim.list_extend(acc, { k })
return acc
end)
vim.validate({ provider = { M.options.provider, "string", false } }) vim.validate({ provider = { M.options.provider, "string", false } })
@ -205,6 +216,7 @@ function M.setup(opts)
M.options.vendors[k] = type(v) == "function" and v() or v M.options.vendors[k] = type(v) == "function" and v() or v
end end
vim.validate({ vendors = { M.options.vendors, "table", true } }) vim.validate({ vendors = { M.options.vendors, "table", true } })
M.providers = vim.list_extend(M.providers, vim.tbl_keys(M.options.vendors))
end end
end end
@ -228,6 +240,9 @@ function M.override(opts)
if next(M.options.vendors) ~= nil then if next(M.options.vendors) ~= nil then
for k, v in pairs(M.options.vendors) do for k, v in pairs(M.options.vendors) do
M.options.vendors[k] = type(v) == "function" and v() or v M.options.vendors[k] = type(v) == "function" and v() or v
if not vim.tbl_contains(M.providers, k) then
M.providers = vim.list_extend(M.providers, { k })
end
end end
vim.validate({ vendors = { M.options.vendors, "table", true } }) vim.validate({ vendors = { M.options.vendors, "table", true } })
end end

View File

@ -1,15 +1,12 @@
-- This file COPY and MODIFIED based on: https://github.com/akinsho/git-conflict.nvim/blob/main/lua/git-conflict.lua local api = vim.api
local M = {}
local Config = require("avante.config") local Config = require("avante.config")
local Utils = require("avante.utils") local Utils = require("avante.utils")
local Highlights = require("avante.highlights") local Highlights = require("avante.highlights")
local fn = vim.fn local H = {}
local api = vim.api local M = {}
local fmt = string.format
local map = vim.keymap.set
-----------------------------------------------------------------------------// -----------------------------------------------------------------------------//
-- REFERENCES: -- REFERENCES:
-----------------------------------------------------------------------------// -----------------------------------------------------------------------------//
@ -31,7 +28,6 @@ local map = vim.keymap.set
--- @class AvanteConflictHighlights --- @class AvanteConflictHighlights
--- @field current string --- @field current string
--- @field incoming string --- @field incoming string
--- @field ancestor string?
---@class RangeMark ---@class RangeMark
---@field label integer ---@field label integer
@ -40,7 +36,6 @@ local map = vim.keymap.set
--- @class PositionMarks --- @class PositionMarks
--- @field current RangeMark --- @field current RangeMark
--- @field incoming RangeMark --- @field incoming RangeMark
--- @field ancestor RangeMark
--- @class Range --- @class Range
--- @field range_start integer --- @field range_start integer
@ -69,7 +64,6 @@ local SIDES = {
THEIRS = "theirs", THEIRS = "theirs",
ALL_THEIRS = "all_theirs", ALL_THEIRS = "all_theirs",
BOTH = "both", BOTH = "both",
BASE = "base",
NONE = "none", NONE = "none",
CURSOR = "cursor", CURSOR = "cursor",
} }
@ -78,7 +72,6 @@ local SIDES = {
local name_map = { local name_map = {
ours = "current", ours = "current",
theirs = "incoming", theirs = "incoming",
base = "ancestor",
both = "both", both = "both",
none = "none", none = "none",
cursor = "cursor", cursor = "cursor",
@ -86,10 +79,8 @@ local name_map = {
local CURRENT_HL = "AvanteConflictCurrent" local CURRENT_HL = "AvanteConflictCurrent"
local INCOMING_HL = "AvanteConflictIncoming" local INCOMING_HL = "AvanteConflictIncoming"
local ANCESTOR_HL = "AvanteConflictAncestor"
local CURRENT_LABEL_HL = "AvanteConflictCurrentLabel" local CURRENT_LABEL_HL = "AvanteConflictCurrentLabel"
local INCOMING_LABEL_HL = "AvanteConflictIncomingLabel" local INCOMING_LABEL_HL = "AvanteConflictIncomingLabel"
local ANCESTOR_LABEL_HL = "AvanteConflictAncestorLabel"
local PRIORITY = vim.highlight.priorities.user local PRIORITY = vim.highlight.priorities.user
local NAMESPACE = api.nvim_create_namespace("avante-conflict") local NAMESPACE = api.nvim_create_namespace("avante-conflict")
local KEYBINDING_NAMESPACE = api.nvim_create_namespace("avante-conflict-keybinding") local KEYBINDING_NAMESPACE = api.nvim_create_namespace("avante-conflict-keybinding")
@ -98,7 +89,6 @@ local AUGROUP_NAME = "avante_conflicts"
local conflict_start = "^<<<<<<<" local conflict_start = "^<<<<<<<"
local conflict_middle = "^=======" local conflict_middle = "^======="
local conflict_end = "^>>>>>>>" local conflict_end = "^>>>>>>>"
local conflict_ancestor = "^|||||||"
-----------------------------------------------------------------------------// -----------------------------------------------------------------------------//
@ -208,16 +198,7 @@ local function highlight_conflicts(positions, lines)
position.marks = { position.marks = {
current = { label = curr_label_id, content = curr_id }, current = { label = curr_label_id, content = curr_id },
incoming = { label = inc_label_id, content = inc_id }, incoming = { label = inc_label_id, content = inc_id },
ancestor = {},
} }
if not vim.tbl_isempty(position.ancestor) then
local ancestor_start = position.ancestor.range_start
local ancestor_end = position.ancestor.range_end
local ancestor_label = lines[ancestor_start + 1] .. " (Base changes)"
local id = hl_range(bufnr, ANCESTOR_HL, ancestor_start + 1, ancestor_end + 1)
local label_id = draw_section_label(bufnr, ANCESTOR_LABEL_HL, ancestor_label, ancestor_start)
position.marks.ancestor = { label = label_id, content = id }
end
end end
end end
@ -228,7 +209,7 @@ end
---@return ConflictPosition[] ---@return ConflictPosition[]
local function detect_conflicts(lines) local function detect_conflicts(lines)
local positions = {} local positions = {}
local position, has_middle, has_ancestor = nil, false, false local position, has_middle = nil, false
for index, line in ipairs(lines) do for index, line in ipairs(lines) do
local lnum = index - 1 local lnum = index - 1
if line:match(conflict_start) then if line:match(conflict_start) then
@ -236,25 +217,12 @@ local function detect_conflicts(lines)
current = { range_start = lnum, content_start = lnum + 1 }, current = { range_start = lnum, content_start = lnum + 1 },
middle = {}, middle = {},
incoming = {}, incoming = {},
ancestor = {},
} }
end end
if position ~= nil and line:match(conflict_ancestor) then
has_ancestor = true
position.ancestor.range_start = lnum
position.ancestor.content_start = lnum + 1
position.current.range_end = lnum - 1
position.current.content_end = lnum - 1
end
if position ~= nil and line:match(conflict_middle) then if position ~= nil and line:match(conflict_middle) then
has_middle = true has_middle = true
if has_ancestor then position.current.range_end = lnum - 1
position.ancestor.content_end = lnum - 1 position.current.content_end = lnum - 1
position.ancestor.range_end = lnum - 1
else
position.current.range_end = lnum - 1
position.current.content_end = lnum - 1
end
position.middle.range_start = lnum position.middle.range_start = lnum
position.middle.range_end = lnum + 1 position.middle.range_end = lnum + 1
position.incoming.range_start = lnum + 1 position.incoming.range_start = lnum + 1
@ -265,7 +233,7 @@ local function detect_conflicts(lines)
position.incoming.content_end = lnum - 1 position.incoming.content_end = lnum - 1
positions[#positions + 1] = position positions[#positions + 1] = position
position, has_middle, has_ancestor = nil, false, false position, has_middle = nil, false
end end
end end
return #positions > 0, positions return #positions > 0, positions
@ -376,7 +344,7 @@ local function parse_buffer(bufnr, range_start, range_end)
else else
M.clear(bufnr) M.clear(bufnr)
end end
if prev_conflicts ~= has_conflict or not vim.b[bufnr].conflict_mappings_set then if prev_conflicts ~= has_conflict or not vim.b[bufnr].avante_conflict_mappings_set then
local pattern = has_conflict and "AvanteConflictDetected" or "AvanteConflictResolved" local pattern = has_conflict and "AvanteConflictDetected" or "AvanteConflictResolved"
api.nvim_exec_autocmds("User", { pattern = pattern }) api.nvim_exec_autocmds("User", { pattern = pattern })
end end
@ -384,6 +352,8 @@ end
---Process a buffer if the changed tick has changed ---Process a buffer if the changed tick has changed
---@param bufnr integer? ---@param bufnr integer?
---@param range_start integer?
---@param range_end integer?
function M.process(bufnr, range_start, range_end) function M.process(bufnr, range_start, range_end)
bufnr = bufnr or api.nvim_get_current_buf() bufnr = bufnr or api.nvim_get_current_buf()
if visited_buffers[bufnr] and visited_buffers[bufnr].tick == vim.b[bufnr].changedtick then if visited_buffers[bufnr] and visited_buffers[bufnr].tick == vim.b[bufnr].changedtick then
@ -392,129 +362,62 @@ function M.process(bufnr, range_start, range_end)
parse_buffer(bufnr, range_start, range_end) parse_buffer(bufnr, range_start, range_end)
end end
-----------------------------------------------------------------------------//
-- Commands
-----------------------------------------------------------------------------//
local function set_commands()
local command = api.nvim_create_user_command
command("AvanteConflictListQf", function()
M.conflicts_to_qf_items(function(items)
if #items > 0 then
fn.setqflist(items, "r")
if type(Config.diff.list_opener) == "function" then
Config.diff.list_opener()
else
vim.cmd(Config.diff.list_opener)
end
end
end)
end, { nargs = 0 })
command("AvanteConflictChooseOurs", function()
M.choose("ours")
end, { nargs = 0 })
command("AvanteConflictChooseTheirs", function()
M.choose("theirs")
end, { nargs = 0 })
command("AvanteConflictChooseAllTheirs", function()
M.choose("all_theirs")
end, { nargs = 0 })
command("AvanteConflictChooseBoth", function()
M.choose("both")
end, { nargs = 0 })
command("AvanteConflictChooseCursor", function()
M.choose("cursor")
end, { nargs = 0 })
command("AvanteConflictChooseBase", function()
M.choose("base")
end, { nargs = 0 })
command("AvanteConflictChooseNone", function()
M.choose("none")
end, { nargs = 0 })
command("AvanteConflictNextConflict", function()
M.find_next("ours")
end, { nargs = 0 })
command("AvanteConflictPrevConflict", function()
M.find_prev("ours")
end, { nargs = 0 })
end
-----------------------------------------------------------------------------// -----------------------------------------------------------------------------//
-- Mappings -- Mappings
-----------------------------------------------------------------------------// -----------------------------------------------------------------------------//
local function set_plug_mappings()
local function opts(desc)
return { silent = true, desc = "Git Conflict: " .. desc }
end
map({ "n", "v" }, "<Plug>(git-conflict-ours)", "<Cmd>AvanteConflictChooseOurs<CR>", opts("Choose Ours"))
map({ "n", "v" }, "<Plug>(git-conflict-both)", "<Cmd>AvanteConflictChooseBoth<CR>", opts("Choose Both"))
map({ "n", "v" }, "<Plug>(git-conflict-none)", "<Cmd>AvanteConflictChooseNone<CR>", opts("Choose None"))
map({ "n", "v" }, "<Plug>(git-conflict-theirs)", "<Cmd>AvanteConflictChooseTheirs<CR>", opts("Choose Theirs"))
map(
{ "n", "v" },
"<Plug>(git-conflict-all-theirs)",
"<Cmd>AvanteConflictChooseAllTheirs<CR>",
opts("Choose All Theirs")
)
map("n", "<Plug>(git-conflict-cursor)", "<Cmd>AvanteConflictChooseCursor<CR>", opts("Choose Cursor"))
map("n", "<Plug>(git-conflict-next-conflict)", "<Cmd>AvanteConflictNextConflict<CR>", opts("Next Conflict"))
map("n", "<Plug>(git-conflict-prev-conflict)", "<Cmd>AvanteConflictPrevConflict<CR>", opts("Previous Conflict"))
end
---@param bufnr integer given buffer id ---@param bufnr integer given buffer id
local function setup_buffer_mappings(bufnr) H.setup_buffer_mappings = function(bufnr)
---@param desc string ---@param desc string
local function opts(desc) local function opts(desc)
return { silent = true, buffer = bufnr, desc = "Git Conflict: " .. desc } return { silent = true, buffer = bufnr, desc = "avante(conflict): " .. desc }
end end
map({ "n", "v" }, Config.diff.mappings.ours, "<Plug>(git-conflict-ours)", opts("Choose Ours")) vim.keymap.set({ "n", "v" }, Config.diff.mappings.ours, function()
map({ "n", "v" }, Config.diff.mappings.both, "<Plug>(git-conflict-both)", opts("Choose Both")) M.choose("ours")
map({ "n", "v" }, Config.diff.mappings.none, "<Plug>(git-conflict-none)", opts("Choose None")) end, opts("choose ours"))
map({ "n", "v" }, Config.diff.mappings.theirs, "<Plug>(git-conflict-theirs)", opts("Choose Theirs")) vim.keymap.set({ "n", "v" }, Config.diff.mappings.both, function()
map({ "n", "v" }, Config.diff.mappings.all_theirs, "<Plug>(git-conflict-all-theirs)", opts("Choose All Theirs")) M.choose("both")
map({ "v", "v" }, Config.diff.mappings.ours, "<Plug>(git-conflict-ours)", opts("Choose Ours")) end, opts("choose both"))
map("n", Config.diff.mappings.cursor, "<Plug>(git-conflict-cursor)", opts("Choose Cursor")) vim.keymap.set({ "n", "v" }, Config.diff.mappings.theirs, function()
-- map('V', Config.diff.mappings.ours, '<Plug>(git-conflict-ours)', opts('Choose Ours')) M.choose("theirs")
map("n", Config.diff.mappings.prev, "<Plug>(git-conflict-prev-conflict)", opts("Previous Conflict")) end, opts("choose theirs"))
map("n", Config.diff.mappings.next, "<Plug>(git-conflict-next-conflict)", opts("Next Conflict")) vim.keymap.set({ "n", "v" }, Config.diff.mappings.all_theirs, function()
vim.b[bufnr].conflict_mappings_set = true M.choose("all_theirs")
end, opts("choose all theirs"))
vim.keymap.set("n", Config.diff.mappings.cursor, function()
M.choose("cursor")
end, opts("choose under cursor"))
vim.keymap.set("n", Config.diff.mappings.prev, function()
M.find_prev("ours")
end, opts("previous conflict"))
vim.keymap.set("n", Config.diff.mappings.next, function()
M.find_next("ours")
end, opts("next conflict"))
vim.b[bufnr].avante_conflict_mappings_set = true
end end
---@param key string ---@param bufnr integer
---@param mode "'n'|'v'|'o'|'nv'|'nvo'"? H.clear_buffer_mappings = function(bufnr)
---@return boolean if not bufnr or not vim.b[bufnr].avante_conflict_mappings_set then
local function is_mapped(key, mode)
return fn.hasmapto(key, mode or "n") > 0
end
local function clear_buffer_mappings(bufnr)
if not bufnr or not vim.b[bufnr].conflict_mappings_set then
return return
end end
for _, mapping in pairs(Config.diff.mappings) do for _, mapping in pairs(Config.diff.mappings) do
if is_mapped(mapping) then if vim.fn.hasmapto(mapping, "n") > 0 then
api.nvim_buf_del_keymap(bufnr, "n", mapping) api.nvim_buf_del_keymap(bufnr, "n", mapping)
end end
end end
vim.b[bufnr].conflict_mappings_set = false vim.b[bufnr].avante_conflict_mappings_set = false
end end
M.augroup = api.nvim_create_augroup(AUGROUP_NAME, { clear = true })
function M.setup() function M.setup()
Highlights.conflict_highlights()
set_commands()
set_plug_mappings()
local augroup = api.nvim_create_augroup(AUGROUP_NAME, { clear = true })
local previous_inlay_enabled = nil local previous_inlay_enabled = nil
api.nvim_create_autocmd("User", { api.nvim_create_autocmd("User", {
group = augroup, group = M.augroup,
pattern = "AvanteConflictDetected", pattern = "AvanteConflictDetected",
callback = function(ev) callback = function(ev)
vim.diagnostic.enable(false, { bufnr = ev.buf }) vim.diagnostic.enable(false, { bufnr = ev.buf })
@ -522,12 +425,12 @@ function M.setup()
previous_inlay_enabled = vim.lsp.inlay_hint.is_enabled({ bufnr = ev.buf }) previous_inlay_enabled = vim.lsp.inlay_hint.is_enabled({ bufnr = ev.buf })
vim.lsp.inlay_hint.enable(false, { bufnr = ev.buf }) vim.lsp.inlay_hint.enable(false, { bufnr = ev.buf })
end end
setup_buffer_mappings(ev.buf) H.setup_buffer_mappings(ev.buf)
end, end,
}) })
api.nvim_create_autocmd("User", { api.nvim_create_autocmd("User", {
group = AUGROUP_NAME, group = M.augroup,
pattern = "AvanteConflictResolved", pattern = "AvanteConflictResolved",
callback = function(ev) callback = function(ev)
vim.diagnostic.enable(true, { bufnr = ev.buf }) vim.diagnostic.enable(true, { bufnr = ev.buf })
@ -535,7 +438,7 @@ function M.setup()
vim.lsp.inlay_hint.enable(previous_inlay_enabled, { bufnr = ev.buf }) vim.lsp.inlay_hint.enable(previous_inlay_enabled, { bufnr = ev.buf })
previous_inlay_enabled = nil previous_inlay_enabled = nil
end end
clear_buffer_mappings(ev.buf) H.clear_buffer_mappings(ev.buf)
end, end,
}) })
@ -565,7 +468,7 @@ local function quickfix_items_from_positions(item, items, visited_buf)
if vim.tbl_contains({ name_map.ours, name_map.theirs, name_map.base }, key) and not vim.tbl_isempty(value) then if vim.tbl_contains({ name_map.ours, name_map.theirs, name_map.base }, key) and not vim.tbl_isempty(value) then
local lnum = value.range_start + 1 local lnum = value.range_start + 1
local next_item = vim.deepcopy(item) local next_item = vim.deepcopy(item)
next_item.text = fmt("%s change", key, lnum) next_item.text = string.format("%s change", key, lnum)
next_item.lnum = lnum next_item.lnum = lnum
next_item.col = 0 next_item.col = 0
table.insert(items, next_item) table.insert(items, next_item)
@ -711,9 +614,6 @@ function M.process_position(bufnr, side, position, enable_autojump)
api.nvim_buf_set_lines(0, pos_start, pos_end, false, lines) api.nvim_buf_set_lines(0, pos_start, pos_end, false, lines)
api.nvim_buf_del_extmark(0, NAMESPACE, position.marks.incoming.label) api.nvim_buf_del_extmark(0, NAMESPACE, position.marks.incoming.label)
api.nvim_buf_del_extmark(0, NAMESPACE, position.marks.current.label) api.nvim_buf_del_extmark(0, NAMESPACE, position.marks.current.label)
if position.marks.ancestor.label then
api.nvim_buf_del_extmark(0, NAMESPACE, position.marks.ancestor.label)
end
parse_buffer(bufnr) parse_buffer(bufnr)
if enable_autojump and Config.diff.autojump then if enable_autojump and Config.diff.autojump then
M.find_next(side) M.find_next(side)

View File

@ -58,6 +58,8 @@ M.setup = function()
api.nvim_set_hl(M.hint_ns, "NormalFloat", { fg = normal_float.fg, bg = normal_float.bg }) api.nvim_set_hl(M.hint_ns, "NormalFloat", { fg = normal_float.fg, bg = normal_float.bg })
api.nvim_set_hl(M.input_ns, "NormalFloat", { fg = normal_float.fg, bg = normal_float.bg }) api.nvim_set_hl(M.input_ns, "NormalFloat", { fg = normal_float.fg, bg = normal_float.bg })
api.nvim_set_hl(M.input_ns, "FloatBorder", { fg = normal.fg, bg = normal.bg }) api.nvim_set_hl(M.input_ns, "FloatBorder", { fg = normal.fg, bg = normal.bg })
M.conflict_highlights()
end end
---@param opts? AvanteConflictHighlights ---@param opts? AvanteConflictHighlights

View File

@ -4,6 +4,7 @@ local Utils = require("avante.utils")
local Sidebar = require("avante.sidebar") local Sidebar = require("avante.sidebar")
local Selection = require("avante.selection") local Selection = require("avante.selection")
local Config = require("avante.config") local Config = require("avante.config")
local Diff = require("avante.diff")
---@class Avante ---@class Avante
local M = { local M = {
@ -20,36 +21,57 @@ M.did_setup = false
local H = {} local H = {}
H.commands = function() H.commands = function()
---@param n string
---@param c vim.api.keyset.user_command.callback
---@param o vim.api.keyset.user_command.opts
local cmd = function(n, c, o) local cmd = function(n, c, o)
o = vim.tbl_extend("force", { nargs = 0 }, o or {}) o = vim.tbl_extend("force", { nargs = 0 }, o or {})
api.nvim_create_user_command("Avante" .. n, c, o) api.nvim_create_user_command("Avante" .. n, c, o)
end end
cmd("Ask", function() cmd("Ask", function(opts)
M.ask() require("avante.api").ask(vim.trim(opts.args))
end, { desc = "avante: ask AI for code suggestions" }) end, { desc = "avante: ask AI for code suggestions", nargs = "*" })
cmd("Edit", function() cmd("Toggle", function()
M.edit() M.toggle()
end, { desc = "avante: edit selected block" }) end, { desc = "avante: toggle AI panel" })
cmd("Edit", function(opts)
require("avante.api").edit(vim.trim(opts.args))
end, { desc = "avante: edit selected block", nargs = "*" })
cmd("Refresh", function() cmd("Refresh", function()
M.refresh() require("avante.api").refresh()
end, { desc = "avante: refresh windows" }) end, { desc = "avante: refresh windows" })
cmd("Build", function() cmd("Build", function()
M.build() M.build()
end, { desc = "avante: build dependencies" }) end, { desc = "avante: build dependencies" })
cmd("SwitchProvider", function(opts)
require("avante.api").switch_provider(vim.trim(opts.args or ""))
end, {
nargs = 1,
desc = "avante: switch provider",
complete = function(_, line, _)
local prefix = line:match("AvanteSwitchProvider%s*(.*)$") or ""
---@param key string
return vim.tbl_filter(function(key)
return key:find(prefix, 1, true) == 1
end, Config.providers)
end,
})
end end
H.keymaps = function() H.keymaps = function()
vim.keymap.set({ "n", "v" }, "<Plug>(AvanteAsk)", function() vim.keymap.set({ "n", "v" }, "<Plug>(AvanteAsk)", function()
M.ask() require("avante.api").ask()
end, { noremap = true }) end, { noremap = true })
vim.keymap.set("v", "<Plug>(AvanteEdit)", function() vim.keymap.set("v", "<Plug>(AvanteEdit)", function()
M.edit() require("avante.api").edit()
end, { noremap = true }) end, { noremap = true })
vim.keymap.set("n", "<Plug>(AvanteRefresh)", function() vim.keymap.set("n", "<Plug>(AvanteRefresh)", function()
M.refresh() require("avante.api").refresh()
end, { noremap = true })
vim.keymap.set("n", "<Plug>(AvanteToggle)", function()
M.toggle()
end, { noremap = true }) end, { noremap = true })
--- the following is kinda considered as internal mappings.
vim.keymap.set("n", "<Plug>(AvanteToggleDebug)", function() vim.keymap.set("n", "<Plug>(AvanteToggleDebug)", function()
M.toggle.debug() M.toggle.debug()
end) end)
@ -57,16 +79,41 @@ H.keymaps = function()
M.toggle.hint() M.toggle.hint()
end) end)
vim.keymap.set({ "n", "v" }, "<Plug>(AvanteConflictOurs)", function()
Diff.choose("ours")
end)
vim.keymap.set({ "n", "v" }, "<Plug>(AvanteConflictBoth)", function()
Diff.choose("both")
end)
vim.keymap.set({ "n", "v" }, "<Plug>(AvanteConflictTheirs)", function()
Diff.choose("theirs")
end)
vim.keymap.set({ "n", "v" }, "<Plug>(AvanteConflictAllTheirs)", function()
Diff.choose("all_theirs")
end)
vim.keymap.set({ "n", "v" }, "<Plug>(AvanteConflictCursor)", function()
Diff.choose("cursor")
end)
vim.keymap.set("n", "<Plug>(AvanteConflictNextConflict)", function()
Diff.find_next("ours")
end)
vim.keymap.set("n", "<Plug>(AvanteConflictPrevConflict)", function()
Diff.find_prev("ours")
end)
if Config.behaviour.auto_set_keymaps then if Config.behaviour.auto_set_keymaps then
Utils.safe_keymap_set({ "n", "v" }, Config.mappings.ask, function() Utils.safe_keymap_set({ "n", "v" }, Config.mappings.ask, function()
M.ask() require("avante.api").ask()
end, { desc = "avante: ask" }) end, { desc = "avante: ask" })
Utils.safe_keymap_set("v", Config.mappings.edit, function() Utils.safe_keymap_set("v", Config.mappings.edit, function()
M.edit() require("avante.api").edit()
end, { desc = "avante: edit" }) end, { desc = "avante: edit" })
Utils.safe_keymap_set("n", Config.mappings.refresh, function() Utils.safe_keymap_set("n", Config.mappings.refresh, function()
M.refresh() require("avante.api").refresh()
end, { desc = "avante: refresh" }) end, { desc = "avante: refresh" })
Utils.safe_keymap_set("n", Config.mappings.toggle.default, function()
M.toggle()
end, { desc = "avante: toggle" })
Utils.safe_keymap_set("n", Config.mappings.toggle.debug, function() Utils.safe_keymap_set("n", Config.mappings.toggle.debug, function()
M.toggle.debug() M.toggle.debug()
end, { desc = "avante: toggle debug" }) end, { desc = "avante: toggle debug" })
@ -138,7 +185,7 @@ H.autocmds = function()
api.nvim_create_autocmd("VimResized", { api.nvim_create_autocmd("VimResized", {
group = H.augroup, group = H.augroup,
callback = function() callback = function()
local sidebar, _ = M._get() local sidebar, _ = M.get()
if not sidebar then if not sidebar then
return return
end end
@ -188,7 +235,7 @@ end
---@param current boolean? false to disable setting current, otherwise use this to track across tabs. ---@param current boolean? false to disable setting current, otherwise use this to track across tabs.
---@return avante.Sidebar, avante.Selection ---@return avante.Sidebar, avante.Selection
function M._get(current) function M.get(current)
local tab = api.nvim_get_current_tabpage() local tab = api.nvim_get_current_tabpage()
local sidebar = M.sidebars[tab] local sidebar = M.sidebars[tab]
local selection = M.selections[tab] local selection = M.selections[tab]
@ -241,7 +288,7 @@ M.toggle.hint = H.api(Utils.toggle_wrap({
setmetatable(M.toggle, { setmetatable(M.toggle, {
__index = M.toggle, __index = M.toggle,
__call = function() __call = function()
local sidebar, _ = M._get() local sidebar, _ = M.get()
if not sidebar then if not sidebar then
M._init(api.nvim_get_current_tabpage()) M._init(api.nvim_get_current_tabpage())
M.current.sidebar:open() M.current.sidebar:open()
@ -252,6 +299,7 @@ setmetatable(M.toggle, {
end, end,
}) })
---@param path string
local function to_windows_path(path) local function to_windows_path(path)
local winpath = path:gsub("/", "\\") local winpath = path:gsub("/", "\\")
@ -299,48 +347,6 @@ M.build = H.api(function()
return vim.tbl_contains({ 0 }, job.code) and true or false return vim.tbl_contains({ 0 }, job.code) and true or false
end) end)
M.ask = H.api(function()
M.toggle()
end)
M.edit = H.api(function()
local _, selection = M._get()
if not selection then
return
end
selection:create_editing_input()
end)
M.refresh = H.api(function()
local sidebar, _ = M._get()
if not sidebar then
return
end
if not sidebar:is_open() then
return
end
local curbuf = vim.api.nvim_get_current_buf()
local focused = sidebar.result.bufnr == curbuf or sidebar.input.bufnr == curbuf
if focused or not sidebar:is_open() then
return
end
local listed = vim.api.nvim_get_option_value("buflisted", { buf = curbuf })
if Utils.is_sidebar_buffer(curbuf) or not listed then
return
end
local curwin = vim.api.nvim_get_current_win()
sidebar:close()
sidebar.code.winid = curwin
sidebar.code.bufnr = curbuf
sidebar:render()
end)
---@param opts? avante.Config ---@param opts? avante.Config
function M.setup(opts) function M.setup(opts)
if vim.fn.has("nvim-0.10") == 0 then if vim.fn.has("nvim-0.10") == 0 then

View File

@ -300,11 +300,8 @@ M.setup = function()
---@type AvanteProviderFunctor ---@type AvanteProviderFunctor
local provider = M[Config.provider] local provider = M[Config.provider]
E.setup({ provider = provider }) E.setup({ provider = provider })
M.commands()
end end
---@private
---@param provider Provider ---@param provider Provider
function M.refresh(provider) function M.refresh(provider)
require("avante.config").override({ provider = provider }) require("avante.config").override({ provider = provider })
@ -315,31 +312,6 @@ function M.refresh(provider)
Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" }) Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" })
end end
local default_providers = { "openai", "claude", "azure", "gemini" }
---@private
M.commands = function()
api.nvim_create_user_command("AvanteSwitchProvider", function(args)
local cmd = vim.trim(args.args or "")
M.refresh(cmd)
end, {
nargs = 1,
desc = "avante: switch provider",
complete = function(_, line)
if line:match("^%s*AvanteSwitchProvider %w") then
return {}
end
local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or ""
-- join two tables
local Keys = vim.list_extend({}, default_providers)
Keys = vim.list_extend(Keys, vim.tbl_keys(Config.vendors or {}))
return vim.tbl_filter(function(key)
return key:find(prefix) == 1
end, Keys)
end,
})
end
---@param opts AvanteProvider | AvanteSupportedProvider | AvanteProviderFunctor ---@param opts AvanteProvider | AvanteSupportedProvider | AvanteProviderFunctor
---@return AvanteDefaultBaseProvider, table<string, any> ---@return AvanteDefaultBaseProvider, table<string, any>
M.parse_config = function(opts) M.parse_config = function(opts)

View File

@ -367,10 +367,8 @@ function Selection:create_editing_input()
end, end,
}) })
local function submit_input() ---@param input string
local lines = api.nvim_buf_get_lines(bufnr, 0, -1, false) local function submit_input(input)
local input = lines[1] or ""
local full_response = "" local full_response = ""
local start_line = self.selection.range.start.line local start_line = self.selection.range.start.line
local finish_line = self.selection.range.finish.line local finish_line = self.selection.range.finish.line
@ -426,8 +424,18 @@ function Selection:create_editing_input()
}) })
end end
vim.keymap.set("i", Config.mappings.submit.insert, submit_input, { buffer = bufnr, noremap = true, silent = true }) ---@return string
vim.keymap.set("n", Config.mappings.submit.normal, submit_input, { buffer = bufnr, noremap = true, silent = true }) local get_bufnr_input = function()
local lines = api.nvim_buf_get_lines(bufnr, 0, -1, false)
return lines[1] or ""
end
vim.keymap.set("i", Config.mappings.submit.insert, function()
submit_input(get_bufnr_input())
end, { buffer = bufnr, noremap = true, silent = true })
vim.keymap.set("n", Config.mappings.submit.normal, function()
submit_input(get_bufnr_input())
end, { buffer = bufnr, noremap = true, silent = true })
vim.keymap.set("n", "<Esc>", function() vim.keymap.set("n", "<Esc>", function()
self:close_editing_input() self:close_editing_input()
end, { buffer = bufnr }) end, { buffer = bufnr })
@ -461,6 +469,15 @@ function Selection:create_editing_input()
end end
end, end,
}) })
api.nvim_create_autocmd("User", {
pattern = "AvanteEditSubmitted",
callback = function(ev)
if ev.data and ev.data.request then
submit_input(ev.data.request)
end
end,
})
end end
function Selection:setup_autocmds() function Selection:setup_autocmds()

View File

@ -545,7 +545,7 @@ function Sidebar:apply(current_cursor)
Diff.process(self.code.bufnr) Diff.process(self.code.bufnr)
api.nvim_win_set_cursor(self.code.winid, { 1, 0 }) api.nvim_win_set_cursor(self.code.winid, { 1, 0 })
vim.defer_fn(function() vim.defer_fn(function()
vim.cmd("AvanteConflictNextConflict") Diff.find_next("ours")
vim.cmd("normal! zz") vim.cmd("normal! zz")
end, 1000) end, 1000)
end, 10) end, 10)
@ -577,13 +577,6 @@ local base_win_options = {
statusline = "", statusline = "",
} }
local function get_win_options()
-- return vim.tbl_deep_extend("force", base_win_options, {
-- fillchars = "eob: ,vert: ,horiz: ,horizup: ,horizdown: ,vertleft: ,vertright:" .. (code_vert_char ~= nil and code_vert_char or " ") .. ",verthoriz: ",
-- })
return base_win_options
end
function Sidebar:render_header(winid, bufnr, header_text, hl, reverse_hl) function Sidebar:render_header(winid, bufnr, header_text, hl, reverse_hl)
if not bufnr or not api.nvim_buf_is_valid(bufnr) then if not bufnr or not api.nvim_buf_is_valid(bufnr) then
return return
@ -937,10 +930,9 @@ function Sidebar:refresh_winids()
end end
function Sidebar:resize() function Sidebar:resize()
local new_layout = Config.get_sidebar_layout_options()
for _, comp in pairs(self) do for _, comp in pairs(self) do
if comp and type(comp) == "table" and comp.winid and api.nvim_win_is_valid(comp.winid) then if comp and type(comp) == "table" and comp.winid and api.nvim_win_is_valid(comp.winid) then
api.nvim_win_set_width(comp.winid, new_layout.width) api.nvim_win_set_width(comp.winid, Config.get_window_width())
end end
end end
self:render_result() self:render_result()
@ -1220,7 +1212,7 @@ function Sidebar:create_selected_code()
winid = self.input.winid, winid = self.input.winid,
}, },
buf_options = buf_options, buf_options = buf_options,
win_options = get_win_options(), win_options = base_win_options,
position = "top", position = "top",
size = { size = {
height = selected_code_size + 3, height = selected_code_size + 3,
@ -1407,7 +1399,7 @@ function Sidebar:create_input()
type = "win", type = "win",
winid = self.result.winid, winid = self.result.winid,
}, },
win_options = vim.tbl_deep_extend("force", get_win_options(), { signcolumn = "yes" }), win_options = vim.tbl_deep_extend("force", base_win_options, { signcolumn = "yes" }),
position = get_position(), position = get_position(),
size = get_size(), size = get_size(),
}) })
@ -1571,6 +1563,15 @@ function Sidebar:create_input()
}) })
self:refresh_winids() self:refresh_winids()
api.nvim_create_autocmd("User", {
pattern = "AvanteInputSubmitted",
callback = function(ev)
if ev.data and ev.data.request then
handle_submit(ev.data.request)
end
end,
})
end end
function Sidebar:get_selected_code_size() function Sidebar:get_selected_code_size()
@ -1623,7 +1624,7 @@ function Sidebar:render()
bufhidden = "wipe", bufhidden = "wipe",
filetype = "Avante", filetype = "Avante",
}), }),
win_options = get_win_options(), win_options = base_win_options,
size = { size = {
width = get_width(), width = get_width(),
height = get_height(), height = get_height(),

View File

@ -12,10 +12,61 @@
---@class vim.api.keyset.create_autocmd.opts: vim.api.keyset.create_autocmd ---@class vim.api.keyset.create_autocmd.opts: vim.api.keyset.create_autocmd
---@field callback? fun(ev:vim.api.create_autocmd.callback.args):boolean? ---@field callback? fun(ev:vim.api.create_autocmd.callback.args):boolean?
--- @param event string | string[] (string|array) Event(s) that will trigger the handler ---@param event string | string[] (string|array) Event(s) that will trigger the handler
--- @param opts vim.api.keyset.create_autocmd.opts ---@param opts vim.api.keyset.create_autocmd.opts
--- @return integer ---@return integer
function vim.api.nvim_create_autocmd(event, opts) end function vim.api.nvim_create_autocmd(event, opts) end
---@class vim.api.keyset.user_command.callback_opts
---@field name string
---@field args string
---@field fargs string[]
---@field nargs? integer | string
---@field bang? boolean
---@field line1? integer
---@field line2? integer
---@field range? integer
---@field count? integer
---@field reg? string
---@field mods? string
---@field smods? UserCommandSmods
---@class UserCommandSmods
---@field browse boolean
---@field confirm boolean
---@field emsg_silent boolean
---@field hide boolean
---@field horizontal boolean
---@field keepalt boolean
---@field keepjumps boolean
---@field keepmarks boolean
---@field keeppatterns boolean
---@field lockmarks boolean
---@field noautocmd boolean
---@field noswapfile boolean
---@field sandbox boolean
---@field silent boolean
---@field split string
---@field tab integer
---@field unsilent boolean
---@field verbose integer
---@field vertical boolean
---@class vim.api.keyset.user_command.opts: vim.api.keyset.user_command
---@field nargs? integer | string
---@field range? integer
---@field bang? boolean
---@field desc? string
---@field force? boolean
---@field complete? fun(prefix: string, line: string, pos?: integer): string[]
---@field preview? fun(opts: vim.api.keyset.user_command.callback_opts, ns: integer, buf: integer): nil
---@alias vim.api.keyset.user_command.callback fun(opts?: vim.api.keyset.user_command.callback_opts):nil
---@param name string
---@param command vim.api.keyset.user_command.callback
---@param opts? vim.api.keyset.user_command.opts
function vim.api.nvim_create_user_command(name, command, opts) end
---@type boolean ---@type boolean
vim.g.avante_login = vim.g.avante_login vim.g.avante_login = vim.g.avante_login