merge from upstream

This commit is contained in:
zhangkun9038@dingtalk.com 2025-02-17 16:30:34 +08:00
commit 15834922c4
21 changed files with 479 additions and 199 deletions

View File

@ -3,3 +3,6 @@ rustflags = ["-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup"]
[target.aarch64-apple-darwin] [target.aarch64-apple-darwin]
rustflags = ["-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup"] rustflags = ["-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup"]
[target.x86_64-unknown-linux-musl]
rustflags = ["-C", "target-feature=-crt-static"]

View File

@ -60,6 +60,7 @@ For building binary if you wish to build from source, then `cargo` is required.
timeout = 30000, -- timeout in milliseconds timeout = 30000, -- timeout in milliseconds
temperature = 0, -- adjust if needed temperature = 0, -- adjust if needed
max_tokens = 4096, max_tokens = 4096,
reasoning_effort = "high" -- only supported for "o" models
}, },
}, },
-- if you want to build from source then do `make BUILD_FROM_SOURCE=true` -- if you want to build from source then do `make BUILD_FROM_SOURCE=true`
@ -380,12 +381,44 @@ This is achieved by emulating nvim-cmp using blink.compat
```lua ```lua
file_selector = { file_selector = {
--- @alias FileSelectorProvider "native" | "fzf" | "mini.pick" | "snacks" | "telescope" | string --- @alias FileSelectorProvider "native" | "fzf" | "mini.pick" | "snacks" | "telescope" | string | fun(params: avante.file_selector.IParams|nil): nil
provider = "fzf", provider = "fzf",
-- Options override for custom providers -- Options override for custom providers
provider_opts = {}, provider_opts = {},
} }
``` ```
To create a customized file_selector, you can specify a customized function to launch a picker to select items and pass the selected items to the `handler` callback.
```lua
file_selector = {
---@param params avante.file_selector.IParams
provider = function(params)
local filepaths = params.filepaths ---@type string[]
local title = params.title ---@type string
local handler = params.handler ---@type fun(selected_filepaths: string[]|nil): nil
-- Launch your customized picker with the items built from `filepaths`, then in the `on_confirm` callback,
-- pass the selected items (convert back to file paths) to the `handler` function.
local items = __your_items_formatter__(filepaths)
__your_picker__({
items = items,
on_cancel = function()
handler(nil)
end,
on_confirm = function(selected_items)
local selected_filepaths = {}
for _, item in ipairs(selected_items) do
table.insert(selected_filepaths, item.filepath)
end
handler(selected_filepaths)
end
})
end,
}
```
Choose a selector other that native, the default as that currently has an issue Choose a selector other that native, the default as that currently has an issue
For lazyvim users copy the full config for blink.cmp from the website or extend the options For lazyvim users copy the full config for blink.cmp from the website or extend the options
```lua ```lua
@ -471,8 +504,10 @@ Given its early stage, `avante.nvim` currently supports the following basic func
> For Amazon Bedrock: > For Amazon Bedrock:
> >
> ```sh > ```sh
> export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region > export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region[,aws_session_token]
>
> ``` > ```
> Note: The aws_session_token is optional and only needed when using temporary AWS credentials
1. Open a code file in Neovim. 1. Open a code file in Neovim.
2. Use the `:AvanteAsk` command to query the AI about the code. 2. Use the `:AvanteAsk` command to query the AI about the code.
@ -548,15 +583,16 @@ For more information, see [Custom Providers](https://github.com/yetone/avante.nv
## Web Search Engines ## Web Search Engines
Avante's tools include some web search engines, currently support [tavily](https://tavily.com/) and [serpapi](https://serpapi.com/). The default is tavily, and can be changed through configuring `Config.web_search_engine.provider`: Avante's tools include some web search engines, currently support [tavily](https://tavily.com/), [serpapi](https://serpapi.com/), [searchapi](https://www.searchapi.io/) and google's [programmable search engine](https://developers.google.com/custom-search/v1/overview). The default is tavily, and can be changed through configuring `Config.web_search_engine.provider`:
```lua ```lua
web_search_engine = { web_search_engine = {
provider = "tavily", -- tavily or serpapi provider = "tavily", -- tavily, serpapi, searchapi or google
} }
``` ```
You need to set the environment variable `TAVILY_API_KEY` or `SERPAPI_API_KEY` to use tavily or serpapi. You need to set the environment variable `TAVILY_API_KEY` , `SERPAPI_API_KEY`, `SEARCHAPI_API_KEY` to use tavily or serpapi or searchapi.
To use google, set the `GOOGLE_SEARCH_API_KEY` as the [API key](https://developers.google.com/custom-search/v1/overview), and `GOOGLE_SEARCH_ENGINE_ID` as the [search engine](https://programmablesearchengine.google.com) ID.
## Disable Tools ## Disable Tools

View File

@ -5,6 +5,11 @@
local Utils = require("avante.utils") local Utils = require("avante.utils")
---@class avante.file_selector.IParams
---@field public title string
---@field public filepaths string[]
---@field public handler fun(filepaths: string[]|nil): nil
---@class avante.CoreConfig: avante.Config ---@class avante.CoreConfig: avante.Config
local M = {} local M = {}
---@class avante.Config ---@class avante.Config
@ -28,11 +33,10 @@ M._defaults = {
tavily = { tavily = {
api_key_name = "TAVILY_API_KEY", api_key_name = "TAVILY_API_KEY",
extra_request_body = { extra_request_body = {
time_range = "d",
include_answer = "basic", include_answer = "basic",
}, },
---@type WebSearchEngineProviderResponseBodyFormatter ---@type WebSearchEngineProviderResponseBodyFormatter
format_response_body = function(body) return body.anwser, nil end, format_response_body = function(body) return body.answer, nil end,
}, },
serpapi = { serpapi = {
api_key_name = "SERPAPI_API_KEY", api_key_name = "SERPAPI_API_KEY",
@ -52,11 +56,65 @@ M._defaults = {
title = result.title, title = result.title,
link = result.link, link = result.link,
snippet = result.snippet, snippet = result.snippet,
date = result.date,
} }
end end
) )
:take(10)
:totable()
return vim.json.encode(jsn), nil
end
return "", nil
end,
},
searchapi = {
api_key_name = "SEARCHAPI_API_KEY",
extra_request_body = {
engine = "google",
},
---@type WebSearchEngineProviderResponseBodyFormatter
format_response_body = function(body)
if body.answer_box ~= nil then return body.answer_box.result, nil end
if body.organic_results ~= nil then
local jsn = vim
.iter(body.organic_results)
:map(
function(result)
return {
title = result.title,
link = result.link,
snippet = result.snippet,
date = result.date,
}
end
)
:take(10)
:totable()
return vim.json.encode(jsn), nil
end
return "", nil
end,
},
google = {
api_key_name = "GOOGLE_SEARCH_API_KEY",
engine_id_name = "GOOGLE_SEARCH_ENGINE_ID",
extra_request_body = {},
---@type WebSearchEngineProviderResponseBodyFormatter
format_response_body = function(body)
if body.items ~= nil then
local jsn = vim
.iter(body.items)
:map(
function(result)
return {
title = result.title,
link = result.link,
snippet = result.snippet,
}
end
)
:take(10)
:totable() :totable()
if #jsn > 5 then jsn = vim.list_slice(jsn, 1, 5) end
return vim.json.encode(jsn), nil return vim.json.encode(jsn), nil
end end
return "", nil return "", nil
@ -307,7 +365,7 @@ M._defaults = {
}, },
--- @class AvanteFileSelectorConfig --- @class AvanteFileSelectorConfig
file_selector = { file_selector = {
--- @alias FileSelectorProvider "native" | "fzf" | "mini.pick" | "snacks" | "telescope" | string --- @alias FileSelectorProvider "native" | "fzf" | "mini.pick" | "snacks" | "telescope" | string | fun(params: avante.file_selector.IParams|nil): nil
provider = "native", provider = "native",
-- Options override for custom providers -- Options override for custom providers
provider_opts = {}, provider_opts = {},

View File

@ -330,6 +330,11 @@ function FileSelector:show_select_ui()
self:snacks_picker_ui(handler) self:snacks_picker_ui(handler)
elseif Config.file_selector.provider == "telescope" then elseif Config.file_selector.provider == "telescope" then
self:telescope_ui(handler) self:telescope_ui(handler)
elseif type(Config.file_selector.provider) == "function" then
local title = string.format("%s:", PROMPT_TITLE) ---@type string
local filepaths = self:get_filepaths() ---@type string[]
local params = { title = title, filepaths = filepaths, handler = handler } ---@type avante.file_selector.IParams
Config.file_selector.provider(params)
else else
Utils.error("Unknown file selector provider: " .. Config.file_selector.provider) Utils.error("Unknown file selector provider: " .. Config.file_selector.provider)
end end
@ -363,9 +368,9 @@ end
function FileSelector:get_selected_files_contents() function FileSelector:get_selected_files_contents()
local contents = {} local contents = {}
for _, file_path in ipairs(self.selected_filepaths) do for _, file_path in ipairs(self.selected_filepaths) do
local lines, filetype, error = Utils.read_file_from_buf_or_disk(file_path) local lines, error = Utils.read_file_from_buf_or_disk(file_path)
lines = lines or {} lines = lines or {}
filetype = filetype or "unknown" local filetype = Utils.get_filetype(file_path)
if error ~= nil then if error ~= nil then
Utils.error("error reading file: " .. error) Utils.error("error reading file: " .. error)
else else

View File

@ -23,9 +23,7 @@ M.check = function()
end end
-- Optional dependencies -- Optional dependencies
local has_devicons = Utils.has("nvim-web-devicons") if Utils.icons_enabled() then
local has_mini_icons = Utils.has("mini.icons") or Utils.has("mini.nvim")
if has_devicons or has_mini_icons then
H.ok("Found icons plugin (nvim-web-devicons or mini.icons)") H.ok("Found icons plugin (nvim-web-devicons or mini.icons)")
else else
H.warn("No icons plugin found (nvim-web-devicons or mini.icons). Icons will not be displayed") H.warn("No icons plugin found (nvim-web-devicons or mini.icons). Icons will not be displayed")

View File

@ -25,8 +25,8 @@ M.generate_prompts = function(opts)
local Provider = opts.provider or P[Config.provider] local Provider = opts.provider or P[Config.provider]
local mode = opts.mode or "planning" local mode = opts.mode or "planning"
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor ---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
local _, body_opts = P.parse_config(Provider) local _, request_body = P.parse_config(Provider)
local max_tokens = body_opts.max_tokens or 4096 local max_tokens = request_body.max_tokens or 4096
-- Check if the instructions contains an image path -- Check if the instructions contains an image path
local image_paths = {} local image_paths = {}
@ -449,7 +449,7 @@ M.stream = function(opts)
return original_on_stop(stop_opts) return original_on_stop(stop_opts)
end) end)
end end
if Config.dual_boost.enabled then if Config.dual_boost.enabled and opts.mode == "planning" then
M._dual_boost_stream(opts, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider]) M._dual_boost_stream(opts, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider])
else else
M._stream(opts) M._stream(opts)

View File

@ -7,6 +7,7 @@ local M = {}
---@param rel_path string ---@param rel_path string
---@return string ---@return string
local function get_abs_path(rel_path) local function get_abs_path(rel_path)
if Path:new(rel_path):is_absolute() then return rel_path end
local project_root = Utils.get_project_root() local project_root = Utils.get_project_root()
return Path:new(project_root):joinpath(rel_path):absolute() return Path:new(project_root):joinpath(rel_path):absolute()
end end
@ -41,13 +42,12 @@ function M.list_files(opts, on_log)
add_dirs = true, add_dirs = true,
depth = opts.depth, depth = opts.depth,
}) })
local result = "" local filepaths = {}
for _, file in ipairs(files) do for _, file in ipairs(files) do
local uniform_path = Utils.uniform_path(file) local uniform_path = Utils.uniform_path(file)
result = result .. uniform_path .. "\n" table.insert(filepaths, uniform_path)
end end
result = result:gsub("\n$", "") return vim.json.encode(filepaths), nil
return result, nil
end end
---@param opts { rel_path: string, keyword: string } ---@param opts { rel_path: string, keyword: string }
@ -62,12 +62,11 @@ function M.search_files(opts, on_log)
local files = Utils.scan_directory_respect_gitignore({ local files = Utils.scan_directory_respect_gitignore({
directory = abs_path, directory = abs_path,
}) })
local result = "" local filepaths = {}
for _, file in ipairs(files) do for _, file in ipairs(files) do
if file:find(opts.keyword) then result = result .. file .. "\n" end if file:find(opts.keyword) then table.insert(filepaths, file) end
end end
result = result:gsub("\n$", "") return vim.json.encode(filepaths), nil
return result, nil
end end
---@param opts { rel_path: string, keyword: string } ---@param opts { rel_path: string, keyword: string }
@ -104,7 +103,9 @@ function M.search(opts, on_log)
if on_log then on_log("Running command: " .. cmd) end if on_log then on_log("Running command: " .. cmd) end
local result = vim.fn.system(cmd) local result = vim.fn.system(cmd)
return result or "", nil local filepaths = vim.split(result, "\n")
return vim.json.encode(filepaths), nil
end end
---@param opts { rel_path: string } ---@param opts { rel_path: string }
@ -183,9 +184,10 @@ function M.rename_file(opts, on_log)
end end
---@param opts { rel_path: string, new_rel_path: string } ---@param opts { rel_path: string, new_rel_path: string }
---@param on_log? fun(log: string): nil
---@return boolean success ---@return boolean success
---@return string|nil error ---@return string|nil error
function M.copy_file(opts) function M.copy_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path) local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
@ -193,38 +195,47 @@ function M.copy_file(opts)
local new_abs_path = get_abs_path(opts.new_rel_path) local new_abs_path = get_abs_path(opts.new_rel_path)
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end
if on_log then on_log("Copying file: " .. abs_path .. " to " .. new_abs_path) end
Path:new(new_abs_path):write(Path:new(abs_path):read()) Path:new(new_abs_path):write(Path:new(abs_path):read())
return true, nil return true, nil
end end
---@param opts { rel_path: string } ---@param opts { rel_path: string }
---@param on_log? fun(log: string): nil
---@return boolean success ---@return boolean success
---@return string|nil error ---@return string|nil error
function M.delete_file(opts) function M.delete_file(opts, on_log)
local abs_path = get_abs_path(opts.rel_path) local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
if not M.confirm("Are you sure you want to delete the file: " .. abs_path) then return false, "User canceled" end if not M.confirm("Are you sure you want to delete the file: " .. abs_path) then return false, "User canceled" end
if on_log then on_log("Deleting file: " .. abs_path) end
os.remove(abs_path) os.remove(abs_path)
return true, nil return true, nil
end end
---@param opts { rel_path: string } ---@param opts { rel_path: string }
---@param on_log? fun(log: string): nil
---@return boolean success ---@return boolean success
---@return string|nil error ---@return string|nil error
function M.create_dir(opts) function M.create_dir(opts, on_log)
local abs_path = get_abs_path(opts.rel_path) local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
if not M.confirm("Are you sure you want to create the directory: " .. abs_path) then
return false, "User canceled"
end
if on_log then on_log("Creating directory: " .. abs_path) end
Path:new(abs_path):mkdir({ parents = true }) Path:new(abs_path):mkdir({ parents = true })
return true, nil return true, nil
end end
---@param opts { rel_path: string, new_rel_path: string } ---@param opts { rel_path: string, new_rel_path: string }
---@param on_log? fun(log: string): nil
---@return boolean success ---@return boolean success
---@return string|nil error ---@return string|nil error
function M.rename_dir(opts) function M.rename_dir(opts, on_log)
local abs_path = get_abs_path(opts.rel_path) local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
@ -235,14 +246,16 @@ function M.rename_dir(opts)
if not M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?") then if not M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?") then
return false, "User canceled" return false, "User canceled"
end end
if on_log then on_log("Renaming directory: " .. abs_path .. " to " .. new_abs_path) end
os.rename(abs_path, new_abs_path) os.rename(abs_path, new_abs_path)
return true, nil return true, nil
end end
---@param opts { rel_path: string } ---@param opts { rel_path: string }
---@param on_log? fun(log: string): nil
---@return boolean success ---@return boolean success
---@return string|nil error ---@return string|nil error
function M.delete_dir(opts) function M.delete_dir(opts, on_log)
local abs_path = get_abs_path(opts.rel_path) local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
@ -250,6 +263,7 @@ function M.delete_dir(opts)
if not M.confirm("Are you sure you want to delete the directory: " .. abs_path) then if not M.confirm("Are you sure you want to delete the directory: " .. abs_path) then
return false, "User canceled" return false, "User canceled"
end end
if on_log then on_log("Deleting directory: " .. abs_path) end
os.remove(abs_path) os.remove(abs_path)
return true, nil return true, nil
end end
@ -326,6 +340,45 @@ function M.web_search(opts, on_log)
if resp.status ~= 200 then return nil, "Error: " .. resp.body end if resp.status ~= 200 then return nil, "Error: " .. resp.body end
local jsn = vim.json.decode(resp.body) local jsn = vim.json.decode(resp.body)
return search_engine.format_response_body(jsn) return search_engine.format_response_body(jsn)
elseif provider_type == "searchapi" then
local query_params = vim.tbl_deep_extend("force", {
api_key = api_key,
q = opts.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
query_string = query_string .. key .. "=" .. vim.uri_encode(value) .. "&"
end
local resp = curl.get("https://searchapi.io/api/v1/search?" .. query_string, {
headers = {
["Content-Type"] = "application/json",
},
})
if resp.status ~= 200 then return nil, "Error: " .. resp.body end
local jsn = vim.json.decode(resp.body)
return search_engine.format_response_body(jsn)
elseif provider_type == "google" then
local engine_id = os.getenv(search_engine.engine_id_name)
if engine_id == nil or engine_id == "" then
return nil, "Environment variable " .. search_engine.engine_id_namee .. " is not set"
end
local query_params = vim.tbl_deep_extend("force", {
key = api_key,
cx = engine_id,
q = opts.query,
}, search_engine.extra_request_body)
local query_string = ""
for key, value in pairs(query_params) do
query_string = query_string .. key .. "=" .. vim.uri_encode(value) .. "&"
end
local resp = curl.get("https://www.googleapis.com/customsearch/v1?" .. query_string, {
headers = {
["Content-Type"] = "application/json",
},
})
if resp.status ~= 200 then return nil, "Error: " .. resp.body end
local jsn = vim.json.decode(resp.body)
return search_engine.format_response_body(jsn)
end end
end end

View File

@ -18,31 +18,34 @@ M.parse_response = O.parse_response
M.parse_response_without_stream = O.parse_response_without_stream M.parse_response_without_stream = O.parse_response_without_stream
M.parse_curl_args = function(provider, prompt_opts) M.parse_curl_args = function(provider, prompt_opts)
local base, body_opts = P.parse_config(provider) local provider_conf, request_body = P.parse_config(provider)
local headers = { local headers = {
["Content-Type"] = "application/json", ["Content-Type"] = "application/json",
} }
if P.env.require_api_key(base) then headers["api-key"] = provider.parse_api_key() end if P.env.require_api_key(provider_conf) then headers["api-key"] = provider.parse_api_key() end
-- NOTE: When using "o" series set the supported parameters only -- NOTE: When using "o" series set the supported parameters only
if O.is_o_series_model(base.model) then if O.is_o_series_model(provider_conf.model) then
body_opts.max_tokens = nil request_body.max_tokens = nil
body_opts.temperature = 1 request_body.temperature = 1
end end
return { return {
url = Utils.url_join( url = Utils.url_join(
base.endpoint, provider_conf.endpoint,
"/openai/deployments/" .. base.deployment .. "/chat/completions?api-version=" .. base.api_version "/openai/deployments/"
.. provider_conf.deployment
.. "/chat/completions?api-version="
.. provider_conf.api_version
), ),
proxy = base.proxy, proxy = provider_conf.proxy,
insecure = base.allow_insecure, insecure = provider_conf.allow_insecure,
headers = headers, headers = headers,
body = vim.tbl_deep_extend("force", { body = vim.tbl_deep_extend("force", {
messages = M.parse_messages(prompt_opts), messages = M.parse_messages(prompt_opts),
stream = true, stream = true,
}, body_opts), }, request_body),
} }
end end

View File

@ -1,5 +1,4 @@
local Utils = require("avante.utils") local Utils = require("avante.utils")
local Clipboard = require("avante.clipboard")
local P = require("avante.providers") local P = require("avante.providers")
---@alias AvanteBedrockPayloadBuilder fun(prompt_opts: AvantePromptOptions, body_opts: table<string, any>): table<string, any> ---@alias AvanteBedrockPayloadBuilder fun(prompt_opts: AvantePromptOptions, body_opts: table<string, any>): table<string, any>
@ -17,17 +16,14 @@ M.api_key_name = "BEDROCK_KEYS"
M.use_xml_format = true M.use_xml_format = true
M.load_model_handler = function() M.load_model_handler = function()
local base, _ = P.parse_config(P["bedrock"]) local provider_conf, _ = P.parse_config(P["bedrock"])
local bedrock_model = base.model local bedrock_model = provider_conf.model
if base.model:match("anthropic") then bedrock_model = "claude" end if provider_conf.model:match("anthropic") then bedrock_model = "claude" end
local ok, model_module = pcall(require, "avante.providers.bedrock." .. bedrock_model) local ok, model_module = pcall(require, "avante.providers.bedrock." .. bedrock_model)
if ok then if ok then return model_module end
return model_module local error_msg = "Bedrock model handler not found: " .. bedrock_model
else error(error_msg)
local error_msg = "Bedrock model handler not found: " .. bedrock_model
Utils.error(error_msg, { once = true, title = "Avante" })
end
end end
M.parse_response = function(ctx, data_stream, event_state, opts) M.parse_response = function(ctx, data_stream, event_state, opts)
@ -46,8 +42,8 @@ M.parse_stream_data = function(data, opts)
-- The `type` field in the decoded JSON determines how the response is handled. -- The `type` field in the decoded JSON determines how the response is handled.
local bedrock_match = data:gmatch("event(%b{})") local bedrock_match = data:gmatch("event(%b{})")
for bedrock_data_match in bedrock_match do for bedrock_data_match in bedrock_match do
local data = vim.json.decode(bedrock_data_match) local jsn = vim.json.decode(bedrock_data_match)
local data_stream = vim.base64.decode(data.bytes) local data_stream = vim.base64.decode(jsn.bytes)
local json = vim.json.decode(data_stream) local json = vim.json.decode(data_stream)
M.parse_response({}, data_stream, json.type, opts) M.parse_response({}, data_stream, json.type, opts)
end end
@ -60,10 +56,12 @@ M.parse_curl_args = function(provider, prompt_opts)
local base, body_opts = P.parse_config(provider) local base, body_opts = P.parse_config(provider)
local api_key = provider.parse_api_key() local api_key = provider.parse_api_key()
if api_key == nil then error("Cannot get the bedrock api key!") end
local parts = vim.split(api_key, ",") local parts = vim.split(api_key, ",")
local aws_access_key_id = parts[1] local aws_access_key_id = parts[1]
local aws_secret_access_key = parts[2] local aws_secret_access_key = parts[2]
local aws_region = parts[3] local aws_region = parts[3]
local aws_session_token = parts[4]
local endpoint = string.format( local endpoint = string.format(
"https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", "https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream",
@ -75,6 +73,8 @@ M.parse_curl_args = function(provider, prompt_opts)
["Content-Type"] = "application/json", ["Content-Type"] = "application/json",
} }
if aws_session_token and aws_session_token ~= "" then headers["x-amz-security-token"] = aws_session_token end
local body_payload = M.build_bedrock_payload(prompt_opts, body_opts) local body_payload = M.build_bedrock_payload(prompt_opts, body_opts)
local rawArgs = { local rawArgs = {
@ -105,7 +105,6 @@ M.on_error = function(result)
end end
local error_msg = body.error.message local error_msg = body.error.message
local error_type = body.error.type
Utils.error(error_msg, { once = true, title = "Avante" }) Utils.error(error_msg, { once = true, title = "Avante" })
end end

View File

@ -226,7 +226,7 @@ end
---@param prompt_opts AvantePromptOptions ---@param prompt_opts AvantePromptOptions
---@return table ---@return table
M.parse_curl_args = function(provider, prompt_opts) M.parse_curl_args = function(provider, prompt_opts)
local base, body_opts = P.parse_config(provider) local provider_conf, request_body = P.parse_config(provider)
local headers = { local headers = {
["Content-Type"] = "application/json", ["Content-Type"] = "application/json",
@ -234,7 +234,7 @@ M.parse_curl_args = function(provider, prompt_opts)
["anthropic-beta"] = "prompt-caching-2024-07-31", ["anthropic-beta"] = "prompt-caching-2024-07-31",
} }
if P.env.require_api_key(base) then headers["x-api-key"] = provider.parse_api_key() end if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end
local messages = M.parse_messages(prompt_opts) local messages = M.parse_messages(prompt_opts)
@ -246,12 +246,12 @@ M.parse_curl_args = function(provider, prompt_opts)
end end
return { return {
url = Utils.url_join(base.endpoint, "/v1/messages"), url = Utils.url_join(provider_conf.endpoint, "/v1/messages"),
proxy = base.proxy, proxy = provider_conf.proxy,
insecure = base.allow_insecure, insecure = provider_conf.allow_insecure,
headers = headers, headers = headers,
body = vim.tbl_deep_extend("force", { body = vim.tbl_deep_extend("force", {
model = base.model, model = provider_conf.model,
system = { system = {
{ {
type = "text", type = "text",
@ -262,7 +262,7 @@ M.parse_curl_args = function(provider, prompt_opts)
messages = messages, messages = messages,
tools = tools, tools = tools,
stream = true, stream = true,
}, body_opts), }, request_body),
} }
end end

View File

@ -70,7 +70,7 @@ M.parse_stream_data = function(data, opts)
end end
M.parse_curl_args = function(provider, prompt_opts) M.parse_curl_args = function(provider, prompt_opts)
local base, body_opts = P.parse_config(provider) local provider_conf, request_body = P.parse_config(provider)
local headers = { local headers = {
["Accept"] = "application/json", ["Accept"] = "application/json",
@ -82,17 +82,17 @@ M.parse_curl_args = function(provider, prompt_opts)
.. "." .. "."
.. vim.version().patch, .. vim.version().patch,
} }
if P.env.require_api_key(base) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
return { return {
url = Utils.url_join(base.endpoint, "/chat"), url = Utils.url_join(provider_conf.endpoint, "/chat"),
proxy = base.proxy, proxy = provider_conf.proxy,
insecure = base.allow_insecure, insecure = provider_conf.allow_insecure,
headers = headers, headers = headers,
body = vim.tbl_deep_extend("force", { body = vim.tbl_deep_extend("force", {
model = base.model, model = provider_conf.model,
stream = true, stream = true,
}, M.parse_messages(prompt_opts), body_opts), }, M.parse_messages(prompt_opts), request_body),
} }
end end

View File

@ -249,7 +249,7 @@ M.parse_curl_args = function(provider, prompt_opts)
-- (this should rarely happen, as we refresh the token in the background) -- (this should rarely happen, as we refresh the token in the background)
H.refresh_token(false, false) H.refresh_token(false, false)
local base, body_opts = P.parse_config(provider) local provider_conf, request_body = P.parse_config(provider)
local tools = {} local tools = {}
if prompt_opts.tools then if prompt_opts.tools then
@ -259,10 +259,10 @@ M.parse_curl_args = function(provider, prompt_opts)
end end
return { return {
url = H.chat_completion_url(base.endpoint), url = H.chat_completion_url(provider_conf.endpoint),
timeout = base.timeout, timeout = provider_conf.timeout,
proxy = base.proxy, proxy = provider_conf.proxy,
insecure = base.allow_insecure, insecure = provider_conf.allow_insecure,
headers = { headers = {
["Content-Type"] = "application/json", ["Content-Type"] = "application/json",
["Authorization"] = "Bearer " .. M.state.github_token.token, ["Authorization"] = "Bearer " .. M.state.github_token.token,
@ -270,11 +270,11 @@ M.parse_curl_args = function(provider, prompt_opts)
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch), ["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
}, },
body = vim.tbl_deep_extend("force", { body = vim.tbl_deep_extend("force", {
model = base.model, model = provider_conf.model,
messages = M.parse_messages(prompt_opts), messages = M.parse_messages(prompt_opts),
stream = true, stream = true,
tools = tools, tools = tools,
}, body_opts), }, request_body),
} }
end end

View File

@ -82,26 +82,29 @@ M.parse_response = function(ctx, data_stream, _, opts)
end end
M.parse_curl_args = function(provider, prompt_opts) M.parse_curl_args = function(provider, prompt_opts)
local base, body_opts = P.parse_config(provider) local provider_conf, request_body = P.parse_config(provider)
body_opts = vim.tbl_deep_extend("force", body_opts, { request_body = vim.tbl_deep_extend("force", request_body, {
generationConfig = { generationConfig = {
temperature = body_opts.temperature, temperature = request_body.temperature,
maxOutputTokens = body_opts.max_tokens, maxOutputTokens = request_body.max_tokens,
}, },
}) })
body_opts.temperature = nil request_body.temperature = nil
body_opts.max_tokens = nil request_body.max_tokens = nil
local api_key = provider.parse_api_key() local api_key = provider.parse_api_key()
if api_key == nil then error("Cannot get the gemini api key!") end if api_key == nil then error("Cannot get the gemini api key!") end
return { return {
url = Utils.url_join(base.endpoint, base.model .. ":streamGenerateContent?alt=sse&key=" .. api_key), url = Utils.url_join(
proxy = base.proxy, provider_conf.endpoint,
insecure = base.allow_insecure, provider_conf.model .. ":streamGenerateContent?alt=sse&key=" .. api_key
),
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = { ["Content-Type"] = "application/json" }, headers = { ["Content-Type"] = "application/json" },
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts), body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body),
} }
end end

View File

@ -64,6 +64,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@field __inherited_from? string ---@field __inherited_from? string
---@field temperature? number ---@field temperature? number
---@field max_tokens? number ---@field max_tokens? number
---@field reasoning_effort? string
--- ---
---@class AvanteLLMUsage ---@class AvanteLLMUsage
---@field input_tokens number ---@field input_tokens number
@ -347,9 +348,9 @@ M = setmetatable(M, {
if t[k].has == nil then t[k].has = function() return E.parse_envvar(t[k]) ~= nil end end if t[k].has == nil then t[k].has = function() return E.parse_envvar(t[k]) ~= nil end end
if t[k].setup == nil then if t[k].setup == nil then
local base = M.parse_config(t[k]) local provider_conf = M.parse_config(t[k])
t[k].setup = function() t[k].setup = function()
if E.require_api_key(base) then t[k].parse_api_key() end if E.require_api_key(provider_conf) then t[k].parse_api_key() end
require("avante.tokenizers").setup(t[k].tokenizer_id) require("avante.tokenizers").setup(t[k].tokenizer_id)
end end
end end

View File

@ -223,7 +223,7 @@ M.parse_response = function(ctx, data_stream, _, opts)
end end
ctx.last_think_content = choice.delta.reasoning ctx.last_think_content = choice.delta.reasoning
opts.on_chunk(choice.delta.reasoning) opts.on_chunk(choice.delta.reasoning)
elseif choice.delta.tool_calls then elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then
local tool_call = choice.delta.tool_calls[1] local tool_call = choice.delta.tool_calls[1]
if not ctx.tool_use_list then ctx.tool_use_list = {} end if not ctx.tool_use_list then ctx.tool_use_list = {} end
if not ctx.tool_use_list[tool_call.index + 1] then if not ctx.tool_use_list[tool_call.index + 1] then
@ -272,13 +272,14 @@ end
local Log = require("avante.utils.log") local Log = require("avante.utils.log")
M.parse_curl_args = function(provider, prompt_opts) M.parse_curl_args = function(provider, prompt_opts)
local base, body_opts = P.parse_config(provider) local provider_conf, request_body = P.parse_config(provider)
local disable_tools = base.disable_tools or false local disable_tools = provider_conf.disable_tools or false
local headers = { local headers = {
["Content-Type"] = "application/json", ["Content-Type"] = "application/json",
} }
<<<<<<< HEAD
-- Add appid header for baidu provider -- Add appid header for baidu provider
if Config.provider == "baidu" then if Config.provider == "baidu" then
local baidu_config = Config.get_provider("baidu") local baidu_config = Config.get_provider("baidu")
@ -289,6 +290,9 @@ M.parse_curl_args = function(provider, prompt_opts)
end end
if P.env.require_api_key(base) then if P.env.require_api_key(base) then
=======
if P.env.require_api_key(provider_conf) then
>>>>>>> b6ae4dfe7fe443362f5f31d71797173ec12c2598
local api_key = provider.parse_api_key() local api_key = provider.parse_api_key()
if api_key == nil then if api_key == nil then
error(Config.provider .. " API key is not set, please set it in your environment variable or config file") error(Config.provider .. " API key is not set, please set it in your environment variable or config file")
@ -296,18 +300,18 @@ M.parse_curl_args = function(provider, prompt_opts)
headers["Authorization"] = "Bearer " .. api_key headers["Authorization"] = "Bearer " .. api_key
end end
if M.is_openrouter(base.endpoint) then if M.is_openrouter(provider_conf.endpoint) then
headers["HTTP-Referer"] = "https://github.com/yetone/avante.nvim" headers["HTTP-Referer"] = "https://github.com/yetone/avante.nvim"
headers["X-Title"] = "Avante.nvim" headers["X-Title"] = "Avante.nvim"
body_opts.include_reasoning = true request_body.include_reasoning = true
end end
-- NOTE: When using "o" series set the supported parameters only -- NOTE: When using "o" series set the supported parameters only
local stream = true local stream = true
if M.is_o_series_model(base.model) then if M.is_o_series_model(provider_conf.model) then
body_opts.max_completion_tokens = body_opts.max_tokens request_body.max_completion_tokens = request_body.max_tokens
body_opts.max_tokens = nil request_body.max_tokens = nil
body_opts.temperature = 1 request_body.temperature = 1
end end
local tools = nil local tools = nil
@ -318,20 +322,20 @@ M.parse_curl_args = function(provider, prompt_opts)
end end
end end
Utils.debug("endpoint", base.endpoint) Utils.debug("endpoint", provider_conf.endpoint)
Utils.debug("model", base.model) Utils.debug("model", provider_conf.model)
local request = { return {
url = Utils.url_join(base.endpoint, "/chat/completions"), url = Utils.url_join(provider_conf.endpoint, "/chat/completions"),
proxy = base.proxy, proxy = provider_conf.proxy,
insecure = base.allow_insecure, insecure = provider_conf.allow_insecure,
headers = headers, headers = headers,
body = vim.tbl_deep_extend("force", { body = vim.tbl_deep_extend("force", {
model = base.model, model = provider_conf.model,
messages = M.parse_messages(prompt_opts), messages = M.parse_messages(prompt_opts),
stream = stream, stream = stream,
tools = tools, tools = tools,
}, body_opts), }, request_body),
} }
-- 记录请求详细信息 -- 记录请求详细信息
Log.log_request(request.url, request.headers, request.body) Log.log_request(request.url, request.headers, request.body)

View File

@ -32,22 +32,22 @@ M.parse_api_key = function()
end end
M.parse_curl_args = function(provider, prompt_opts) M.parse_curl_args = function(provider, prompt_opts)
local base, body_opts = P.parse_config(provider) local provider_conf, request_body = P.parse_config(provider)
local location = vim.fn.getenv("LOCATION") or "default-location" local location = vim.fn.getenv("LOCATION") or "default-location"
local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id" local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id"
local model_id = base.model or "default-model-id" local model_id = provider_conf.model or "default-model-id"
local url = base.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id) local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id)
url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id) url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id)
body_opts = vim.tbl_deep_extend("force", body_opts, { request_body = vim.tbl_deep_extend("force", request_body, {
generationConfig = { generationConfig = {
temperature = body_opts.temperature, temperature = request_body.temperature,
maxOutputTokens = body_opts.max_tokens, maxOutputTokens = request_body.max_tokens,
}, },
}) })
body_opts.temperature = nil request_body.temperature = nil
body_opts.max_tokens = nil request_body.max_tokens = nil
local bearer_token = M.parse_api_key() local bearer_token = M.parse_api_key()
return { return {
@ -56,9 +56,9 @@ M.parse_curl_args = function(provider, prompt_opts)
["Authorization"] = "Bearer " .. bearer_token, ["Authorization"] = "Bearer " .. bearer_token,
["Content-Type"] = "application/json; charset=utf-8", ["Content-Type"] = "application/json; charset=utf-8",
}, },
proxy = base.proxy, proxy = provider_conf.proxy,
insecure = base.allow_insecure, insecure = provider_conf.allow_insecure,
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts), body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body),
} }
end end

View File

@ -29,21 +29,10 @@ end
function RepoMap.setup() vim.defer_fn(RepoMap._init_repo_map_lib, 1000) end function RepoMap.setup() vim.defer_fn(RepoMap._init_repo_map_lib, 1000) end
function RepoMap.get_ts_lang(filepath) function RepoMap.get_ts_lang(filepath)
local filetype = RepoMap.get_filetype(filepath) local filetype = Utils.get_filetype(filepath)
return filetype_map[filetype] or filetype return filetype_map[filetype] or filetype
end end
function RepoMap.get_filetype(filepath)
-- Some files are sometimes not detected correctly when buffer is not included
-- https://github.com/neovim/neovim/issues/27265
local buf = vim.api.nvim_create_buf(false, true)
local filetype = vim.filetype.match({ filename = filepath, buf = buf })
vim.api.nvim_buf_delete(buf, { force = true })
return filetype
end
function RepoMap._build_repo_map(project_root, file_ext) function RepoMap._build_repo_map(project_root, file_ext)
local output = {} local output = {}
local gitignore_path = project_root .. "/.gitignore" local gitignore_path = project_root .. "/.gitignore"
@ -70,7 +59,7 @@ function RepoMap._build_repo_map(project_root, file_ext)
if definitions == "" then return end if definitions == "" then return end
table.insert(output, { table.insert(output, {
path = Utils.relative_path(filepath), path = Utils.relative_path(filepath),
lang = RepoMap.get_filetype(filepath), lang = Utils.get_filetype(filepath),
defs = definitions, defs = definitions,
}) })
end) end)
@ -142,7 +131,7 @@ function RepoMap._get_repo_map(file_ext)
if not found then if not found then
table.insert(repo_map, { table.insert(repo_map, {
path = Utils.relative_path(abs_filepath), path = Utils.relative_path(abs_filepath),
lang = RepoMap.get_filetype(abs_filepath), lang = Utils.get_filetype(abs_filepath),
defs = definitions, defs = definitions,
}) })
end end

View File

@ -57,6 +57,7 @@ function Sidebar:new(id)
selected_files_container = nil, selected_files_container = nil,
input_container = nil, input_container = nil,
file_selector = FileSelector:new(id), file_selector = FileSelector:new(id),
is_generating = false,
}, { __index = self }) }, { __index = self })
end end
@ -66,9 +67,24 @@ function Sidebar:delete_autocmds()
end end
function Sidebar:reset() function Sidebar:reset()
-- clean up event handlers
if self.augroup then
api.nvim_del_augroup_by_id(self.augroup)
self.augroup = nil
end
-- clean up keymaps
self:unbind_apply_key() self:unbind_apply_key()
self:unbind_sidebar_keys() self:unbind_sidebar_keys()
self:delete_autocmds()
-- clean up file selector events
if self.file_selector then self.file_selector:off("update") end
if self.result_container then self.result_container:unmount() end
if self.selected_code_container then self.selected_code_container:unmount() end
if self.selected_files_container then self.selected_files_container:unmount() end
if self.input_container then self.input_container:unmount() end
self.code = { bufnr = 0, winid = 0, selection = nil } self.code = { bufnr = 0, winid = 0, selection = nil }
self.winids = self.winids =
{ result_container = 0, selected_files_container = 0, selected_code_container = 0, input_container = 0 } { result_container = 0, selected_files_container = 0, selected_code_container = 0, input_container = 0 }
@ -200,24 +216,34 @@ local function transform_result_content(selected_files, result_content, prev_fil
local current_filepath local current_filepath
local i = 1 local i = 1
while i <= #result_lines do while true do
if i > #result_lines then break end
local line_content = result_lines[i] local line_content = result_lines[i]
if line_content:match("<FILEPATH>.+</FILEPATH>") then if line_content:match("<[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>.+</[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>") then
local filepath = line_content:match("<FILEPATH>(.+)</FILEPATH>") local filepath = line_content:match("<[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>(.+)</[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>")
if filepath then if filepath then
current_filepath = filepath current_filepath = filepath
table.insert(transformed_lines, string.format("Filepath: %s", filepath)) table.insert(transformed_lines, string.format("Filepath: %s", filepath))
goto continue goto continue
end end
end end
if line_content == "<SEARCH>" then if line_content:match("^%s*<[Ss][Ee][Aa][Rr][Cc][Hh]>") then
is_searching = true is_searching = true
if not line_content:match("^%s*<[Ss][Ee][Aa][Rr][Cc][Hh]>%s*$") then
local search_start_line = line_content:match("<[Ss][Ee][Aa][Rr][Cc][Hh]>(.+)$")
line_content = "<SEARCH>"
result_lines[i] = line_content
if search_start_line and search_start_line ~= "" then table.insert(result_lines, i + 1, search_start_line) end
end
line_content = "<SEARCH>"
local prev_line = result_lines[i - 1] local prev_line = result_lines[i - 1]
if if
prev_line prev_line
and prev_filepath and prev_filepath
and not prev_line:match("Filepath:.+") and not prev_line:match("Filepath:.+")
and not prev_line:match("<FILEPATH>.+</FILEPATH>") and not prev_line:match("<[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>.+</[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>")
then then
table.insert(transformed_lines, string.format("Filepath: %s", prev_filepath)) table.insert(transformed_lines, string.format("Filepath: %s", prev_filepath))
end end
@ -225,7 +251,23 @@ local function transform_result_content(selected_files, result_content, prev_fil
if next_line and next_line:match("^%s*```%w+$") then i = i + 1 end if next_line and next_line:match("^%s*```%w+$") then i = i + 1 end
search_start = i + 1 search_start = i + 1
last_search_tag_start_line = i last_search_tag_start_line = i
elseif line_content == "</SEARCH>" then elseif line_content:match("</[Ss][Ee][Aa][Rr][Cc][Hh]>%s*$") then
if is_replacing then
result_lines[i] = line_content:gsub("</[Ss][Ee][Aa][Rr][Cc][Hh]>", "</REPLACE>")
goto continue_without_increment
end
-- Handle case where </SEARCH> is a suffix
if not line_content:match("^%s*</[Ss][Ee][Aa][Rr][Cc][Hh]>%s*$") then
local search_end_line = line_content:match("^(.+)</[Ss][Ee][Aa][Rr][Cc][Hh]>")
line_content = "</SEARCH>"
result_lines[i] = line_content
if search_end_line and search_end_line ~= "" then
table.insert(result_lines, i, search_end_line)
goto continue_without_increment
end
end
is_searching = false is_searching = false
local search_end = i local search_end = i
@ -248,24 +290,28 @@ local function transform_result_content(selected_files, result_content, prev_fil
if not the_matched_file then if not the_matched_file then
if not PPath:new(filepath):exists() then if not PPath:new(filepath):exists() then
Utils.warn("File not found: " .. filepath) the_matched_file = {
goto continue filepath = filepath,
content = "",
file_type = nil,
}
else
if not PPath:new(filepath):is_file() then
Utils.warn("Not a file: " .. filepath)
goto continue
end
local lines = Utils.read_file_from_buf_or_disk(filepath)
if lines == nil then
Utils.warn("Failed to read file: " .. filepath)
goto continue
end
local content = table.concat(lines, "\n")
the_matched_file = {
filepath = filepath,
content = content,
file_type = nil,
}
end end
if not PPath:new(filepath):is_file() then
Utils.warn("Not a file: " .. filepath)
goto continue
end
local lines = Utils.read_file_from_buf_or_disk(filepath)
if lines == nil then
Utils.warn("Failed to read file: " .. filepath)
goto continue
end
local content = table.concat(lines, "\n")
the_matched_file = {
filepath = filepath,
content = content,
file_type = nil,
}
end end
local file_content = vim.split(the_matched_file.content, "\n") local file_content = vim.split(the_matched_file.content, "\n")
@ -292,8 +338,7 @@ local function transform_result_content(selected_files, result_content, prev_fil
-- can happen if the llm tries to edit or create a file outside of it's context. -- can happen if the llm tries to edit or create a file outside of it's context.
if not match_filetype then if not match_filetype then
local snippet_file_path = current_filepath or prev_filepath local snippet_file_path = current_filepath or prev_filepath
local snippet_file_type = vim.filetype.match({ filename = snippet_file_path }) or "unknown" match_filetype = Utils.get_filetype(snippet_file_path)
match_filetype = snippet_file_type
end end
local search_start_tag_idx_in_transformed_lines = 0 local search_start_tag_idx_in_transformed_lines = 0
@ -311,13 +356,31 @@ local function transform_result_content(selected_files, result_content, prev_fil
string.format("```%s", match_filetype), string.format("```%s", match_filetype),
}) })
goto continue goto continue
elseif line_content == "<REPLACE>" then elseif line_content:match("^%s*<[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>") then
is_replacing = true is_replacing = true
if not line_content:match("^%s*<[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>%s*$") then
local replace_first_line = line_content:match("<[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>(.+)$")
line_content = "<REPLACE>"
result_lines[i] = line_content
if replace_first_line and replace_first_line ~= "" then
table.insert(result_lines, i + 1, replace_first_line)
end
end
local next_line = result_lines[i + 1] local next_line = result_lines[i + 1]
if next_line and next_line:match("^%s*```%w+$") then i = i + 1 end if next_line and next_line:match("^%s*```%w+$") then i = i + 1 end
last_replace_tag_start_line = i last_replace_tag_start_line = i
goto continue goto continue
elseif line_content == "</REPLACE>" then elseif line_content:match("</[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>%s*$") then
-- Handle case where </REPLACE> is a suffix
if not line_content:match("^%s*</[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>%s*$") then
local replace_end_line = line_content:match("^(.+)</[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>")
line_content = "</REPLACE>"
result_lines[i] = line_content
if replace_end_line and replace_end_line ~= "" then
table.insert(result_lines, i, replace_end_line)
goto continue_without_increment
end
end
is_replacing = false is_replacing = false
local prev_line = result_lines[i - 1] local prev_line = result_lines[i - 1]
if not (prev_line and prev_line:match("^%s*```$")) then table.insert(transformed_lines, "```") end if not (prev_line and prev_line:match("^%s*```$")) then table.insert(transformed_lines, "```") end
@ -332,6 +395,7 @@ local function transform_result_content(selected_files, result_content, prev_fil
table.insert(transformed_lines, line_content) table.insert(transformed_lines, line_content)
::continue:: ::continue::
i = i + 1 i = i + 1
::continue_without_increment::
end end
return { return {
@ -397,8 +461,8 @@ local function get_searching_hint()
end end
local thinking_spinner_chars = { local thinking_spinner_chars = {
"🤯", Utils.icon("🤯", "?"),
"🙄", Utils.icon("🙄", "¿"),
} }
local thinking_spinner_index = 1 local thinking_spinner_index = 1
@ -437,8 +501,10 @@ local function generate_display_content(replacement)
return string.format(" > %s", line) return string.format(" > %s", line)
end) end)
:totable() :totable()
local result_lines = local result_lines = vim.list_extend(
vim.list_extend(vim.list_slice(lines, 1, replacement.last_search_tag_start_line), { "🤔 Thought content:" }) vim.list_slice(lines, 1, replacement.last_search_tag_start_line),
{ Utils.icon("🤔 ") .. "Thought content:" }
)
result_lines = vim.list_extend(result_lines, formatted_thinking_content_lines) result_lines = vim.list_extend(result_lines, formatted_thinking_content_lines)
result_lines = vim.list_extend(result_lines, vim.list_slice(lines, last_think_tag_end_line + 1)) result_lines = vim.list_extend(result_lines, vim.list_slice(lines, last_think_tag_end_line + 1))
return table.concat(result_lines, "\n") return table.concat(result_lines, "\n")
@ -695,28 +761,22 @@ local function minimize_snippet(original_lines, snippet)
return new_snippets return new_snippets
end end
---@param snippets_map table<string, AvanteCodeSnippet[]> ---@param filepath string
---@param snippets AvanteCodeSnippet[]
---@return table<string, AvanteCodeSnippet[]> ---@return table<string, AvanteCodeSnippet[]>
function Sidebar:minimize_snippets(snippets_map) function Sidebar:minimize_snippets(filepath, snippets)
local original_lines = {} local original_lines = {}
if vim.tbl_count(snippets_map) > 0 then local original_lines_ = Utils.read_file_from_buf_or_disk(filepath)
local filepaths = vim.tbl_keys(snippets_map) if original_lines_ then original_lines = original_lines_ end
local original_lines_, _, err = Utils.read_file_from_buf_or_disk(filepaths[1])
if err ~= nil then return {} end
if original_lines_ then original_lines = original_lines_ end
end
local results = {} local results = {}
for filepath, snippets in pairs(snippets_map) do for _, snippet in ipairs(snippets) do
for _, snippet in ipairs(snippets) do local new_snippets = minimize_snippet(original_lines, snippet)
local new_snippets = minimize_snippet(original_lines, snippet) if new_snippets then
if new_snippets then for _, new_snippet in ipairs(new_snippets) do
results[filepath] = results[filepath] or {} table.insert(results, new_snippet)
for _, new_snippet in ipairs(new_snippets) do
table.insert(results[filepath], new_snippet)
end
end end
end end
end end
@ -749,12 +809,13 @@ function Sidebar:apply(current_cursor)
selected_snippets_map = all_snippets_map selected_snippets_map = all_snippets_map
end end
if Config.behaviour.minimize_diff then selected_snippets_map = self:minimize_snippets(selected_snippets_map) end
vim.defer_fn(function() vim.defer_fn(function()
api.nvim_set_current_win(self.code.winid) api.nvim_set_current_win(self.code.winid)
for filepath, snippets in pairs(selected_snippets_map) do for filepath, snippets in pairs(selected_snippets_map) do
if Config.behaviour.minimize_diff then snippets = self:minimize_snippets(filepath, snippets) end
local bufnr = Utils.get_or_create_buffer_with_filepath(filepath) local bufnr = Utils.get_or_create_buffer_with_filepath(filepath)
local path_ = PPath:new(filepath)
path_:parent():mkdir({ parents = true, exists_ok = true })
insert_conflict_contents(bufnr, snippets) insert_conflict_contents(bufnr, snippets)
local process = function(winid) local process = function(winid)
api.nvim_set_current_win(winid) api.nvim_set_current_win(winid)
@ -845,7 +906,7 @@ function Sidebar:render_result()
then then
return return
end end
local header_text = "󰭻 Avante" local header_text = Utils.icon("󰭻 ") .. "Avante"
self:render_header( self:render_header(
self.result_container.winid, self.result_container.winid,
self.result_container.bufnr, self.result_container.bufnr,
@ -867,13 +928,15 @@ function Sidebar:render_input(ask)
end end
local header_text = string.format( local header_text = string.format(
"󱜸 %s (" .. Config.mappings.sidebar.switch_windows .. ": switch focus)", "%s%s (" .. Config.mappings.sidebar.switch_windows .. ": switch focus)",
Utils.icon("󱜸 "),
ask and "Ask" or "Chat with" ask and "Ask" or "Chat with"
) )
if self.code.selection ~= nil then if self.code.selection ~= nil then
header_text = string.format( header_text = string.format(
"󱜸 %s (%d:%d) (<Tab>: switch focus)", "%s%s (%d:%d) (<Tab>: switch focus)",
Utils.icon("󱜸 "),
ask and "Ask" or "Chat with", ask and "Ask" or "Chat with",
self.code.selection.range.start.lnum, self.code.selection.range.start.lnum,
self.code.selection.range.finish.lnum self.code.selection.range.finish.lnum
@ -906,7 +969,8 @@ function Sidebar:render_selected_code()
selected_code_lines_count = #selected_code_lines selected_code_lines_count = #selected_code_lines
end end
local header_text = " Selected Code" local header_text = Utils.icon("")
.. "Selected Code"
.. ( .. (
selected_code_lines_count > selected_code_max_lines_count selected_code_lines_count > selected_code_max_lines_count
and " (Show only the first " .. tostring(selected_code_max_lines_count) .. " lines)" and " (Show only the first " .. tostring(selected_code_max_lines_count) .. " lines)"
@ -1312,6 +1376,20 @@ function Sidebar:initialize()
return self return self
end end
function Sidebar:is_focused()
if not self:is_open() then return false end
local current_winid = api.nvim_get_current_win()
if self.winids.result_container and self.winids.result_container == current_winid then return true end
if self.winids.selected_files_container and self.winids.selected_files_container == current_winid then
return true
end
if self.winids.selected_code_container and self.winids.selected_code_container == current_winid then return true end
if self.winids.input_container and self.winids.input_container == current_winid then return true end
return false
end
function Sidebar:is_focused_on_result() function Sidebar:is_focused_on_result()
return self:is_open() and self.result_container and self.result_container.winid == api.nvim_get_current_win() return self:is_open() and self.result_container and self.result_container.winid == api.nvim_get_current_win()
end end
@ -1892,6 +1970,8 @@ function Sidebar:create_input_container(opts)
---@type AvanteLLMChunkCallback ---@type AvanteLLMChunkCallback
local on_chunk = function(chunk) local on_chunk = function(chunk)
self.is_generating = true
original_response = original_response .. chunk original_response = original_response .. chunk
local selected_files = self.file_selector:get_selected_files_contents() local selected_files = self.file_selector:get_selected_files_contents()
@ -1926,6 +2006,8 @@ function Sidebar:create_input_container(opts)
---@type AvanteLLMStopCallback ---@type AvanteLLMStopCallback
local on_stop = function(stop_opts) local on_stop = function(stop_opts)
self.is_generating = false
pcall(function() pcall(function()
---remove keymaps ---remove keymaps
vim.keymap.del("n", "j", { buffer = self.result_container.bufnr }) vim.keymap.del("n", "j", { buffer = self.result_container.bufnr })
@ -2039,6 +2121,7 @@ function Sidebar:create_input_container(opts)
local request = table.concat(lines, "\n") local request = table.concat(lines, "\n")
if request == "" then return end if request == "" then return end
api.nvim_buf_set_lines(self.input_container.bufnr, 0, -1, false, {}) api.nvim_buf_set_lines(self.input_container.bufnr, 0, -1, false, {})
api.nvim_win_set_cursor(self.input_container.winid, { 1, 0 })
handle_submit(request) handle_submit(request)
end end
@ -2454,7 +2537,7 @@ function Sidebar:create_selected_files_container()
self:render_header( self:render_header(
self.selected_files_container.winid, self.selected_files_container.winid,
selected_files_buf, selected_files_buf,
"Selected Files", Utils.icon("") .. "Selected Files",
Highlights.SUBTITLE, Highlights.SUBTITLE,
Highlights.REVERSED_SUBTITLE Highlights.REVERSED_SUBTITLE
) )

View File

@ -84,8 +84,8 @@ function Suggestion:suggest()
L1: def fib L1: def fib
L2: L2:
L3: if __name__ == "__main__": L3: if __name__ == "__main__":
L4: # just pass L4: # just pass
L5: pass L5: pass
</code> </code>
]], ]],
}, },
@ -95,7 +95,7 @@ L5: pass
}, },
{ {
role = "user", role = "user",
content = '<question>{ "indentSize": 4, "position": { "row": 1, "col": 2 } }</question>', content = '<question>{"insertSpaces":true,"tabSize":4,"indentSize":4,"position":{"row":1,"col":7}}</question>',
}, },
{ {
role = "assistant", role = "assistant",

View File

@ -10,10 +10,17 @@
Act as an expert software developer. Act as an expert software developer.
Always use best practices when coding. Always use best practices when coding.
Respect and use existing conventions, libraries, etc that are already present in the code base. Respect and use existing conventions, libraries, etc that are already present in the code base.
You have access to tools, but only use them when necessary. If a tool is not required, respond as normal. Don't directly search for code context in historical messages. Instead, prioritize using tools to obtain context first, then use context from historical messages as a secondary source, since context from historical messages is often not up to date.
If you encounter a URL, prioritize using the fetch tool to obtain its content.
If you have information that you don't know, please proactively use the tools provided by users! Especially the web search tool. Tools Usage Guide:
When available tools cannot meet the requirements, please try to use the `run_command` tool to solve the problem whenever possible. - You have access to tools, but only use them when necessary. If a tool is not required, respond as normal.
- If you encounter a URL, prioritize using the fetch tool to obtain its content.
- If you have information that you don't know, please proactively use the tools provided by users! Especially the web search tool.
- When available tools cannot meet the requirements, please try to use the `run_command` tool to solve the problem whenever possible.
- When attempting to modify a file that is not in the context, please first use the `list_files` tool and `search_files` tool to check if the file you want to modify exists, then use the `read_file` tool to read the file content. Don't modify blindly!
- When generating files, first use `list_files` tool to read the directory structure, don't generate blindly!
- When creating files, first check if the directory exists. If it doesn't exist, create the directory before creating the file.
- After `web_search`, if you don't get detailed enough information, do not continue use `web_search`, just continue using the `fetch` tool to get more information you need from the links in the search results.
{% if system_info -%} {% if system_info -%}
Use the appropriate shell based on the user's system info: Use the appropriate shell based on the user's system info:

View File

@ -665,6 +665,19 @@ function M.scan_directory_respect_gitignore(options)
local directory = options.directory local directory = options.directory
local gitignore_path = directory .. "/.gitignore" local gitignore_path = directory .. "/.gitignore"
local gitignore_patterns, gitignore_negate_patterns = M.parse_gitignore(gitignore_path) local gitignore_patterns, gitignore_negate_patterns = M.parse_gitignore(gitignore_path)
-- Convert relative paths in gitignore to absolute paths based on project root
local project_root = M.get_project_root()
local function to_absolute_path(pattern)
-- Skip if already absolute path
if pattern:sub(1, 1) == "/" then return pattern end
-- Convert relative path to absolute
return Path:new(project_root, pattern):absolute()
end
gitignore_patterns = vim.tbl_map(to_absolute_path, gitignore_patterns)
gitignore_negate_patterns = vim.tbl_map(to_absolute_path, gitignore_negate_patterns)
return M.scan_directory({ return M.scan_directory({
directory = directory, directory = directory,
gitignore_patterns = gitignore_patterns, gitignore_patterns = gitignore_patterns,
@ -890,9 +903,19 @@ function M.is_same_file(filepath_a, filepath_b) return M.uniform_path(filepath_a
function M.trim_think_content(content) return content:gsub("^<think>.-</think>", "", 1) end function M.trim_think_content(content) return content:gsub("^<think>.-</think>", "", 1) end
function M.get_filetype(filepath)
-- Some files are sometimes not detected correctly when buffer is not included
-- https://github.com/neovim/neovim/issues/27265
local buf = vim.api.nvim_create_buf(false, true)
local filetype = vim.filetype.match({ filename = filepath, buf = buf }) or ""
vim.api.nvim_buf_delete(buf, { force = true })
return filetype
end
---@param file_path string ---@param file_path string
---@return string[]|nil lines ---@return string[]|nil lines
---@return string|nil file_type
---@return string|nil error ---@return string|nil error
function M.read_file_from_buf_or_disk(file_path) function M.read_file_from_buf_or_disk(file_path)
--- Lookup if the file is loaded in a buffer --- Lookup if the file is loaded in a buffer
@ -900,8 +923,7 @@ function M.read_file_from_buf_or_disk(file_path)
if bufnr ~= -1 and vim.api.nvim_buf_is_loaded(bufnr) then if bufnr ~= -1 and vim.api.nvim_buf_is_loaded(bufnr) then
-- If buffer exists and is loaded, get buffer content -- If buffer exists and is loaded, get buffer content
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local file_type = vim.api.nvim_get_option_value("filetype", { buf = bufnr }) return lines, nil
return lines, file_type, nil
end end
-- Fallback: read file from disk -- Fallback: read file from disk
@ -909,12 +931,28 @@ function M.read_file_from_buf_or_disk(file_path)
if file then if file then
local content = file:read("*all") local content = file:read("*all")
file:close() file:close()
-- Detect the file type using the specific file's content return vim.split(content, "\n"), nil
local file_type = vim.filetype.match({ filename = file_path, contents = { content } }) or "unknown"
return vim.split(content, "\n"), file_type, nil
else else
M.error("failed to open file: " .. file_path .. " with error: " .. open_err) -- M.error("failed to open file: " .. file_path .. " with error: " .. open_err)
return {}, nil, open_err return {}, open_err
end
end
---Check if an icon plugin is installed
---@return boolean
M.icons_enabled = function() return M.has("nvim-web-devicons") or M.has("mini.icons") or M.has("mini.nvim") end
---Display an string with icon, if an icon plugin is available.
---Dev icons are an optional install for avante, this function prevents ugly chars
---being displayed by displaying fallback options or nothing at all.
---@param string_with_icon string
---@param utf8_fallback string|nil
---@return string
M.icon = function(string_with_icon, utf8_fallback)
if M.icons_enabled() then
return string_with_icon
else
return utf8_fallback or ""
end end
end end