diff --git a/lua/avante/config.lua b/lua/avante/config.lua index f959286..9314d21 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -396,6 +396,7 @@ M.BASE_PROVIDER_KEYS = { "use_xml_format", "role_map", "__inherited_from", + "disable_tools", } return M diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 9a843cf..7501bae 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -116,18 +116,16 @@ M.generate_prompts = function(opts) messages = messages, image_paths = image_paths, tools = opts.tools, - tool_use = opts.tool_use, - tool_result = opts.tool_result, - response_content = opts.response_content, + tool_histories = opts.tool_histories, } end ---@param opts GeneratePromptsOptions ---@return integer M.calculate_tokens = function(opts) - local code_opts = M.generate_prompts(opts) - local tokens = Utils.tokens.calculate_tokens(code_opts.system_prompt) - for _, message in ipairs(code_opts.messages) do + local prompt_opts = M.generate_prompts(opts) + local tokens = Utils.tokens.calculate_tokens(prompt_opts.system_prompt) + for _, message in ipairs(prompt_opts.messages) do tokens = tokens + Utils.tokens.calculate_tokens(message.content) end return tokens @@ -137,7 +135,7 @@ end M._stream = function(opts) local Provider = opts.provider or P[Config.provider] - local code_opts = M.generate_prompts(opts) + local prompt_opts = M.generate_prompts(opts) ---@type string local current_event_state = nil @@ -154,11 +152,14 @@ M._stream = function(opts) content = error ~= nil and error or result, is_error = error ~= nil, } - local new_opts = vim.tbl_deep_extend( - "force", - opts, + local old_tool_histories = vim.deepcopy(opts.tool_histories) or {} + table.insert( + old_tool_histories, { tool_result = tool_result, tool_use = stop_opts.tool_use, response_content = stop_opts.response_content } ) + local new_opts = vim.tbl_deep_extend("force", opts, { + tool_histories = old_tool_histories, + }) return M._stream(new_opts) end return opts.on_stop(stop_opts) @@ -166,7 +167,7 @@ M._stream = function(opts) } ---@type AvanteCurlOutput - local spec = Provider.parse_curl_args(Provider, code_opts) + local spec = Provider.parse_curl_args(Provider, prompt_opts) local resp_ctx = {} @@ -310,7 +311,7 @@ local function _merge_response(first_response, second_response, opts) prompt = prompt .. "\n" - -- append this reference prompt to the code_opts messages at last + -- append this reference prompt to the prompt_opts messages at last opts.instructions = opts.instructions .. prompt M._stream(opts) @@ -412,6 +413,9 @@ end ---@field mode LlmMode ---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil ---@field tools? AvanteLLMTool[] +---@field tool_histories? AvanteLLMToolHistory[] +--- +---@class AvanteLLMToolHistory ---@field tool_result? AvanteLLMToolResult ---@field tool_use? AvanteLLMToolUse ---@field response_content? string diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua index 49ee887..83ca617 100644 --- a/lua/avante/providers/azure.lua +++ b/lua/avante/providers/azure.lua @@ -17,7 +17,7 @@ M.parse_messages = O.parse_messages M.parse_response = O.parse_response M.parse_response_without_stream = O.parse_response_without_stream -M.parse_curl_args = function(provider, code_opts) +M.parse_curl_args = function(provider, prompt_opts) local base, body_opts = P.parse_config(provider) local headers = { @@ -40,7 +40,7 @@ M.parse_curl_args = function(provider, code_opts) insecure = base.allow_insecure, headers = headers, body = vim.tbl_deep_extend("force", { - messages = M.parse_messages(code_opts), + messages = M.parse_messages(prompt_opts), stream = true, }, body_opts), } diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 300e671..3e48d92 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -112,38 +112,42 @@ M.parse_messages = function(opts) messages[#messages].content = message_content end - if opts.tool_use then - local msg = { - role = "assistant", - content = {}, - } - if opts.response_content then - msg.content[#msg.content + 1] = { - type = "text", - text = opts.response_content, - } - end - msg.content[#msg.content + 1] = { - type = "tool_use", - id = opts.tool_use.id, - name = opts.tool_use.name, - input = vim.json.decode(opts.tool_use.input_json), - } - messages[#messages + 1] = msg - end + if opts.tool_histories then + for _, tool_history in ipairs(opts.tool_histories) do + if tool_history.tool_use then + local msg = { + role = "assistant", + content = {}, + } + if tool_history.response_content then + msg.content[#msg.content + 1] = { + type = "text", + text = tool_history.response_content, + } + end + msg.content[#msg.content + 1] = { + type = "tool_use", + id = tool_history.tool_use.id, + name = tool_history.tool_use.name, + input = vim.json.decode(tool_history.tool_use.input_json), + } + messages[#messages + 1] = msg + end - if opts.tool_result then - messages[#messages + 1] = { - role = "user", - content = { - { - type = "tool_result", - tool_use_id = opts.tool_result.tool_use_id, - content = opts.tool_result.content, - is_error = opts.tool_result.is_error, - }, - }, - } + if tool_history.tool_result then + messages[#messages + 1] = { + role = "user", + content = { + { + type = "tool_result", + tool_use_id = tool_history.tool_result.tool_use_id, + content = tool_history.tool_result.content, + is_error = tool_history.tool_result.is_error, + }, + }, + } + end + end end return messages diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua index e6f787e..059e23e 100644 --- a/lua/avante/providers/cohere.lua +++ b/lua/avante/providers/cohere.lua @@ -69,7 +69,7 @@ M.parse_stream_data = function(data, opts) end end -M.parse_curl_args = function(provider, code_opts) +M.parse_curl_args = function(provider, prompt_opts) local base, body_opts = P.parse_config(provider) local headers = { @@ -92,7 +92,7 @@ M.parse_curl_args = function(provider, code_opts) body = vim.tbl_deep_extend("force", { model = base.model, stream = true, - }, M.parse_messages(code_opts), body_opts), + }, M.parse_messages(prompt_opts), body_opts), } end diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index 47305ff..dd47820 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -241,7 +241,7 @@ end M.parse_response = OpenAI.parse_response -M.parse_curl_args = function(provider, code_opts) +M.parse_curl_args = function(provider, prompt_opts) -- refresh token synchronously, only if it has expired -- (this should rarely happen, as we refresh the token in the background) H.refresh_token(false, false) @@ -249,8 +249,8 @@ M.parse_curl_args = function(provider, code_opts) local base, body_opts = P.parse_config(provider) local tools = {} - if code_opts.tools then - for _, tool in ipairs(code_opts.tools) do + if prompt_opts.tools then + for _, tool in ipairs(prompt_opts.tools) do table.insert(tools, OpenAI.transform_tool(tool)) end end @@ -268,7 +268,7 @@ M.parse_curl_args = function(provider, code_opts) }, body = vim.tbl_deep_extend("force", { model = base.model, - messages = M.parse_messages(code_opts), + messages = M.parse_messages(prompt_opts), stream = true, tools = tools, }, body_opts), diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index af63565..b357093 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -81,7 +81,7 @@ M.parse_response = function(ctx, data_stream, _, opts) end end -M.parse_curl_args = function(provider, code_opts) +M.parse_curl_args = function(provider, prompt_opts) local base, body_opts = P.parse_config(provider) body_opts = vim.tbl_deep_extend("force", body_opts, { @@ -101,7 +101,7 @@ M.parse_curl_args = function(provider, code_opts) proxy = base.proxy, insecure = base.allow_insecure, headers = { ["Content-Type"] = "application/json" }, - body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts), + body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts), } end diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index ba9848c..f9509ab 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -30,9 +30,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field messages AvanteLLMMessage[] ---@field image_paths? string[] ---@field tools? AvanteLLMTool[] ----@field tool_result? AvanteLLMToolResult ----@field tool_use? AvanteLLMToolUse ----@field response_content? string +---@field tool_histories? AvanteLLMToolHistory[] --- ---@class AvanteGeminiMessage ---@field role "user" @@ -43,7 +41,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[] --- ---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table | string, headers: table, rawArgs: string[] | nil} ----@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput +---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput --- ---@class ResponseParser ---@field on_start AvanteLLMStartCallback @@ -60,6 +58,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field allow_insecure? boolean ---@field api_key_name? string ---@field _shellenv? string +---@field disable_tools? boolean --- ---@class AvanteSupportedProvider: AvanteDefaultBaseProvider ---@field __inherited_from? string @@ -382,26 +381,31 @@ function M.refresh(provider) end ---@param opts AvanteProvider | AvanteSupportedProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor ----@return AvanteDefaultBaseProvider, table +---@return AvanteDefaultBaseProvider provider_opts +---@return table request_body M.parse_config = function(opts) ---@type AvanteDefaultBaseProvider - local s1 = {} + local provider_opts = {} ---@type table - local s2 = {} + local request_body = {} for key, value in pairs(opts) do if vim.tbl_contains(Config.BASE_PROVIDER_KEYS, key) then - s1[key] = value + provider_opts[key] = value else - s2[key] = value + request_body[key] = value end end - return s1, - vim.iter(s2):filter(function(_, v) return type(v) ~= "function" end):fold({}, function(acc, k, v) + request_body = vim + .iter(request_body) + :filter(function(_, v) return type(v) ~= "function" end) + :fold({}, function(acc, k, v) acc[k] = v return acc end) + + return provider_opts, request_body end ---@private diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 1ff6cf8..80647ff 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -167,26 +167,28 @@ M.parse_messages = function(opts) table.insert(final_messages, { role = M.role_map[role] or role, content = message.content }) end) - if opts.tool_result then - table.insert(final_messages, { - role = M.role_map["assistant"], - tool_calls = { - { - id = opts.tool_use.id, - type = "function", - ["function"] = { - name = opts.tool_use.name, - arguments = opts.tool_use.input_json, + if opts.tool_histories then + for _, tool_history in ipairs(opts.tool_histories) do + table.insert(final_messages, { + role = M.role_map["assistant"], + tool_calls = { + { + id = tool_history.tool_use.id, + type = "function", + ["function"] = { + name = tool_history.tool_use.name, + arguments = tool_history.tool_use.input_json, + }, }, }, - }, - }) - local result_content = opts.tool_result.content or "" - table.insert(final_messages, { - role = "tool", - tool_call_id = opts.tool_result.tool_use_id, - content = opts.tool_result.is_error and "Error: " .. result_content or result_content, - }) + }) + local result_content = tool_history.tool_result.content or "" + table.insert(final_messages, { + role = "tool", + tool_call_id = tool_history.tool_result.tool_use_id, + content = tool_history.tool_result.is_error and "Error: " .. result_content or result_content, + }) + end end return final_messages @@ -269,8 +271,9 @@ M.parse_response_without_stream = function(data, _, opts) end end -M.parse_curl_args = function(provider, code_opts) +M.parse_curl_args = function(provider, prompt_opts) local base, body_opts = P.parse_config(provider) + local disable_tools = base.disable_tools or false local headers = { ["Content-Type"] = "application/json", @@ -298,9 +301,10 @@ M.parse_curl_args = function(provider, code_opts) body_opts.temperature = 1 end - local tools = {} - if code_opts.tools then - for _, tool in ipairs(code_opts.tools) do + local tools = nil + if not disable_tools and prompt_opts.tools then + tools = {} + for _, tool in ipairs(prompt_opts.tools) do table.insert(tools, M.transform_tool(tool)) end end @@ -315,7 +319,7 @@ M.parse_curl_args = function(provider, code_opts) headers = headers, body = vim.tbl_deep_extend("force", { model = base.model, - messages = M.parse_messages(code_opts), + messages = M.parse_messages(prompt_opts), stream = stream, tools = tools, }, body_opts), diff --git a/lua/avante/providers/vertex.lua b/lua/avante/providers/vertex.lua index f28e8a2..f1ca9ee 100644 --- a/lua/avante/providers/vertex.lua +++ b/lua/avante/providers/vertex.lua @@ -31,7 +31,7 @@ M.parse_api_key = function() return direct_output end -M.parse_curl_args = function(provider, code_opts) +M.parse_curl_args = function(provider, prompt_opts) local base, body_opts = P.parse_config(provider) local location = vim.fn.getenv("LOCATION") or "default-location" local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id" @@ -58,7 +58,7 @@ M.parse_curl_args = function(provider, code_opts) }, proxy = base.proxy, insecure = base.allow_insecure, - body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts), + body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts), } end diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index a55c394..a3d2f22 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -4,6 +4,7 @@ local fn = vim.fn local Split = require("nui.split") local event = require("nui.utils.autocmd").event +local PPath = require("plenary.path") local Provider = require("avante.providers") local Path = require("avante.path") local Config = require("avante.config") @@ -236,28 +237,54 @@ local function transform_result_content(selected_files, result_content, prev_fil local end_line = 0 local match_filetype = nil local filepath = current_filepath or prev_filepath or "" + ---@type {path: string, content: string, file_type: string | nil} | nil + local the_matched_file = nil for _, file in ipairs(selected_files) do - if not Utils.is_same_file(file.path, filepath) then goto continue1 end - local file_content = vim.split(file.content, "\n") - if start_line ~= 0 or end_line ~= 0 then break end - for j = 1, #file_content - (search_end - search_start) + 1 do - local match = true - for k = 0, search_end - search_start - 1 do - if - Utils.remove_indentation(file_content[j + k]) ~= Utils.remove_indentation(result_lines[search_start + k]) - then - match = false - break - end - end - if match then - start_line = j - end_line = j + (search_end - search_start) - 1 - match_filetype = file.file_type + if Utils.is_same_file(file.path, filepath) then + the_matched_file = file + break + end + end + + if not the_matched_file then + if not PPath:new(filepath):exists() then + Utils.warn("File not found: " .. filepath) + goto continue + end + if not PPath:new(filepath):is_file() then + Utils.warn("Not a file: " .. filepath) + goto continue + end + local content = Utils.file.read_content(filepath) + if content == nil then + Utils.warn("Failed to read file: " .. filepath) + goto continue + end + the_matched_file = { + filepath = filepath, + content = content, + file_type = nil, + } + end + + local file_content = vim.split(the_matched_file.content, "\n") + if start_line ~= 0 or end_line ~= 0 then break end + for j = 1, #file_content - (search_end - search_start) + 1 do + local match = true + for k = 0, search_end - search_start - 1 do + if + Utils.remove_indentation(file_content[j + k]) ~= Utils.remove_indentation(result_lines[search_start + k]) + then + match = false break end end - ::continue1:: + if match then + start_line = j + end_line = j + (search_end - search_start) - 1 + match_filetype = the_matched_file.file_type + break + end end -- when the filetype isn't detected, fallback to matching based on filepath. @@ -1758,7 +1785,7 @@ function Sidebar:create_input_container(opts) vim.keymap.set("n", "G", on_G, { buffer = self.result_container.bufnr }) ---@type AvanteLLMStartCallback - local on_start = function(start_opts) end + local on_start = function(_) end ---@type AvanteLLMChunkCallback local on_chunk = function(chunk)