diff --git a/.cargo/config.toml b/.cargo/config.toml index af95132..83d9435 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -3,3 +3,6 @@ rustflags = ["-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup"] [target.aarch64-apple-darwin] rustflags = ["-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup"] + +[target.x86_64-unknown-linux-musl] +rustflags = ["-C", "target-feature=-crt-static"] diff --git a/README.md b/README.md index 18ab5c1..c7f3363 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ For building binary if you wish to build from source, then `cargo` is required. timeout = 30000, -- timeout in milliseconds temperature = 0, -- adjust if needed max_tokens = 4096, + reasoning_effort = "high" -- only supported for "o" models }, }, -- if you want to build from source then do `make BUILD_FROM_SOURCE=true` @@ -380,12 +381,44 @@ This is achieved by emulating nvim-cmp using blink.compat ```lua file_selector = { - --- @alias FileSelectorProvider "native" | "fzf" | "mini.pick" | "snacks" | "telescope" | string + --- @alias FileSelectorProvider "native" | "fzf" | "mini.pick" | "snacks" | "telescope" | string | fun(params: avante.file_selector.IParams|nil): nil provider = "fzf", -- Options override for custom providers provider_opts = {}, } ``` + +To create a customized file_selector, you can specify a customized function to launch a picker to select items and pass the selected items to the `handler` callback. + +```lua + file_selector = { + ---@param params avante.file_selector.IParams + provider = function(params) + local filepaths = params.filepaths ---@type string[] + local title = params.title ---@type string + local handler = params.handler ---@type fun(selected_filepaths: string[]|nil): nil + + -- Launch your customized picker with the items built from `filepaths`, then in the `on_confirm` callback, + -- pass the selected items (convert back to file paths) to the `handler` function. + + local items = __your_items_formatter__(filepaths) + __your_picker__({ + items = items, + on_cancel = function() + handler(nil) + end, + on_confirm = function(selected_items) + local selected_filepaths = {} + for _, item in ipairs(selected_items) do + table.insert(selected_filepaths, item.filepath) + end + handler(selected_filepaths) + end + }) + end, + } +``` + Choose a selector other that native, the default as that currently has an issue For lazyvim users copy the full config for blink.cmp from the website or extend the options ```lua @@ -471,8 +504,10 @@ Given its early stage, `avante.nvim` currently supports the following basic func > For Amazon Bedrock: > > ```sh -> export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region +> export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region[,aws_session_token] +> > ``` +> Note: The aws_session_token is optional and only needed when using temporary AWS credentials 1. Open a code file in Neovim. 2. Use the `:AvanteAsk` command to query the AI about the code. @@ -548,15 +583,16 @@ For more information, see [Custom Providers](https://github.com/yetone/avante.nv ## Web Search Engines -Avante's tools include some web search engines, currently support [tavily](https://tavily.com/) and [serpapi](https://serpapi.com/). The default is tavily, and can be changed through configuring `Config.web_search_engine.provider`: +Avante's tools include some web search engines, currently support [tavily](https://tavily.com/), [serpapi](https://serpapi.com/), [searchapi](https://www.searchapi.io/) and google's [programmable search engine](https://developers.google.com/custom-search/v1/overview). The default is tavily, and can be changed through configuring `Config.web_search_engine.provider`: ```lua web_search_engine = { - provider = "tavily", -- tavily or serpapi + provider = "tavily", -- tavily, serpapi, searchapi or google } ``` -You need to set the environment variable `TAVILY_API_KEY` or `SERPAPI_API_KEY` to use tavily or serpapi. +You need to set the environment variable `TAVILY_API_KEY` , `SERPAPI_API_KEY`, `SEARCHAPI_API_KEY` to use tavily or serpapi or searchapi. +To use google, set the `GOOGLE_SEARCH_API_KEY` as the [API key](https://developers.google.com/custom-search/v1/overview), and `GOOGLE_SEARCH_ENGINE_ID` as the [search engine](https://programmablesearchengine.google.com) ID. ## Disable Tools diff --git a/lua/avante/config.lua b/lua/avante/config.lua index fd099f0..c8d32ff 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -5,6 +5,11 @@ local Utils = require("avante.utils") +---@class avante.file_selector.IParams +---@field public title string +---@field public filepaths string[] +---@field public handler fun(filepaths: string[]|nil): nil + ---@class avante.CoreConfig: avante.Config local M = {} ---@class avante.Config @@ -28,11 +33,10 @@ M._defaults = { tavily = { api_key_name = "TAVILY_API_KEY", extra_request_body = { - time_range = "d", include_answer = "basic", }, ---@type WebSearchEngineProviderResponseBodyFormatter - format_response_body = function(body) return body.anwser, nil end, + format_response_body = function(body) return body.answer, nil end, }, serpapi = { api_key_name = "SERPAPI_API_KEY", @@ -52,11 +56,65 @@ M._defaults = { title = result.title, link = result.link, snippet = result.snippet, + date = result.date, } end ) + :take(10) + :totable() + return vim.json.encode(jsn), nil + end + return "", nil + end, + }, + searchapi = { + api_key_name = "SEARCHAPI_API_KEY", + extra_request_body = { + engine = "google", + }, + ---@type WebSearchEngineProviderResponseBodyFormatter + format_response_body = function(body) + if body.answer_box ~= nil then return body.answer_box.result, nil end + if body.organic_results ~= nil then + local jsn = vim + .iter(body.organic_results) + :map( + function(result) + return { + title = result.title, + link = result.link, + snippet = result.snippet, + date = result.date, + } + end + ) + :take(10) + :totable() + return vim.json.encode(jsn), nil + end + return "", nil + end, + }, + google = { + api_key_name = "GOOGLE_SEARCH_API_KEY", + engine_id_name = "GOOGLE_SEARCH_ENGINE_ID", + extra_request_body = {}, + ---@type WebSearchEngineProviderResponseBodyFormatter + format_response_body = function(body) + if body.items ~= nil then + local jsn = vim + .iter(body.items) + :map( + function(result) + return { + title = result.title, + link = result.link, + snippet = result.snippet, + } + end + ) + :take(10) :totable() - if #jsn > 5 then jsn = vim.list_slice(jsn, 1, 5) end return vim.json.encode(jsn), nil end return "", nil @@ -307,7 +365,7 @@ M._defaults = { }, --- @class AvanteFileSelectorConfig file_selector = { - --- @alias FileSelectorProvider "native" | "fzf" | "mini.pick" | "snacks" | "telescope" | string + --- @alias FileSelectorProvider "native" | "fzf" | "mini.pick" | "snacks" | "telescope" | string | fun(params: avante.file_selector.IParams|nil): nil provider = "native", -- Options override for custom providers provider_opts = {}, diff --git a/lua/avante/file_selector.lua b/lua/avante/file_selector.lua index 4af6ac6..6eece15 100644 --- a/lua/avante/file_selector.lua +++ b/lua/avante/file_selector.lua @@ -330,6 +330,11 @@ function FileSelector:show_select_ui() self:snacks_picker_ui(handler) elseif Config.file_selector.provider == "telescope" then self:telescope_ui(handler) + elseif type(Config.file_selector.provider) == "function" then + local title = string.format("%s:", PROMPT_TITLE) ---@type string + local filepaths = self:get_filepaths() ---@type string[] + local params = { title = title, filepaths = filepaths, handler = handler } ---@type avante.file_selector.IParams + Config.file_selector.provider(params) else Utils.error("Unknown file selector provider: " .. Config.file_selector.provider) end @@ -363,9 +368,9 @@ end function FileSelector:get_selected_files_contents() local contents = {} for _, file_path in ipairs(self.selected_filepaths) do - local lines, filetype, error = Utils.read_file_from_buf_or_disk(file_path) + local lines, error = Utils.read_file_from_buf_or_disk(file_path) lines = lines or {} - filetype = filetype or "unknown" + local filetype = Utils.get_filetype(file_path) if error ~= nil then Utils.error("error reading file: " .. error) else diff --git a/lua/avante/health.lua b/lua/avante/health.lua index 357db84..78804ac 100644 --- a/lua/avante/health.lua +++ b/lua/avante/health.lua @@ -23,9 +23,7 @@ M.check = function() end -- Optional dependencies - local has_devicons = Utils.has("nvim-web-devicons") - local has_mini_icons = Utils.has("mini.icons") or Utils.has("mini.nvim") - if has_devicons or has_mini_icons then + if Utils.icons_enabled() then H.ok("Found icons plugin (nvim-web-devicons or mini.icons)") else H.warn("No icons plugin found (nvim-web-devicons or mini.icons). Icons will not be displayed") diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 522d54b..064cb10 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -25,8 +25,8 @@ M.generate_prompts = function(opts) local Provider = opts.provider or P[Config.provider] local mode = opts.mode or "planning" ---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor - local _, body_opts = P.parse_config(Provider) - local max_tokens = body_opts.max_tokens or 4096 + local _, request_body = P.parse_config(Provider) + local max_tokens = request_body.max_tokens or 4096 -- Check if the instructions contains an image path local image_paths = {} @@ -449,7 +449,7 @@ M.stream = function(opts) return original_on_stop(stop_opts) end) end - if Config.dual_boost.enabled then + if Config.dual_boost.enabled and opts.mode == "planning" then M._dual_boost_stream(opts, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider]) else M._stream(opts) diff --git a/lua/avante/llm_tools.lua b/lua/avante/llm_tools.lua index 9d75346..5dbec4d 100644 --- a/lua/avante/llm_tools.lua +++ b/lua/avante/llm_tools.lua @@ -7,6 +7,7 @@ local M = {} ---@param rel_path string ---@return string local function get_abs_path(rel_path) + if Path:new(rel_path):is_absolute() then return rel_path end local project_root = Utils.get_project_root() return Path:new(project_root):joinpath(rel_path):absolute() end @@ -41,13 +42,12 @@ function M.list_files(opts, on_log) add_dirs = true, depth = opts.depth, }) - local result = "" + local filepaths = {} for _, file in ipairs(files) do local uniform_path = Utils.uniform_path(file) - result = result .. uniform_path .. "\n" + table.insert(filepaths, uniform_path) end - result = result:gsub("\n$", "") - return result, nil + return vim.json.encode(filepaths), nil end ---@param opts { rel_path: string, keyword: string } @@ -62,12 +62,11 @@ function M.search_files(opts, on_log) local files = Utils.scan_directory_respect_gitignore({ directory = abs_path, }) - local result = "" + local filepaths = {} for _, file in ipairs(files) do - if file:find(opts.keyword) then result = result .. file .. "\n" end + if file:find(opts.keyword) then table.insert(filepaths, file) end end - result = result:gsub("\n$", "") - return result, nil + return vim.json.encode(filepaths), nil end ---@param opts { rel_path: string, keyword: string } @@ -104,7 +103,9 @@ function M.search(opts, on_log) if on_log then on_log("Running command: " .. cmd) end local result = vim.fn.system(cmd) - return result or "", nil + local filepaths = vim.split(result, "\n") + + return vim.json.encode(filepaths), nil end ---@param opts { rel_path: string } @@ -183,9 +184,10 @@ function M.rename_file(opts, on_log) end ---@param opts { rel_path: string, new_rel_path: string } +---@param on_log? fun(log: string): nil ---@return boolean success ---@return string|nil error -function M.copy_file(opts) +function M.copy_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end @@ -193,38 +195,47 @@ function M.copy_file(opts) local new_abs_path = get_abs_path(opts.new_rel_path) if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end + if on_log then on_log("Copying file: " .. abs_path .. " to " .. new_abs_path) end Path:new(new_abs_path):write(Path:new(abs_path):read()) return true, nil end ---@param opts { rel_path: string } +---@param on_log? fun(log: string): nil ---@return boolean success ---@return string|nil error -function M.delete_file(opts) +function M.delete_file(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end if not M.confirm("Are you sure you want to delete the file: " .. abs_path) then return false, "User canceled" end + if on_log then on_log("Deleting file: " .. abs_path) end os.remove(abs_path) return true, nil end ---@param opts { rel_path: string } +---@param on_log? fun(log: string): nil ---@return boolean success ---@return string|nil error -function M.create_dir(opts) +function M.create_dir(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end + if not M.confirm("Are you sure you want to create the directory: " .. abs_path) then + return false, "User canceled" + end + if on_log then on_log("Creating directory: " .. abs_path) end Path:new(abs_path):mkdir({ parents = true }) return true, nil end ---@param opts { rel_path: string, new_rel_path: string } +---@param on_log? fun(log: string): nil ---@return boolean success ---@return string|nil error -function M.rename_dir(opts) +function M.rename_dir(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end @@ -235,14 +246,16 @@ function M.rename_dir(opts) if not M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?") then return false, "User canceled" end + if on_log then on_log("Renaming directory: " .. abs_path .. " to " .. new_abs_path) end os.rename(abs_path, new_abs_path) return true, nil end ---@param opts { rel_path: string } +---@param on_log? fun(log: string): nil ---@return boolean success ---@return string|nil error -function M.delete_dir(opts) +function M.delete_dir(opts, on_log) local abs_path = get_abs_path(opts.rel_path) if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end @@ -250,6 +263,7 @@ function M.delete_dir(opts) if not M.confirm("Are you sure you want to delete the directory: " .. abs_path) then return false, "User canceled" end + if on_log then on_log("Deleting directory: " .. abs_path) end os.remove(abs_path) return true, nil end @@ -326,6 +340,45 @@ function M.web_search(opts, on_log) if resp.status ~= 200 then return nil, "Error: " .. resp.body end local jsn = vim.json.decode(resp.body) return search_engine.format_response_body(jsn) + elseif provider_type == "searchapi" then + local query_params = vim.tbl_deep_extend("force", { + api_key = api_key, + q = opts.query, + }, search_engine.extra_request_body) + local query_string = "" + for key, value in pairs(query_params) do + query_string = query_string .. key .. "=" .. vim.uri_encode(value) .. "&" + end + local resp = curl.get("https://searchapi.io/api/v1/search?" .. query_string, { + headers = { + ["Content-Type"] = "application/json", + }, + }) + if resp.status ~= 200 then return nil, "Error: " .. resp.body end + local jsn = vim.json.decode(resp.body) + return search_engine.format_response_body(jsn) + elseif provider_type == "google" then + local engine_id = os.getenv(search_engine.engine_id_name) + if engine_id == nil or engine_id == "" then + return nil, "Environment variable " .. search_engine.engine_id_namee .. " is not set" + end + local query_params = vim.tbl_deep_extend("force", { + key = api_key, + cx = engine_id, + q = opts.query, + }, search_engine.extra_request_body) + local query_string = "" + for key, value in pairs(query_params) do + query_string = query_string .. key .. "=" .. vim.uri_encode(value) .. "&" + end + local resp = curl.get("https://www.googleapis.com/customsearch/v1?" .. query_string, { + headers = { + ["Content-Type"] = "application/json", + }, + }) + if resp.status ~= 200 then return nil, "Error: " .. resp.body end + local jsn = vim.json.decode(resp.body) + return search_engine.format_response_body(jsn) end end diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua index 83ca617..270e1a2 100644 --- a/lua/avante/providers/azure.lua +++ b/lua/avante/providers/azure.lua @@ -18,31 +18,34 @@ M.parse_response = O.parse_response M.parse_response_without_stream = O.parse_response_without_stream M.parse_curl_args = function(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(provider) local headers = { ["Content-Type"] = "application/json", } - if P.env.require_api_key(base) then headers["api-key"] = provider.parse_api_key() end + if P.env.require_api_key(provider_conf) then headers["api-key"] = provider.parse_api_key() end -- NOTE: When using "o" series set the supported parameters only - if O.is_o_series_model(base.model) then - body_opts.max_tokens = nil - body_opts.temperature = 1 + if O.is_o_series_model(provider_conf.model) then + request_body.max_tokens = nil + request_body.temperature = 1 end return { url = Utils.url_join( - base.endpoint, - "/openai/deployments/" .. base.deployment .. "/chat/completions?api-version=" .. base.api_version + provider_conf.endpoint, + "/openai/deployments/" + .. provider_conf.deployment + .. "/chat/completions?api-version=" + .. provider_conf.api_version ), - proxy = base.proxy, - insecure = base.allow_insecure, + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = headers, body = vim.tbl_deep_extend("force", { messages = M.parse_messages(prompt_opts), stream = true, - }, body_opts), + }, request_body), } end diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua index 2b0cd7e..07ef4f0 100644 --- a/lua/avante/providers/bedrock.lua +++ b/lua/avante/providers/bedrock.lua @@ -1,5 +1,4 @@ local Utils = require("avante.utils") -local Clipboard = require("avante.clipboard") local P = require("avante.providers") ---@alias AvanteBedrockPayloadBuilder fun(prompt_opts: AvantePromptOptions, body_opts: table): table @@ -17,17 +16,14 @@ M.api_key_name = "BEDROCK_KEYS" M.use_xml_format = true M.load_model_handler = function() - local base, _ = P.parse_config(P["bedrock"]) - local bedrock_model = base.model - if base.model:match("anthropic") then bedrock_model = "claude" end + local provider_conf, _ = P.parse_config(P["bedrock"]) + local bedrock_model = provider_conf.model + if provider_conf.model:match("anthropic") then bedrock_model = "claude" end local ok, model_module = pcall(require, "avante.providers.bedrock." .. bedrock_model) - if ok then - return model_module - else - local error_msg = "Bedrock model handler not found: " .. bedrock_model - Utils.error(error_msg, { once = true, title = "Avante" }) - end + if ok then return model_module end + local error_msg = "Bedrock model handler not found: " .. bedrock_model + error(error_msg) end M.parse_response = function(ctx, data_stream, event_state, opts) @@ -46,8 +42,8 @@ M.parse_stream_data = function(data, opts) -- The `type` field in the decoded JSON determines how the response is handled. local bedrock_match = data:gmatch("event(%b{})") for bedrock_data_match in bedrock_match do - local data = vim.json.decode(bedrock_data_match) - local data_stream = vim.base64.decode(data.bytes) + local jsn = vim.json.decode(bedrock_data_match) + local data_stream = vim.base64.decode(jsn.bytes) local json = vim.json.decode(data_stream) M.parse_response({}, data_stream, json.type, opts) end @@ -60,10 +56,12 @@ M.parse_curl_args = function(provider, prompt_opts) local base, body_opts = P.parse_config(provider) local api_key = provider.parse_api_key() + if api_key == nil then error("Cannot get the bedrock api key!") end local parts = vim.split(api_key, ",") local aws_access_key_id = parts[1] local aws_secret_access_key = parts[2] local aws_region = parts[3] + local aws_session_token = parts[4] local endpoint = string.format( "https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", @@ -75,6 +73,8 @@ M.parse_curl_args = function(provider, prompt_opts) ["Content-Type"] = "application/json", } + if aws_session_token and aws_session_token ~= "" then headers["x-amz-security-token"] = aws_session_token end + local body_payload = M.build_bedrock_payload(prompt_opts, body_opts) local rawArgs = { @@ -105,7 +105,6 @@ M.on_error = function(result) end local error_msg = body.error.message - local error_type = body.error.type Utils.error(error_msg, { once = true, title = "Avante" }) end diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index ecc97dd..bf64fa3 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -226,7 +226,7 @@ end ---@param prompt_opts AvantePromptOptions ---@return table M.parse_curl_args = function(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(provider) local headers = { ["Content-Type"] = "application/json", @@ -234,7 +234,7 @@ M.parse_curl_args = function(provider, prompt_opts) ["anthropic-beta"] = "prompt-caching-2024-07-31", } - if P.env.require_api_key(base) then headers["x-api-key"] = provider.parse_api_key() end + if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end local messages = M.parse_messages(prompt_opts) @@ -246,12 +246,12 @@ M.parse_curl_args = function(provider, prompt_opts) end return { - url = Utils.url_join(base.endpoint, "/v1/messages"), - proxy = base.proxy, - insecure = base.allow_insecure, + url = Utils.url_join(provider_conf.endpoint, "/v1/messages"), + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = headers, body = vim.tbl_deep_extend("force", { - model = base.model, + model = provider_conf.model, system = { { type = "text", @@ -262,7 +262,7 @@ M.parse_curl_args = function(provider, prompt_opts) messages = messages, tools = tools, stream = true, - }, body_opts), + }, request_body), } end diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua index 059e23e..f6c7953 100644 --- a/lua/avante/providers/cohere.lua +++ b/lua/avante/providers/cohere.lua @@ -70,7 +70,7 @@ M.parse_stream_data = function(data, opts) end M.parse_curl_args = function(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(provider) local headers = { ["Accept"] = "application/json", @@ -82,17 +82,17 @@ M.parse_curl_args = function(provider, prompt_opts) .. "." .. vim.version().patch, } - if P.env.require_api_key(base) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end + if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end return { - url = Utils.url_join(base.endpoint, "/chat"), - proxy = base.proxy, - insecure = base.allow_insecure, + url = Utils.url_join(provider_conf.endpoint, "/chat"), + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = headers, body = vim.tbl_deep_extend("force", { - model = base.model, + model = provider_conf.model, stream = true, - }, M.parse_messages(prompt_opts), body_opts), + }, M.parse_messages(prompt_opts), request_body), } end diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index 47ec4b4..050b934 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -249,7 +249,7 @@ M.parse_curl_args = function(provider, prompt_opts) -- (this should rarely happen, as we refresh the token in the background) H.refresh_token(false, false) - local base, body_opts = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(provider) local tools = {} if prompt_opts.tools then @@ -259,10 +259,10 @@ M.parse_curl_args = function(provider, prompt_opts) end return { - url = H.chat_completion_url(base.endpoint), - timeout = base.timeout, - proxy = base.proxy, - insecure = base.allow_insecure, + url = H.chat_completion_url(provider_conf.endpoint), + timeout = provider_conf.timeout, + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = { ["Content-Type"] = "application/json", ["Authorization"] = "Bearer " .. M.state.github_token.token, @@ -270,11 +270,11 @@ M.parse_curl_args = function(provider, prompt_opts) ["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch), }, body = vim.tbl_deep_extend("force", { - model = base.model, + model = provider_conf.model, messages = M.parse_messages(prompt_opts), stream = true, tools = tools, - }, body_opts), + }, request_body), } end diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 1528b87..9a13346 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -82,26 +82,29 @@ M.parse_response = function(ctx, data_stream, _, opts) end M.parse_curl_args = function(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(provider) - body_opts = vim.tbl_deep_extend("force", body_opts, { + request_body = vim.tbl_deep_extend("force", request_body, { generationConfig = { - temperature = body_opts.temperature, - maxOutputTokens = body_opts.max_tokens, + temperature = request_body.temperature, + maxOutputTokens = request_body.max_tokens, }, }) - body_opts.temperature = nil - body_opts.max_tokens = nil + request_body.temperature = nil + request_body.max_tokens = nil local api_key = provider.parse_api_key() if api_key == nil then error("Cannot get the gemini api key!") end return { - url = Utils.url_join(base.endpoint, base.model .. ":streamGenerateContent?alt=sse&key=" .. api_key), - proxy = base.proxy, - insecure = base.allow_insecure, + url = Utils.url_join( + provider_conf.endpoint, + provider_conf.model .. ":streamGenerateContent?alt=sse&key=" .. api_key + ), + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = { ["Content-Type"] = "application/json" }, - body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts), + body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body), } end diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 3a30469..77009e7 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -64,6 +64,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field __inherited_from? string ---@field temperature? number ---@field max_tokens? number +---@field reasoning_effort? string --- ---@class AvanteLLMUsage ---@field input_tokens number @@ -347,9 +348,9 @@ M = setmetatable(M, { if t[k].has == nil then t[k].has = function() return E.parse_envvar(t[k]) ~= nil end end if t[k].setup == nil then - local base = M.parse_config(t[k]) + local provider_conf = M.parse_config(t[k]) t[k].setup = function() - if E.require_api_key(base) then t[k].parse_api_key() end + if E.require_api_key(provider_conf) then t[k].parse_api_key() end require("avante.tokenizers").setup(t[k].tokenizer_id) end end diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index d3b5bfe..7914bbd 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -223,7 +223,7 @@ M.parse_response = function(ctx, data_stream, _, opts) end ctx.last_think_content = choice.delta.reasoning opts.on_chunk(choice.delta.reasoning) - elseif choice.delta.tool_calls then + elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then local tool_call = choice.delta.tool_calls[1] if not ctx.tool_use_list then ctx.tool_use_list = {} end if not ctx.tool_use_list[tool_call.index + 1] then @@ -272,13 +272,14 @@ end local Log = require("avante.utils.log") M.parse_curl_args = function(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) - local disable_tools = base.disable_tools or false + local provider_conf, request_body = P.parse_config(provider) + local disable_tools = provider_conf.disable_tools or false local headers = { ["Content-Type"] = "application/json", } +<<<<<<< HEAD -- Add appid header for baidu provider if Config.provider == "baidu" then local baidu_config = Config.get_provider("baidu") @@ -289,6 +290,9 @@ M.parse_curl_args = function(provider, prompt_opts) end if P.env.require_api_key(base) then +======= + if P.env.require_api_key(provider_conf) then +>>>>>>> b6ae4dfe7fe443362f5f31d71797173ec12c2598 local api_key = provider.parse_api_key() if api_key == nil then error(Config.provider .. " API key is not set, please set it in your environment variable or config file") @@ -296,18 +300,18 @@ M.parse_curl_args = function(provider, prompt_opts) headers["Authorization"] = "Bearer " .. api_key end - if M.is_openrouter(base.endpoint) then + if M.is_openrouter(provider_conf.endpoint) then headers["HTTP-Referer"] = "https://github.com/yetone/avante.nvim" headers["X-Title"] = "Avante.nvim" - body_opts.include_reasoning = true + request_body.include_reasoning = true end -- NOTE: When using "o" series set the supported parameters only local stream = true - if M.is_o_series_model(base.model) then - body_opts.max_completion_tokens = body_opts.max_tokens - body_opts.max_tokens = nil - body_opts.temperature = 1 + if M.is_o_series_model(provider_conf.model) then + request_body.max_completion_tokens = request_body.max_tokens + request_body.max_tokens = nil + request_body.temperature = 1 end local tools = nil @@ -318,20 +322,20 @@ M.parse_curl_args = function(provider, prompt_opts) end end - Utils.debug("endpoint", base.endpoint) - Utils.debug("model", base.model) + Utils.debug("endpoint", provider_conf.endpoint) + Utils.debug("model", provider_conf.model) - local request = { - url = Utils.url_join(base.endpoint, "/chat/completions"), - proxy = base.proxy, - insecure = base.allow_insecure, + return { + url = Utils.url_join(provider_conf.endpoint, "/chat/completions"), + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = headers, body = vim.tbl_deep_extend("force", { - model = base.model, + model = provider_conf.model, messages = M.parse_messages(prompt_opts), stream = stream, tools = tools, - }, body_opts), + }, request_body), } -- 记录请求详细信息 Log.log_request(request.url, request.headers, request.body) diff --git a/lua/avante/providers/vertex.lua b/lua/avante/providers/vertex.lua index f1ca9ee..5273483 100644 --- a/lua/avante/providers/vertex.lua +++ b/lua/avante/providers/vertex.lua @@ -32,22 +32,22 @@ M.parse_api_key = function() end M.parse_curl_args = function(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(provider) local location = vim.fn.getenv("LOCATION") or "default-location" local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id" - local model_id = base.model or "default-model-id" - local url = base.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id) + local model_id = provider_conf.model or "default-model-id" + local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id) url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id) - body_opts = vim.tbl_deep_extend("force", body_opts, { + request_body = vim.tbl_deep_extend("force", request_body, { generationConfig = { - temperature = body_opts.temperature, - maxOutputTokens = body_opts.max_tokens, + temperature = request_body.temperature, + maxOutputTokens = request_body.max_tokens, }, }) - body_opts.temperature = nil - body_opts.max_tokens = nil + request_body.temperature = nil + request_body.max_tokens = nil local bearer_token = M.parse_api_key() return { @@ -56,9 +56,9 @@ M.parse_curl_args = function(provider, prompt_opts) ["Authorization"] = "Bearer " .. bearer_token, ["Content-Type"] = "application/json; charset=utf-8", }, - proxy = base.proxy, - insecure = base.allow_insecure, - body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts), + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, + body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body), } end diff --git a/lua/avante/repo_map.lua b/lua/avante/repo_map.lua index c639922..42334d8 100644 --- a/lua/avante/repo_map.lua +++ b/lua/avante/repo_map.lua @@ -29,21 +29,10 @@ end function RepoMap.setup() vim.defer_fn(RepoMap._init_repo_map_lib, 1000) end function RepoMap.get_ts_lang(filepath) - local filetype = RepoMap.get_filetype(filepath) + local filetype = Utils.get_filetype(filepath) return filetype_map[filetype] or filetype end -function RepoMap.get_filetype(filepath) - -- Some files are sometimes not detected correctly when buffer is not included - -- https://github.com/neovim/neovim/issues/27265 - - local buf = vim.api.nvim_create_buf(false, true) - local filetype = vim.filetype.match({ filename = filepath, buf = buf }) - vim.api.nvim_buf_delete(buf, { force = true }) - - return filetype -end - function RepoMap._build_repo_map(project_root, file_ext) local output = {} local gitignore_path = project_root .. "/.gitignore" @@ -70,7 +59,7 @@ function RepoMap._build_repo_map(project_root, file_ext) if definitions == "" then return end table.insert(output, { path = Utils.relative_path(filepath), - lang = RepoMap.get_filetype(filepath), + lang = Utils.get_filetype(filepath), defs = definitions, }) end) @@ -142,7 +131,7 @@ function RepoMap._get_repo_map(file_ext) if not found then table.insert(repo_map, { path = Utils.relative_path(abs_filepath), - lang = RepoMap.get_filetype(abs_filepath), + lang = Utils.get_filetype(abs_filepath), defs = definitions, }) end diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 39a426d..c4c1e78 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -57,6 +57,7 @@ function Sidebar:new(id) selected_files_container = nil, input_container = nil, file_selector = FileSelector:new(id), + is_generating = false, }, { __index = self }) end @@ -66,9 +67,24 @@ function Sidebar:delete_autocmds() end function Sidebar:reset() + -- clean up event handlers + if self.augroup then + api.nvim_del_augroup_by_id(self.augroup) + self.augroup = nil + end + + -- clean up keymaps self:unbind_apply_key() self:unbind_sidebar_keys() - self:delete_autocmds() + + -- clean up file selector events + if self.file_selector then self.file_selector:off("update") end + + if self.result_container then self.result_container:unmount() end + if self.selected_code_container then self.selected_code_container:unmount() end + if self.selected_files_container then self.selected_files_container:unmount() end + if self.input_container then self.input_container:unmount() end + self.code = { bufnr = 0, winid = 0, selection = nil } self.winids = { result_container = 0, selected_files_container = 0, selected_code_container = 0, input_container = 0 } @@ -200,24 +216,34 @@ local function transform_result_content(selected_files, result_content, prev_fil local current_filepath local i = 1 - while i <= #result_lines do + while true do + if i > #result_lines then break end local line_content = result_lines[i] - if line_content:match(".+") then - local filepath = line_content:match("(.+)") + if line_content:match("<[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>.+") then + local filepath = line_content:match("<[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>(.+)") if filepath then current_filepath = filepath table.insert(transformed_lines, string.format("Filepath: %s", filepath)) goto continue end end - if line_content == "" then + if line_content:match("^%s*<[Ss][Ee][Aa][Rr][Cc][Hh]>") then is_searching = true + + if not line_content:match("^%s*<[Ss][Ee][Aa][Rr][Cc][Hh]>%s*$") then + local search_start_line = line_content:match("<[Ss][Ee][Aa][Rr][Cc][Hh]>(.+)$") + line_content = "" + result_lines[i] = line_content + if search_start_line and search_start_line ~= "" then table.insert(result_lines, i + 1, search_start_line) end + end + line_content = "" + local prev_line = result_lines[i - 1] if prev_line and prev_filepath and not prev_line:match("Filepath:.+") - and not prev_line:match(".+") + and not prev_line:match("<[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>.+") then table.insert(transformed_lines, string.format("Filepath: %s", prev_filepath)) end @@ -225,7 +251,23 @@ local function transform_result_content(selected_files, result_content, prev_fil if next_line and next_line:match("^%s*```%w+$") then i = i + 1 end search_start = i + 1 last_search_tag_start_line = i - elseif line_content == "" then + elseif line_content:match("%s*$") then + if is_replacing then + result_lines[i] = line_content:gsub("", "") + goto continue_without_increment + end + + -- Handle case where is a suffix + if not line_content:match("^%s*%s*$") then + local search_end_line = line_content:match("^(.+)") + line_content = "" + result_lines[i] = line_content + if search_end_line and search_end_line ~= "" then + table.insert(result_lines, i, search_end_line) + goto continue_without_increment + end + end + is_searching = false local search_end = i @@ -248,24 +290,28 @@ local function transform_result_content(selected_files, result_content, prev_fil if not the_matched_file then if not PPath:new(filepath):exists() then - Utils.warn("File not found: " .. filepath) - goto continue + the_matched_file = { + filepath = filepath, + content = "", + file_type = nil, + } + else + if not PPath:new(filepath):is_file() then + Utils.warn("Not a file: " .. filepath) + goto continue + end + local lines = Utils.read_file_from_buf_or_disk(filepath) + if lines == nil then + Utils.warn("Failed to read file: " .. filepath) + goto continue + end + local content = table.concat(lines, "\n") + the_matched_file = { + filepath = filepath, + content = content, + file_type = nil, + } end - if not PPath:new(filepath):is_file() then - Utils.warn("Not a file: " .. filepath) - goto continue - end - local lines = Utils.read_file_from_buf_or_disk(filepath) - if lines == nil then - Utils.warn("Failed to read file: " .. filepath) - goto continue - end - local content = table.concat(lines, "\n") - the_matched_file = { - filepath = filepath, - content = content, - file_type = nil, - } end local file_content = vim.split(the_matched_file.content, "\n") @@ -292,8 +338,7 @@ local function transform_result_content(selected_files, result_content, prev_fil -- can happen if the llm tries to edit or create a file outside of it's context. if not match_filetype then local snippet_file_path = current_filepath or prev_filepath - local snippet_file_type = vim.filetype.match({ filename = snippet_file_path }) or "unknown" - match_filetype = snippet_file_type + match_filetype = Utils.get_filetype(snippet_file_path) end local search_start_tag_idx_in_transformed_lines = 0 @@ -311,13 +356,31 @@ local function transform_result_content(selected_files, result_content, prev_fil string.format("```%s", match_filetype), }) goto continue - elseif line_content == "" then + elseif line_content:match("^%s*<[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>") then is_replacing = true + if not line_content:match("^%s*<[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>%s*$") then + local replace_first_line = line_content:match("<[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>(.+)$") + line_content = "" + result_lines[i] = line_content + if replace_first_line and replace_first_line ~= "" then + table.insert(result_lines, i + 1, replace_first_line) + end + end local next_line = result_lines[i + 1] if next_line and next_line:match("^%s*```%w+$") then i = i + 1 end last_replace_tag_start_line = i goto continue - elseif line_content == "" then + elseif line_content:match("%s*$") then + -- Handle case where is a suffix + if not line_content:match("^%s*%s*$") then + local replace_end_line = line_content:match("^(.+)") + line_content = "" + result_lines[i] = line_content + if replace_end_line and replace_end_line ~= "" then + table.insert(result_lines, i, replace_end_line) + goto continue_without_increment + end + end is_replacing = false local prev_line = result_lines[i - 1] if not (prev_line and prev_line:match("^%s*```$")) then table.insert(transformed_lines, "```") end @@ -332,6 +395,7 @@ local function transform_result_content(selected_files, result_content, prev_fil table.insert(transformed_lines, line_content) ::continue:: i = i + 1 + ::continue_without_increment:: end return { @@ -397,8 +461,8 @@ local function get_searching_hint() end local thinking_spinner_chars = { - "🤯", - "🙄", + Utils.icon("🤯", "?"), + Utils.icon("🙄", "¿"), } local thinking_spinner_index = 1 @@ -437,8 +501,10 @@ local function generate_display_content(replacement) return string.format(" > %s", line) end) :totable() - local result_lines = - vim.list_extend(vim.list_slice(lines, 1, replacement.last_search_tag_start_line), { "🤔 Thought content:" }) + local result_lines = vim.list_extend( + vim.list_slice(lines, 1, replacement.last_search_tag_start_line), + { Utils.icon("🤔 ") .. "Thought content:" } + ) result_lines = vim.list_extend(result_lines, formatted_thinking_content_lines) result_lines = vim.list_extend(result_lines, vim.list_slice(lines, last_think_tag_end_line + 1)) return table.concat(result_lines, "\n") @@ -695,28 +761,22 @@ local function minimize_snippet(original_lines, snippet) return new_snippets end ----@param snippets_map table +---@param filepath string +---@param snippets AvanteCodeSnippet[] ---@return table -function Sidebar:minimize_snippets(snippets_map) +function Sidebar:minimize_snippets(filepath, snippets) local original_lines = {} - if vim.tbl_count(snippets_map) > 0 then - local filepaths = vim.tbl_keys(snippets_map) - local original_lines_, _, err = Utils.read_file_from_buf_or_disk(filepaths[1]) - if err ~= nil then return {} end - if original_lines_ then original_lines = original_lines_ end - end + local original_lines_ = Utils.read_file_from_buf_or_disk(filepath) + if original_lines_ then original_lines = original_lines_ end local results = {} - for filepath, snippets in pairs(snippets_map) do - for _, snippet in ipairs(snippets) do - local new_snippets = minimize_snippet(original_lines, snippet) - if new_snippets then - results[filepath] = results[filepath] or {} - for _, new_snippet in ipairs(new_snippets) do - table.insert(results[filepath], new_snippet) - end + for _, snippet in ipairs(snippets) do + local new_snippets = minimize_snippet(original_lines, snippet) + if new_snippets then + for _, new_snippet in ipairs(new_snippets) do + table.insert(results, new_snippet) end end end @@ -749,12 +809,13 @@ function Sidebar:apply(current_cursor) selected_snippets_map = all_snippets_map end - if Config.behaviour.minimize_diff then selected_snippets_map = self:minimize_snippets(selected_snippets_map) end - vim.defer_fn(function() api.nvim_set_current_win(self.code.winid) for filepath, snippets in pairs(selected_snippets_map) do + if Config.behaviour.minimize_diff then snippets = self:minimize_snippets(filepath, snippets) end local bufnr = Utils.get_or_create_buffer_with_filepath(filepath) + local path_ = PPath:new(filepath) + path_:parent():mkdir({ parents = true, exists_ok = true }) insert_conflict_contents(bufnr, snippets) local process = function(winid) api.nvim_set_current_win(winid) @@ -845,7 +906,7 @@ function Sidebar:render_result() then return end - local header_text = "󰭻 Avante" + local header_text = Utils.icon("󰭻 ") .. "Avante" self:render_header( self.result_container.winid, self.result_container.bufnr, @@ -867,13 +928,15 @@ function Sidebar:render_input(ask) end local header_text = string.format( - "󱜸 %s (" .. Config.mappings.sidebar.switch_windows .. ": switch focus)", + "%s%s (" .. Config.mappings.sidebar.switch_windows .. ": switch focus)", + Utils.icon("󱜸 "), ask and "Ask" or "Chat with" ) if self.code.selection ~= nil then header_text = string.format( - "󱜸 %s (%d:%d) (: switch focus)", + "%s%s (%d:%d) (: switch focus)", + Utils.icon("󱜸 "), ask and "Ask" or "Chat with", self.code.selection.range.start.lnum, self.code.selection.range.finish.lnum @@ -906,7 +969,8 @@ function Sidebar:render_selected_code() selected_code_lines_count = #selected_code_lines end - local header_text = " Selected Code" + local header_text = Utils.icon(" ") + .. "Selected Code" .. ( selected_code_lines_count > selected_code_max_lines_count and " (Show only the first " .. tostring(selected_code_max_lines_count) .. " lines)" @@ -1312,6 +1376,20 @@ function Sidebar:initialize() return self end +function Sidebar:is_focused() + if not self:is_open() then return false end + + local current_winid = api.nvim_get_current_win() + if self.winids.result_container and self.winids.result_container == current_winid then return true end + if self.winids.selected_files_container and self.winids.selected_files_container == current_winid then + return true + end + if self.winids.selected_code_container and self.winids.selected_code_container == current_winid then return true end + if self.winids.input_container and self.winids.input_container == current_winid then return true end + + return false +end + function Sidebar:is_focused_on_result() return self:is_open() and self.result_container and self.result_container.winid == api.nvim_get_current_win() end @@ -1892,6 +1970,8 @@ function Sidebar:create_input_container(opts) ---@type AvanteLLMChunkCallback local on_chunk = function(chunk) + self.is_generating = true + original_response = original_response .. chunk local selected_files = self.file_selector:get_selected_files_contents() @@ -1926,6 +2006,8 @@ function Sidebar:create_input_container(opts) ---@type AvanteLLMStopCallback local on_stop = function(stop_opts) + self.is_generating = false + pcall(function() ---remove keymaps vim.keymap.del("n", "j", { buffer = self.result_container.bufnr }) @@ -2039,6 +2121,7 @@ function Sidebar:create_input_container(opts) local request = table.concat(lines, "\n") if request == "" then return end api.nvim_buf_set_lines(self.input_container.bufnr, 0, -1, false, {}) + api.nvim_win_set_cursor(self.input_container.winid, { 1, 0 }) handle_submit(request) end @@ -2454,7 +2537,7 @@ function Sidebar:create_selected_files_container() self:render_header( self.selected_files_container.winid, selected_files_buf, - " Selected Files", + Utils.icon(" ") .. "Selected Files", Highlights.SUBTITLE, Highlights.REVERSED_SUBTITLE ) diff --git a/lua/avante/suggestion.lua b/lua/avante/suggestion.lua index 37c521b..10451f3 100644 --- a/lua/avante/suggestion.lua +++ b/lua/avante/suggestion.lua @@ -84,8 +84,8 @@ function Suggestion:suggest() L1: def fib L2: L3: if __name__ == "__main__": -L4: # just pass -L5: pass +L4: # just pass +L5: pass ]], }, @@ -95,7 +95,7 @@ L5: pass }, { role = "user", - content = '{ "indentSize": 4, "position": { "row": 1, "col": 2 } }', + content = '{"insertSpaces":true,"tabSize":4,"indentSize":4,"position":{"row":1,"col":7}}', }, { role = "assistant", diff --git a/lua/avante/templates/base.avanterules b/lua/avante/templates/base.avanterules index 06a2efe..c96886e 100644 --- a/lua/avante/templates/base.avanterules +++ b/lua/avante/templates/base.avanterules @@ -10,10 +10,17 @@ Act as an expert software developer. Always use best practices when coding. Respect and use existing conventions, libraries, etc that are already present in the code base. -You have access to tools, but only use them when necessary. If a tool is not required, respond as normal. -If you encounter a URL, prioritize using the fetch tool to obtain its content. -If you have information that you don't know, please proactively use the tools provided by users! Especially the web search tool. -When available tools cannot meet the requirements, please try to use the `run_command` tool to solve the problem whenever possible. +Don't directly search for code context in historical messages. Instead, prioritize using tools to obtain context first, then use context from historical messages as a secondary source, since context from historical messages is often not up to date. + +Tools Usage Guide: + - You have access to tools, but only use them when necessary. If a tool is not required, respond as normal. + - If you encounter a URL, prioritize using the fetch tool to obtain its content. + - If you have information that you don't know, please proactively use the tools provided by users! Especially the web search tool. + - When available tools cannot meet the requirements, please try to use the `run_command` tool to solve the problem whenever possible. + - When attempting to modify a file that is not in the context, please first use the `list_files` tool and `search_files` tool to check if the file you want to modify exists, then use the `read_file` tool to read the file content. Don't modify blindly! + - When generating files, first use `list_files` tool to read the directory structure, don't generate blindly! + - When creating files, first check if the directory exists. If it doesn't exist, create the directory before creating the file. + - After `web_search`, if you don't get detailed enough information, do not continue use `web_search`, just continue using the `fetch` tool to get more information you need from the links in the search results. {% if system_info -%} Use the appropriate shell based on the user's system info: diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 8921c05..fef7648 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -665,6 +665,19 @@ function M.scan_directory_respect_gitignore(options) local directory = options.directory local gitignore_path = directory .. "/.gitignore" local gitignore_patterns, gitignore_negate_patterns = M.parse_gitignore(gitignore_path) + + -- Convert relative paths in gitignore to absolute paths based on project root + local project_root = M.get_project_root() + local function to_absolute_path(pattern) + -- Skip if already absolute path + if pattern:sub(1, 1) == "/" then return pattern end + -- Convert relative path to absolute + return Path:new(project_root, pattern):absolute() + end + + gitignore_patterns = vim.tbl_map(to_absolute_path, gitignore_patterns) + gitignore_negate_patterns = vim.tbl_map(to_absolute_path, gitignore_negate_patterns) + return M.scan_directory({ directory = directory, gitignore_patterns = gitignore_patterns, @@ -890,9 +903,19 @@ function M.is_same_file(filepath_a, filepath_b) return M.uniform_path(filepath_a function M.trim_think_content(content) return content:gsub("^.-", "", 1) end +function M.get_filetype(filepath) + -- Some files are sometimes not detected correctly when buffer is not included + -- https://github.com/neovim/neovim/issues/27265 + + local buf = vim.api.nvim_create_buf(false, true) + local filetype = vim.filetype.match({ filename = filepath, buf = buf }) or "" + vim.api.nvim_buf_delete(buf, { force = true }) + + return filetype +end + ---@param file_path string ---@return string[]|nil lines ----@return string|nil file_type ---@return string|nil error function M.read_file_from_buf_or_disk(file_path) --- Lookup if the file is loaded in a buffer @@ -900,8 +923,7 @@ function M.read_file_from_buf_or_disk(file_path) if bufnr ~= -1 and vim.api.nvim_buf_is_loaded(bufnr) then -- If buffer exists and is loaded, get buffer content local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local file_type = vim.api.nvim_get_option_value("filetype", { buf = bufnr }) - return lines, file_type, nil + return lines, nil end -- Fallback: read file from disk @@ -909,12 +931,28 @@ function M.read_file_from_buf_or_disk(file_path) if file then local content = file:read("*all") file:close() - -- Detect the file type using the specific file's content - local file_type = vim.filetype.match({ filename = file_path, contents = { content } }) or "unknown" - return vim.split(content, "\n"), file_type, nil + return vim.split(content, "\n"), nil else - M.error("failed to open file: " .. file_path .. " with error: " .. open_err) - return {}, nil, open_err + -- M.error("failed to open file: " .. file_path .. " with error: " .. open_err) + return {}, open_err + end +end + +---Check if an icon plugin is installed +---@return boolean +M.icons_enabled = function() return M.has("nvim-web-devicons") or M.has("mini.icons") or M.has("mini.nvim") end + +---Display an string with icon, if an icon plugin is available. +---Dev icons are an optional install for avante, this function prevents ugly chars +---being displayed by displaying fallback options or nothing at all. +---@param string_with_icon string +---@param utf8_fallback string|nil +---@return string +M.icon = function(string_with_icon, utf8_fallback) + if M.icons_enabled() then + return string_with_icon + else + return utf8_fallback or "" end end