fix: multiple tool use histories and disable tools (#1185)

This commit is contained in:
yetone 2025-02-06 02:46:52 +08:00 committed by GitHub
parent e1125fca54
commit d1cc23fa54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 152 additions and 108 deletions

View File

@ -396,6 +396,7 @@ M.BASE_PROVIDER_KEYS = {
"use_xml_format", "use_xml_format",
"role_map", "role_map",
"__inherited_from", "__inherited_from",
"disable_tools",
} }
return M return M

View File

@ -116,18 +116,16 @@ M.generate_prompts = function(opts)
messages = messages, messages = messages,
image_paths = image_paths, image_paths = image_paths,
tools = opts.tools, tools = opts.tools,
tool_use = opts.tool_use, tool_histories = opts.tool_histories,
tool_result = opts.tool_result,
response_content = opts.response_content,
} }
end end
---@param opts GeneratePromptsOptions ---@param opts GeneratePromptsOptions
---@return integer ---@return integer
M.calculate_tokens = function(opts) M.calculate_tokens = function(opts)
local code_opts = M.generate_prompts(opts) local prompt_opts = M.generate_prompts(opts)
local tokens = Utils.tokens.calculate_tokens(code_opts.system_prompt) local tokens = Utils.tokens.calculate_tokens(prompt_opts.system_prompt)
for _, message in ipairs(code_opts.messages) do for _, message in ipairs(prompt_opts.messages) do
tokens = tokens + Utils.tokens.calculate_tokens(message.content) tokens = tokens + Utils.tokens.calculate_tokens(message.content)
end end
return tokens return tokens
@ -137,7 +135,7 @@ end
M._stream = function(opts) M._stream = function(opts)
local Provider = opts.provider or P[Config.provider] local Provider = opts.provider or P[Config.provider]
local code_opts = M.generate_prompts(opts) local prompt_opts = M.generate_prompts(opts)
---@type string ---@type string
local current_event_state = nil local current_event_state = nil
@ -154,11 +152,14 @@ M._stream = function(opts)
content = error ~= nil and error or result, content = error ~= nil and error or result,
is_error = error ~= nil, is_error = error ~= nil,
} }
local new_opts = vim.tbl_deep_extend( local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
"force", table.insert(
opts, old_tool_histories,
{ tool_result = tool_result, tool_use = stop_opts.tool_use, response_content = stop_opts.response_content } { tool_result = tool_result, tool_use = stop_opts.tool_use, response_content = stop_opts.response_content }
) )
local new_opts = vim.tbl_deep_extend("force", opts, {
tool_histories = old_tool_histories,
})
return M._stream(new_opts) return M._stream(new_opts)
end end
return opts.on_stop(stop_opts) return opts.on_stop(stop_opts)
@ -166,7 +167,7 @@ M._stream = function(opts)
} }
---@type AvanteCurlOutput ---@type AvanteCurlOutput
local spec = Provider.parse_curl_args(Provider, code_opts) local spec = Provider.parse_curl_args(Provider, prompt_opts)
local resp_ctx = {} local resp_ctx = {}
@ -310,7 +311,7 @@ local function _merge_response(first_response, second_response, opts)
prompt = prompt .. "\n" prompt = prompt .. "\n"
-- append this reference prompt to the code_opts messages at last -- append this reference prompt to the prompt_opts messages at last
opts.instructions = opts.instructions .. prompt opts.instructions = opts.instructions .. prompt
M._stream(opts) M._stream(opts)
@ -412,6 +413,9 @@ end
---@field mode LlmMode ---@field mode LlmMode
---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil ---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil
---@field tools? AvanteLLMTool[] ---@field tools? AvanteLLMTool[]
---@field tool_histories? AvanteLLMToolHistory[]
---
---@class AvanteLLMToolHistory
---@field tool_result? AvanteLLMToolResult ---@field tool_result? AvanteLLMToolResult
---@field tool_use? AvanteLLMToolUse ---@field tool_use? AvanteLLMToolUse
---@field response_content? string ---@field response_content? string

View File

@ -17,7 +17,7 @@ M.parse_messages = O.parse_messages
M.parse_response = O.parse_response M.parse_response = O.parse_response
M.parse_response_without_stream = O.parse_response_without_stream M.parse_response_without_stream = O.parse_response_without_stream
M.parse_curl_args = function(provider, code_opts) M.parse_curl_args = function(provider, prompt_opts)
local base, body_opts = P.parse_config(provider) local base, body_opts = P.parse_config(provider)
local headers = { local headers = {
@ -40,7 +40,7 @@ M.parse_curl_args = function(provider, code_opts)
insecure = base.allow_insecure, insecure = base.allow_insecure,
headers = headers, headers = headers,
body = vim.tbl_deep_extend("force", { body = vim.tbl_deep_extend("force", {
messages = M.parse_messages(code_opts), messages = M.parse_messages(prompt_opts),
stream = true, stream = true,
}, body_opts), }, body_opts),
} }

View File

@ -112,38 +112,42 @@ M.parse_messages = function(opts)
messages[#messages].content = message_content messages[#messages].content = message_content
end end
if opts.tool_use then if opts.tool_histories then
local msg = { for _, tool_history in ipairs(opts.tool_histories) do
role = "assistant", if tool_history.tool_use then
content = {}, local msg = {
} role = "assistant",
if opts.response_content then content = {},
msg.content[#msg.content + 1] = { }
type = "text", if tool_history.response_content then
text = opts.response_content, msg.content[#msg.content + 1] = {
} type = "text",
end text = tool_history.response_content,
msg.content[#msg.content + 1] = { }
type = "tool_use", end
id = opts.tool_use.id, msg.content[#msg.content + 1] = {
name = opts.tool_use.name, type = "tool_use",
input = vim.json.decode(opts.tool_use.input_json), id = tool_history.tool_use.id,
} name = tool_history.tool_use.name,
messages[#messages + 1] = msg input = vim.json.decode(tool_history.tool_use.input_json),
end }
messages[#messages + 1] = msg
end
if opts.tool_result then if tool_history.tool_result then
messages[#messages + 1] = { messages[#messages + 1] = {
role = "user", role = "user",
content = { content = {
{ {
type = "tool_result", type = "tool_result",
tool_use_id = opts.tool_result.tool_use_id, tool_use_id = tool_history.tool_result.tool_use_id,
content = opts.tool_result.content, content = tool_history.tool_result.content,
is_error = opts.tool_result.is_error, is_error = tool_history.tool_result.is_error,
}, },
}, },
} }
end
end
end end
return messages return messages

View File

@ -69,7 +69,7 @@ M.parse_stream_data = function(data, opts)
end end
end end
M.parse_curl_args = function(provider, code_opts) M.parse_curl_args = function(provider, prompt_opts)
local base, body_opts = P.parse_config(provider) local base, body_opts = P.parse_config(provider)
local headers = { local headers = {
@ -92,7 +92,7 @@ M.parse_curl_args = function(provider, code_opts)
body = vim.tbl_deep_extend("force", { body = vim.tbl_deep_extend("force", {
model = base.model, model = base.model,
stream = true, stream = true,
}, M.parse_messages(code_opts), body_opts), }, M.parse_messages(prompt_opts), body_opts),
} }
end end

View File

@ -241,7 +241,7 @@ end
M.parse_response = OpenAI.parse_response M.parse_response = OpenAI.parse_response
M.parse_curl_args = function(provider, code_opts) M.parse_curl_args = function(provider, prompt_opts)
-- refresh token synchronously, only if it has expired -- refresh token synchronously, only if it has expired
-- (this should rarely happen, as we refresh the token in the background) -- (this should rarely happen, as we refresh the token in the background)
H.refresh_token(false, false) H.refresh_token(false, false)
@ -249,8 +249,8 @@ M.parse_curl_args = function(provider, code_opts)
local base, body_opts = P.parse_config(provider) local base, body_opts = P.parse_config(provider)
local tools = {} local tools = {}
if code_opts.tools then if prompt_opts.tools then
for _, tool in ipairs(code_opts.tools) do for _, tool in ipairs(prompt_opts.tools) do
table.insert(tools, OpenAI.transform_tool(tool)) table.insert(tools, OpenAI.transform_tool(tool))
end end
end end
@ -268,7 +268,7 @@ M.parse_curl_args = function(provider, code_opts)
}, },
body = vim.tbl_deep_extend("force", { body = vim.tbl_deep_extend("force", {
model = base.model, model = base.model,
messages = M.parse_messages(code_opts), messages = M.parse_messages(prompt_opts),
stream = true, stream = true,
tools = tools, tools = tools,
}, body_opts), }, body_opts),

View File

@ -81,7 +81,7 @@ M.parse_response = function(ctx, data_stream, _, opts)
end end
end end
M.parse_curl_args = function(provider, code_opts) M.parse_curl_args = function(provider, prompt_opts)
local base, body_opts = P.parse_config(provider) local base, body_opts = P.parse_config(provider)
body_opts = vim.tbl_deep_extend("force", body_opts, { body_opts = vim.tbl_deep_extend("force", body_opts, {
@ -101,7 +101,7 @@ M.parse_curl_args = function(provider, code_opts)
proxy = base.proxy, proxy = base.proxy,
insecure = base.allow_insecure, insecure = base.allow_insecure,
headers = { ["Content-Type"] = "application/json" }, headers = { ["Content-Type"] = "application/json" },
body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts), body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts),
} }
end end

View File

@ -30,9 +30,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@field messages AvanteLLMMessage[] ---@field messages AvanteLLMMessage[]
---@field image_paths? string[] ---@field image_paths? string[]
---@field tools? AvanteLLMTool[] ---@field tools? AvanteLLMTool[]
---@field tool_result? AvanteLLMToolResult ---@field tool_histories? AvanteLLMToolHistory[]
---@field tool_use? AvanteLLMToolUse
---@field response_content? string
--- ---
---@class AvanteGeminiMessage ---@class AvanteGeminiMessage
---@field role "user" ---@field role "user"
@ -43,7 +41,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[] ---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
--- ---
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil} ---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput ---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
--- ---
---@class ResponseParser ---@class ResponseParser
---@field on_start AvanteLLMStartCallback ---@field on_start AvanteLLMStartCallback
@ -60,6 +58,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@field allow_insecure? boolean ---@field allow_insecure? boolean
---@field api_key_name? string ---@field api_key_name? string
---@field _shellenv? string ---@field _shellenv? string
---@field disable_tools? boolean
--- ---
---@class AvanteSupportedProvider: AvanteDefaultBaseProvider ---@class AvanteSupportedProvider: AvanteDefaultBaseProvider
---@field __inherited_from? string ---@field __inherited_from? string
@ -382,26 +381,31 @@ function M.refresh(provider)
end end
---@param opts AvanteProvider | AvanteSupportedProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor ---@param opts AvanteProvider | AvanteSupportedProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor
---@return AvanteDefaultBaseProvider, table<string, any> ---@return AvanteDefaultBaseProvider provider_opts
---@return table<string, any> request_body
M.parse_config = function(opts) M.parse_config = function(opts)
---@type AvanteDefaultBaseProvider ---@type AvanteDefaultBaseProvider
local s1 = {} local provider_opts = {}
---@type table<string, any> ---@type table<string, any>
local s2 = {} local request_body = {}
for key, value in pairs(opts) do for key, value in pairs(opts) do
if vim.tbl_contains(Config.BASE_PROVIDER_KEYS, key) then if vim.tbl_contains(Config.BASE_PROVIDER_KEYS, key) then
s1[key] = value provider_opts[key] = value
else else
s2[key] = value request_body[key] = value
end end
end end
return s1, request_body = vim
vim.iter(s2):filter(function(_, v) return type(v) ~= "function" end):fold({}, function(acc, k, v) .iter(request_body)
:filter(function(_, v) return type(v) ~= "function" end)
:fold({}, function(acc, k, v)
acc[k] = v acc[k] = v
return acc return acc
end) end)
return provider_opts, request_body
end end
---@private ---@private

View File

@ -167,26 +167,28 @@ M.parse_messages = function(opts)
table.insert(final_messages, { role = M.role_map[role] or role, content = message.content }) table.insert(final_messages, { role = M.role_map[role] or role, content = message.content })
end) end)
if opts.tool_result then if opts.tool_histories then
table.insert(final_messages, { for _, tool_history in ipairs(opts.tool_histories) do
role = M.role_map["assistant"], table.insert(final_messages, {
tool_calls = { role = M.role_map["assistant"],
{ tool_calls = {
id = opts.tool_use.id, {
type = "function", id = tool_history.tool_use.id,
["function"] = { type = "function",
name = opts.tool_use.name, ["function"] = {
arguments = opts.tool_use.input_json, name = tool_history.tool_use.name,
arguments = tool_history.tool_use.input_json,
},
}, },
}, },
}, })
}) local result_content = tool_history.tool_result.content or ""
local result_content = opts.tool_result.content or "" table.insert(final_messages, {
table.insert(final_messages, { role = "tool",
role = "tool", tool_call_id = tool_history.tool_result.tool_use_id,
tool_call_id = opts.tool_result.tool_use_id, content = tool_history.tool_result.is_error and "Error: " .. result_content or result_content,
content = opts.tool_result.is_error and "Error: " .. result_content or result_content, })
}) end
end end
return final_messages return final_messages
@ -269,8 +271,9 @@ M.parse_response_without_stream = function(data, _, opts)
end end
end end
M.parse_curl_args = function(provider, code_opts) M.parse_curl_args = function(provider, prompt_opts)
local base, body_opts = P.parse_config(provider) local base, body_opts = P.parse_config(provider)
local disable_tools = base.disable_tools or false
local headers = { local headers = {
["Content-Type"] = "application/json", ["Content-Type"] = "application/json",
@ -298,9 +301,10 @@ M.parse_curl_args = function(provider, code_opts)
body_opts.temperature = 1 body_opts.temperature = 1
end end
local tools = {} local tools = nil
if code_opts.tools then if not disable_tools and prompt_opts.tools then
for _, tool in ipairs(code_opts.tools) do tools = {}
for _, tool in ipairs(prompt_opts.tools) do
table.insert(tools, M.transform_tool(tool)) table.insert(tools, M.transform_tool(tool))
end end
end end
@ -315,7 +319,7 @@ M.parse_curl_args = function(provider, code_opts)
headers = headers, headers = headers,
body = vim.tbl_deep_extend("force", { body = vim.tbl_deep_extend("force", {
model = base.model, model = base.model,
messages = M.parse_messages(code_opts), messages = M.parse_messages(prompt_opts),
stream = stream, stream = stream,
tools = tools, tools = tools,
}, body_opts), }, body_opts),

View File

@ -31,7 +31,7 @@ M.parse_api_key = function()
return direct_output return direct_output
end end
M.parse_curl_args = function(provider, code_opts) M.parse_curl_args = function(provider, prompt_opts)
local base, body_opts = P.parse_config(provider) local base, body_opts = P.parse_config(provider)
local location = vim.fn.getenv("LOCATION") or "default-location" local location = vim.fn.getenv("LOCATION") or "default-location"
local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id" local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id"
@ -58,7 +58,7 @@ M.parse_curl_args = function(provider, code_opts)
}, },
proxy = base.proxy, proxy = base.proxy,
insecure = base.allow_insecure, insecure = base.allow_insecure,
body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts), body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts),
} }
end end

View File

@ -4,6 +4,7 @@ local fn = vim.fn
local Split = require("nui.split") local Split = require("nui.split")
local event = require("nui.utils.autocmd").event local event = require("nui.utils.autocmd").event
local PPath = require("plenary.path")
local Provider = require("avante.providers") local Provider = require("avante.providers")
local Path = require("avante.path") local Path = require("avante.path")
local Config = require("avante.config") local Config = require("avante.config")
@ -236,28 +237,54 @@ local function transform_result_content(selected_files, result_content, prev_fil
local end_line = 0 local end_line = 0
local match_filetype = nil local match_filetype = nil
local filepath = current_filepath or prev_filepath or "" local filepath = current_filepath or prev_filepath or ""
---@type {path: string, content: string, file_type: string | nil} | nil
local the_matched_file = nil
for _, file in ipairs(selected_files) do for _, file in ipairs(selected_files) do
if not Utils.is_same_file(file.path, filepath) then goto continue1 end if Utils.is_same_file(file.path, filepath) then
local file_content = vim.split(file.content, "\n") the_matched_file = file
if start_line ~= 0 or end_line ~= 0 then break end break
for j = 1, #file_content - (search_end - search_start) + 1 do end
local match = true end
for k = 0, search_end - search_start - 1 do
if if not the_matched_file then
Utils.remove_indentation(file_content[j + k]) ~= Utils.remove_indentation(result_lines[search_start + k]) if not PPath:new(filepath):exists() then
then Utils.warn("File not found: " .. filepath)
match = false goto continue
break end
end if not PPath:new(filepath):is_file() then
end Utils.warn("Not a file: " .. filepath)
if match then goto continue
start_line = j end
end_line = j + (search_end - search_start) - 1 local content = Utils.file.read_content(filepath)
match_filetype = file.file_type if content == nil then
Utils.warn("Failed to read file: " .. filepath)
goto continue
end
the_matched_file = {
filepath = filepath,
content = content,
file_type = nil,
}
end
local file_content = vim.split(the_matched_file.content, "\n")
if start_line ~= 0 or end_line ~= 0 then break end
for j = 1, #file_content - (search_end - search_start) + 1 do
local match = true
for k = 0, search_end - search_start - 1 do
if
Utils.remove_indentation(file_content[j + k]) ~= Utils.remove_indentation(result_lines[search_start + k])
then
match = false
break break
end end
end end
::continue1:: if match then
start_line = j
end_line = j + (search_end - search_start) - 1
match_filetype = the_matched_file.file_type
break
end
end end
-- when the filetype isn't detected, fallback to matching based on filepath. -- when the filetype isn't detected, fallback to matching based on filepath.
@ -1758,7 +1785,7 @@ function Sidebar:create_input_container(opts)
vim.keymap.set("n", "G", on_G, { buffer = self.result_container.bufnr }) vim.keymap.set("n", "G", on_G, { buffer = self.result_container.bufnr })
---@type AvanteLLMStartCallback ---@type AvanteLLMStartCallback
local on_start = function(start_opts) end local on_start = function(_) end
---@type AvanteLLMChunkCallback ---@type AvanteLLMChunkCallback
local on_chunk = function(chunk) local on_chunk = function(chunk)