feat: add add_file_to_context tool (#1191)

This commit is contained in:
yetone 2025-02-06 16:00:14 +08:00 committed by GitHub
parent 4f41154e83
commit f2bd4adba4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 88 additions and 25 deletions

View File

@ -344,7 +344,7 @@ end
---@param idx integer
---@return boolean
function FileSelector:remove_selected_filepaths(idx)
function FileSelector:remove_selected_filepaths_with_index(idx)
if idx > 0 and idx <= #self.selected_filepaths then
table.remove(self.selected_filepaths, idx)
self:emit("update")
@ -353,6 +353,12 @@ function FileSelector:remove_selected_filepaths(idx)
return false
end
function FileSelector:remove_selected_file(rel_path)
local uniform_path = Utils.uniform_path(rel_path)
local idx = Utils.tbl_indexof(self.selected_filepaths, uniform_path)
if idx then self:remove_selected_filepaths_with_index(idx) end
end
---@return { path: string, content: string, file_type: string }[]
function FileSelector:get_selected_files_contents()
local contents = {}

View File

@ -146,7 +146,7 @@ M._stream = function(opts)
on_chunk = opts.on_chunk,
on_stop = function(stop_opts)
if stop_opts.reason == "tool_use" and stop_opts.tool_use then
local result, error = LLMTools.process_tool_use(stop_opts.tool_use, opts.on_tool_log)
local result, error = LLMTools.process_tool_use(opts.tools, stop_opts.tool_use, opts.on_tool_log)
local tool_result = {
tool_use_id = stop_opts.tool_use.id,
content = error ~= nil and error or result,

View File

@ -311,6 +311,7 @@ end
---@class AvanteLLMTool
---@field name string
---@field description string
---@field func? fun(input: any): (string | nil, string | nil)
---@field param AvanteLLMToolParam
---@field returns AvanteLLMToolReturn[]
@ -716,16 +717,17 @@ M.tools = {
},
}
---@param tools AvanteLLMTool[]
---@param tool_use AvanteLLMToolUse
---@param on_log? fun(tool_name: string, log: string): nil
---@return string | nil result
---@return string | nil error
function M.process_tool_use(tool_use, on_log)
function M.process_tool_use(tools, tool_use, on_log)
Utils.debug("use tool", tool_use.name, tool_use.input_json)
local tool = vim.iter(M.tools):find(function(tool) return tool.name == tool_use.name end)
local tool = vim.iter(tools):find(function(tool) return tool.name == tool_use.name end)
if tool == nil then return end
local input_json = vim.json.decode(tool_use.input_json)
local func = M[tool.name]
local func = tool.func or M[tool.name]
if on_log then on_log(tool_use.name, "running tool") end
local result, error = func(input_json, function(log)
if on_log then on_log(tool_use.name, log) end

View File

@ -1628,6 +1628,41 @@ function Sidebar:create_input_container(opts)
local chat_history = Path.history.load(self.code.bufnr)
local tools = vim.deepcopy(LLMTools.tools)
table.insert(tools, {
name = "add_file_to_context",
description = "Add a file to the context",
---@param input { rel_path: string }
---@return string | nil result
---@return string | nil error
func = function(input)
self.file_selector:add_selected_file(input.rel_path)
return "Added file to context", nil
end,
param = {
type = "table",
fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } },
},
returns = {},
})
table.insert(tools, {
name = "remove_file_from_context",
description = "Remove a file from the context",
---@param input { rel_path: string }
---@return string | nil result
---@return string | nil error
func = function(input)
self.file_selector:remove_selected_file(input.rel_path)
return "Removed file from context", nil
end,
param = {
type = "table",
fields = { { name = "rel_path", description = "Relative path to the file", type = "string" } },
},
returns = {},
})
---@param request string
---@return GeneratePromptsOptions
local function get_generate_prompts_options(request)
@ -1697,7 +1732,7 @@ function Sidebar:create_input_container(opts)
selected_code = selected_code_content,
instructions = request,
mode = "planning",
tools = LLMTools.tools,
tools = tools,
}
end
@ -2133,6 +2168,41 @@ function Sidebar:get_selected_code_size()
return selected_code_size
end
function Sidebar:get_selected_files_size()
if not self.file_selector then return 0 end
local selected_files_max_lines_count = 10
local selected_files = self.file_selector:get_selected_filepaths()
local selected_files_size = #selected_files
selected_files_size = math.min(selected_files_size, selected_files_max_lines_count)
return selected_files_size
end
function Sidebar:get_result_container_height()
local selected_code_size = self:get_selected_code_size()
local selected_files_size = self:get_selected_files_size()
if self:get_layout() == "horizontal" then return math.floor(Config.windows.height / 100 * vim.o.lines) end
return math.max(1, api.nvim_win_get_height(self.code.winid) - selected_files_size - selected_code_size - 3 - 8)
end
function Sidebar:get_result_container_width()
if self:get_layout() == "vertical" then return math.floor(Config.windows.width / 100 * vim.o.columns) end
return math.max(1, api.nvim_win_get_width(self.code.winid))
end
function Sidebar:adjust_result_container_layout()
local width = self:get_result_container_width()
local height = self:get_result_container_height()
api.nvim_win_set_width(self.result_container.winid, width)
api.nvim_win_set_height(self.result_container.winid, height)
end
---@param opts AskOptions
function Sidebar:render(opts)
local chat_history = Path.history.load(self.code.bufnr)
@ -2141,20 +2211,6 @@ function Sidebar:render(opts)
return (opts and opts.win and opts.win.position) and opts.win.position or calculate_config_window_position()
end
local get_height = function()
local selected_code_size = self:get_selected_code_size()
if self:get_layout() == "horizontal" then return math.floor(Config.windows.height / 100 * vim.o.lines) end
return math.max(1, api.nvim_win_get_height(self.code.winid) - selected_code_size - 3 - 8)
end
local get_width = function()
if self:get_layout() == "vertical" then return math.floor(Config.windows.width / 100 * vim.o.columns) end
return math.max(1, api.nvim_win_get_width(self.code.winid))
end
self.result_container = Split({
enter = false,
relative = "editor",
@ -2170,8 +2226,8 @@ function Sidebar:render(opts)
wrap = Config.windows.wrap,
}),
size = {
width = get_width(),
height = get_height(),
width = self:get_result_container_width(),
height = self:get_result_container_height(),
},
})
@ -2270,13 +2326,12 @@ function Sidebar:create_selected_files_container()
Highlights.SUBTITLE,
Highlights.REVERSED_SUBTITLE
)
self:adjust_result_container_layout()
end
self.file_selector:on("update", render)
local remove_file = function(line_number)
if self.file_selector:remove_selected_filepaths(line_number) then render() end
end
local remove_file = function(line_number) self.file_selector:remove_selected_filepaths_with_index(line_number) end
-- Function to show hint
local function show_hint()