feat: tools ()

* feat: tools

* feat: claude use tools

* feat: openai use tools
This commit is contained in:
yetone 2025-02-05 22:39:54 +08:00 committed by GitHub
parent 1726d32778
commit 1437f319d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1321 additions and 74 deletions

@ -38,6 +38,9 @@ jobs:
mkdir -p _neovim mkdir -p _neovim
curl -sL "https://github.com/neovim/neovim/releases/download/${{ matrix.rev }}" | tar xzf - --strip-components=1 -C "${PWD}/_neovim" curl -sL "https://github.com/neovim/neovim/releases/download/${{ matrix.rev }}" | tar xzf - --strip-components=1 -C "${PWD}/_neovim"
} }
sudo apt-get update
sudo apt-get install -y ripgrep
sudo apt-get install -y silversearcher-ag
- name: Run tests - name: Run tests
run: | run: |

@ -31,6 +31,7 @@ struct TemplateContext {
selected_code: Option<String>, selected_code: Option<String>,
project_context: Option<String>, project_context: Option<String>,
diagnostics: Option<String>, diagnostics: Option<String>,
system_info: Option<String>,
} }
// Given the file name registered after add, the context table in Lua, resulted in a formatted // Given the file name registered after add, the context table in Lua, resulted in a formatted
@ -54,6 +55,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
selected_code => context.selected_code, selected_code => context.selected_code,
project_context => context.project_context, project_context => context.project_context,
diagnostics => context.diagnostics, diagnostics => context.diagnostics,
system_info => context.system_info,
}) })
.map_err(LuaError::external) .map_err(LuaError::external)
.unwrap()) .unwrap())

@ -20,6 +20,14 @@ M._defaults = {
-- For most providers that we support we will determine this automatically. -- For most providers that we support we will determine this automatically.
-- If you wish to use a given implementation, then you can override it here. -- If you wish to use a given implementation, then you can override it here.
tokenizer = "tiktoken", tokenizer = "tiktoken",
web_search_engine = {
provider = "tavily",
api_key_name = "TAVILY_API_KEY",
provider_opts = {
time_range = "d",
include_answer = "basic",
},
},
---@type AvanteSupportedProvider ---@type AvanteSupportedProvider
openai = { openai = {
endpoint = "https://api.openai.com/v1", endpoint = "https://api.openai.com/v1",

@ -8,6 +8,7 @@ local Utils = require("avante.utils")
local Config = require("avante.config") local Config = require("avante.config")
local Path = require("avante.path") local Path = require("avante.path")
local P = require("avante.providers") local P = require("avante.providers")
local LLMTools = require("avante.llm_tools")
---@class avante.LLM ---@class avante.LLM
local M = {} local M = {}
@ -45,6 +46,8 @@ M.generate_prompts = function(opts)
local project_root = Utils.root.get() local project_root = Utils.root.get()
Path.prompts.initialize(Path.prompts.get(project_root)) Path.prompts.initialize(Path.prompts.get(project_root))
local system_info = Utils.get_system_info()
local template_opts = { local template_opts = {
use_xml_format = Provider.use_xml_format, use_xml_format = Provider.use_xml_format,
ask = opts.ask, -- TODO: add mode without ask instruction ask = opts.ask, -- TODO: add mode without ask instruction
@ -53,6 +56,7 @@ M.generate_prompts = function(opts)
selected_code = opts.selected_code, selected_code = opts.selected_code,
project_context = opts.project_context, project_context = opts.project_context,
diagnostics = opts.diagnostics, diagnostics = opts.diagnostics,
system_info = system_info,
} }
local system_prompt = Path.prompts.render_mode(mode, template_opts) local system_prompt = Path.prompts.render_mode(mode, template_opts)
@ -111,6 +115,10 @@ M.generate_prompts = function(opts)
system_prompt = system_prompt, system_prompt = system_prompt,
messages = messages, messages = messages,
image_paths = image_paths, image_paths = image_paths,
tools = opts.tools,
tool_use = opts.tool_use,
tool_result = opts.tool_result,
response_content = opts.response_content,
} }
end end
@ -135,7 +143,28 @@ M._stream = function(opts)
local current_event_state = nil local current_event_state = nil
---@type AvanteHandlerOptions ---@type AvanteHandlerOptions
local handler_opts = { on_chunk = opts.on_chunk, on_complete = opts.on_complete } local handler_opts = {
on_start = opts.on_start,
on_chunk = opts.on_chunk,
on_stop = function(stop_opts)
if stop_opts.reason == "tool_use" and stop_opts.tool_use then
local result, error = LLMTools.process_tool_use(stop_opts.tool_use)
local tool_result = {
tool_use_id = stop_opts.tool_use.id,
content = error ~= nil and error or result,
is_error = error ~= nil,
}
local new_opts = vim.tbl_deep_extend(
"force",
opts,
{ tool_result = tool_result, tool_use = stop_opts.tool_use, response_content = stop_opts.response_content }
)
return M._stream(new_opts)
end
return opts.on_stop(stop_opts)
end,
}
---@type AvanteCurlOutput ---@type AvanteCurlOutput
local spec = Provider.parse_curl_args(Provider, code_opts) local spec = Provider.parse_curl_args(Provider, code_opts)
@ -180,7 +209,7 @@ M._stream = function(opts)
stream = function(err, data, _) stream = function(err, data, _)
if err then if err then
completed = true completed = true
opts.on_complete(err) handler_opts.on_stop({ reason = "error", error = err })
return return
end end
if not data then return end if not data then return end
@ -224,7 +253,7 @@ M._stream = function(opts)
active_job = nil active_job = nil
completed = true completed = true
cleanup() cleanup()
opts.on_complete(err) handler_opts.on_stop({ reason = "error", error = err })
end, end,
callback = function(result) callback = function(result)
active_job = nil active_job = nil
@ -238,9 +267,10 @@ M._stream = function(opts)
vim.schedule(function() vim.schedule(function()
if not completed then if not completed then
completed = true completed = true
opts.on_complete( handler_opts.on_stop({
"API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body) reason = "error",
) error = "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body),
})
end end
end) end)
end end
@ -335,9 +365,9 @@ M._dual_boost_stream = function(opts, Provider1, Provider2)
on_chunk = function(chunk) on_chunk = function(chunk)
if chunk then response = response .. chunk end if chunk then response = response .. chunk end
end, end,
on_complete = function(err) on_stop = function(stop_opts)
if err then if stop_opts.error then
Utils.error(string.format("Stream %d failed: %s", index, err)) Utils.error(string.format("Stream %d failed: %s", index, stop_opts.error))
return return
end end
Utils.debug(string.format("Response %d completed", index)) Utils.debug(string.format("Response %d completed", index))
@ -381,10 +411,15 @@ end
---@field instructions string ---@field instructions string
---@field mode LlmMode ---@field mode LlmMode
---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil ---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil
---@field tools? AvanteLLMTool[]
---@field tool_result? AvanteLLMToolResult
---@field tool_use? AvanteLLMToolUse
---@field response_content? string
--- ---
---@class StreamOptions: GeneratePromptsOptions ---@class StreamOptions: GeneratePromptsOptions
---@field on_chunk AvanteChunkParser ---@field on_start AvanteLLMStartCallback
---@field on_complete AvanteCompleteParser ---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback
---@param opts StreamOptions ---@param opts StreamOptions
M.stream = function(opts) M.stream = function(opts)
@ -396,12 +431,12 @@ M.stream = function(opts)
return original_on_chunk(chunk) return original_on_chunk(chunk)
end) end)
end end
if opts.on_complete ~= nil then if opts.on_stop ~= nil then
local original_on_complete = opts.on_complete local original_on_stop = opts.on_stop
opts.on_complete = vim.schedule_wrap(function(err) opts.on_stop = vim.schedule_wrap(function(stop_opts)
if is_completed then return end if is_completed then return end
is_completed = true if stop_opts.reason == "complete" or stop_opts.reason == "error" then is_completed = true end
return original_on_complete(err) return original_on_stop(stop_opts)
end) end)
end end
if Config.dual_boost.enabled then if Config.dual_boost.enabled then

714
lua/avante/llm_tools.lua Normal file

@ -0,0 +1,714 @@
local curl = require("plenary.curl")
local Utils = require("avante.utils")
local Path = require("plenary.path")
local Config = require("avante.config")
local M = {}
---@param rel_path string
---@return string
local function get_abs_path(rel_path)
local project_root = Utils.get_project_root()
return Path:new(project_root):joinpath(rel_path):absolute()
end
function M.comfirm(msg)
local ok = vim.fn.confirm(msg, "&Yes\n&No", 2)
return ok == 1
end
---@param abs_path string
---@return boolean
local function has_permission_to_access(abs_path)
if not Path:new(abs_path):is_absolute() then return false end
local project_root = Utils.get_project_root()
if abs_path:sub(1, #project_root) ~= project_root then return false end
local gitignore_path = project_root .. "/.gitignore"
local gitignore_patterns, gitignore_negate_patterns = Utils.parse_gitignore(gitignore_path)
return not Utils.is_ignored(abs_path, gitignore_patterns, gitignore_negate_patterns)
end
---@param opts { rel_path: string, depth?: integer }
---@return string files
---@return string|nil error
function M.list_files(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
local files = Utils.scan_directory_respect_gitignore({
directory = abs_path,
add_dirs = true,
depth = opts.depth,
})
local result = ""
for _, file in ipairs(files) do
local uniform_path = Utils.uniform_path(file)
result = result .. uniform_path .. "\n"
end
result = result:gsub("\n$", "")
return result, nil
end
---@param opts { rel_path: string, keyword: string }
---@return string files
---@return string|nil error
function M.search_files(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
local files = Utils.scan_directory_respect_gitignore({
directory = abs_path,
})
local result = ""
for _, file in ipairs(files) do
if file:find(opts.keyword) then result = result .. file .. "\n" end
end
result = result:gsub("\n$", "")
return result, nil
end
---@param opts { rel_path: string, keyword: string }
---@return string result
---@return string|nil error
function M.search(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return "", "No such file or directory: " .. abs_path end
---check if any search cmd is available
local search_cmd = vim.fn.exepath("rg")
if search_cmd == "" then search_cmd = vim.fn.exepath("ag") end
if search_cmd == "" then search_cmd = vim.fn.exepath("ack") end
if search_cmd == "" then search_cmd = vim.fn.exepath("grep") end
if search_cmd == "" then return "", "No search command found" end
---execute the search command
local cmd = ""
if search_cmd:find("rg") then
cmd = string.format("%s --files-with-matches --no-ignore-vcs --ignore-case --hidden --glob '!.git'", search_cmd)
cmd = string.format("%s '%s' %s", cmd, opts.keyword, abs_path)
elseif search_cmd:find("ag") then
cmd = string.format("%s '%s' --nocolor --nogroup --hidden --ignore .git %s", search_cmd, opts.keyword, abs_path)
elseif search_cmd:find("ack") then
cmd = string.format("%s --nocolor --nogroup --hidden --ignore-dir .git", search_cmd)
cmd = string.format("%s '%s' %s", cmd, opts.keyword, abs_path)
elseif search_cmd:find("grep") then
cmd = string.format("%s -riH --exclude-dir=.git %s %s", search_cmd, opts.keyword, abs_path)
end
Utils.debug("cmd", cmd)
local result = vim.fn.system(cmd)
return result or "", nil
end
---@param opts { rel_path: string }
---@return string definitions
---@return string|nil error
function M.read_file_toplevel_symbols(opts)
local RepoMap = require("avante.repo_map")
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
local filetype = RepoMap.get_ts_lang(abs_path)
local repo_map_lib = RepoMap._init_repo_map_lib()
if not repo_map_lib then return "", "Failed to load avante_repo_map" end
local definitions = filetype
and repo_map_lib.stringify_definitions(filetype, Utils.file.read_content(abs_path) or "")
or ""
return definitions, nil
end
---@param opts { rel_path: string }
---@return string content
---@return string|nil error
function M.read_file(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
local file = io.open(abs_path, "r")
if not file then return "", "file not found: " .. abs_path end
local content = file:read("*a")
file:close()
return content, nil
end
---@param opts { rel_path: string }
---@return boolean success
---@return string|nil error
function M.create_file(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
---create directory if it doesn't exist
local dir = Path:new(abs_path):parent()
if not dir:exists() then dir:mkdir({ parents = true }) end
---create file if it doesn't exist
if not dir:joinpath(opts.rel_path):exists() then
local file = io.open(abs_path, "w")
if not file then return false, "file not found: " .. abs_path end
file:close()
end
return true, nil
end
---@param opts { rel_path: string, new_rel_path: string }
---@return boolean success
---@return string|nil error
function M.rename_file(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
local new_abs_path = get_abs_path(opts.new_rel_path)
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end
if not M.confirm("Are you sure you want to rename the file: " .. abs_path .. " to: " .. new_abs_path) then
return false, "User canceled"
end
os.rename(abs_path, new_abs_path)
return true, nil
end
---@param opts { rel_path: string, new_rel_path: string }
---@return boolean success
---@return string|nil error
function M.copy_file(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
local new_abs_path = get_abs_path(opts.new_rel_path)
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end
Path:new(new_abs_path):write(Path:new(abs_path):read())
return true, nil
end
---@param opts { rel_path: string }
---@return boolean success
---@return string|nil error
function M.delete_file(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
if not M.confirm("Are you sure you want to delete the file: " .. abs_path) then return false, "User canceled" end
os.remove(abs_path)
return true, nil
end
---@param opts { rel_path: string }
---@return boolean success
---@return string|nil error
function M.create_dir(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
Path:new(abs_path):mkdir({ parents = true })
return true, nil
end
---@param opts { rel_path: string, new_rel_path: string }
---@return boolean success
---@return string|nil error
function M.rename_dir(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end
local new_abs_path = get_abs_path(opts.new_rel_path)
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
if Path:new(new_abs_path):exists() then return false, "Directory already exists: " .. new_abs_path end
if not M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?") then
return false, "User canceled"
end
os.rename(abs_path, new_abs_path)
return true, nil
end
---@param opts { rel_path: string }
---@return boolean success
---@return string|nil error
function M.delete_dir(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end
if not M.confirm("Are you sure you want to delete the directory: " .. abs_path) then
return false, "User canceled"
end
os.remove(abs_path)
return true, nil
end
---@param opts { rel_path: string, command: string }
---@return string|boolean result
---@return string|nil error
function M.run_command(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if
not M.confirm("Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path)
then
return false, "User canceled"
end
---change cwd to abs_path
local old_cwd = vim.fn.getcwd()
vim.fn.chdir(abs_path)
local res = Utils.shell_run(opts.command)
vim.fn.chdir(old_cwd)
if res.code ~= 0 then
if res.stdout then return false, "Error: " .. res.stdout .. "; Error code: " .. tostring(res.code) end
return false, "Error code: " .. tostring(res.code)
end
return res.stdout, nil
end
---@param opts { query: string }
---@return string|nil result
---@return string|nil error
function M.web_search(opts)
local search_engine = Config.web_search_engine
if search_engine.provider == "tavily" then
if search_engine.api_key_name == "" then return nil, "No API key provided" end
local api_key = os.getenv(search_engine.api_key_name)
if api_key == nil or api_key == "" then
return nil, "Environment variable " .. search_engine.api_key_name .. " is not set"
end
local resp = curl.post("https://api.tavily.com/search", {
headers = {
["Content-Type"] = "application/json",
["Authorization"] = "Bearer " .. api_key,
},
body = vim.json.encode(vim.tbl_deep_extend("force", {
query = opts.query,
}, search_engine.provider_opts)),
})
if resp.status ~= 200 then return nil, "Error: " .. resp.body end
local jsn = vim.json.decode(resp.body)
return jsn.anwser, nil
end
end
---@class AvanteLLMTool
---@field name string
---@field description string
---@field param AvanteLLMToolParam
---@field returns AvanteLLMToolReturn[]
---@class AvanteLLMToolParam
---@field type string
---@field fields AvanteLLMToolParamField[]
---@class AvanteLLMToolParamField
---@field name string
---@field description string
---@field type string
---@field optional? boolean
---@class AvanteLLMToolReturn
---@field name string
---@field description string
---@field type string
---@field optional? boolean
---@type AvanteLLMTool[]
M.tools = {
{
name = "list_files",
description = "List files in a directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
{
name = "depth",
description = "Depth of the directory",
type = "integer",
optional = true,
},
},
},
returns = {
{
name = "files",
description = "List of files in the directory",
type = "string[]",
},
{
name = "error",
description = "Error message if the directory was not listed successfully",
type = "string",
optional = true,
},
},
},
{
name = "search_files",
description = "Search for files in a directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
{
name = "keyword",
description = "Keyword to search for",
type = "string",
},
},
},
returns = {
{
name = "files",
description = "List of files that match the keyword",
type = "string",
},
{
name = "error",
description = "Error message if the directory was not searched successfully",
type = "string",
optional = true,
},
},
},
{
name = "search",
description = "Search for a keyword in a directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
{
name = "keyword",
description = "Keyword to search for",
type = "string",
},
},
},
returns = {
{
name = "files",
description = "List of files that match the keyword",
type = "string",
},
{
name = "error",
description = "Error message if the directory was not searched successfully",
type = "string",
optional = true,
},
},
},
{
name = "read_file_toplevel_symbols",
description = "Read the top-level symbols of a file",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
type = "string",
},
},
},
returns = {
{
name = "definitions",
description = "Top-level symbols of the file",
type = "string",
},
{
name = "error",
description = "Error message if the file was not read successfully",
type = "string",
optional = true,
},
},
},
{
name = "read_file",
description = "Read the contents of a file",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
type = "string",
},
},
},
returns = {
{
name = "content",
description = "Contents of the file",
type = "string",
},
{
name = "error",
description = "Error message if the file was not read successfully",
type = "string",
optional = true,
},
},
},
{
name = "create_file",
description = "Create a new file",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the file was created successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the file was not created successfully",
type = "string",
optional = true,
},
},
},
{
name = "rename_file",
description = "Rename a file",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
type = "string",
},
{
name = "new_rel_path",
description = "New relative path for the file",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the file was renamed successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the file was not renamed successfully",
type = "string",
optional = true,
},
},
},
{
name = "delete_file",
description = "Delete a file",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the file was deleted successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the file was not deleted successfully",
type = "string",
optional = true,
},
},
},
{
name = "create_dir",
description = "Create a new directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the directory was created successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the directory was not created successfully",
type = "string",
optional = true,
},
},
},
{
name = "rename_dir",
description = "Rename a directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
{
name = "new_rel_path",
description = "New relative path for the directory",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the directory was renamed successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the directory was not renamed successfully",
type = "string",
optional = true,
},
},
},
{
name = "delete_dir",
description = "Delete a directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the directory was deleted successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the directory was not deleted successfully",
type = "string",
optional = true,
},
},
},
{
name = "run_command",
description = "Run a command in a directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
{
name = "command",
description = "Command to run",
type = "string",
},
},
},
returns = {
{
name = "stdout",
description = "Output of the command",
type = "string",
},
{
name = "error",
description = "Error message if the command was not run successfully",
type = "string",
optional = true,
},
},
},
{
name = "web_search",
description = "Search the web",
param = {
type = "table",
fields = {
{
name = "query",
description = "Query to search",
type = "string",
},
},
},
returns = {
{
name = "result",
description = "Result of the search",
type = "string",
},
{
name = "error",
description = "Error message if the search was not successful",
type = "string",
optional = true,
},
},
},
}
---@param tool_use AvanteLLMToolUse
---@return string | nil result
---@return string | nil error
function M.process_tool_use(tool_use)
Utils.debug("use tool", tool_use.name, tool_use.input_json)
local tool = vim.iter(M.tools):find(function(tool) return tool.name == tool_use.name end)
if tool == nil then return end
local input_json = vim.json.decode(tool_use.input_json)
local func = M[tool.name]
local result, error = func(input_json)
-- Utils.debug("result", result)
-- Utils.debug("error", error)
if result ~= nil and type(result) ~= "string" then result = vim.json.encode(result) end
return result, error
end
return M

@ -6,6 +6,8 @@
---@field role "user" | "assistant" ---@field role "user" | "assistant"
---@field content [AvanteBedrockClaudeTextMessage][] ---@field content [AvanteBedrockClaudeTextMessage][]
local Claude = require("avante.providers.claude")
---@class AvanteBedrockModelHandler ---@class AvanteBedrockModelHandler
local M = {} local M = {}
@ -33,25 +35,7 @@ M.parse_messages = function(opts)
return messages return messages
end end
M.parse_response = function(ctx, data_stream, event_state, opts) M.parse_response = Claude.parse_response
if event_state == nil then
if data_stream:match('"content_block_delta"') then
event_state = "content_block_delta"
elseif data_stream:match('"message_stop"') then
event_state = "message_stop"
end
end
if event_state == "content_block_delta" then
local ok, json = pcall(vim.json.decode, data_stream)
if not ok then return end
opts.on_chunk(json.delta.text)
elseif event_state == "message_stop" then
opts.on_complete(nil)
return
elseif event_state == "error" then
opts.on_complete(vim.json.decode(data_stream))
end
end
---@param prompt_opts AvantePromptOptions ---@param prompt_opts AvantePromptOptions
---@param body_opts table ---@param body_opts table
@ -60,7 +44,6 @@ M.build_bedrock_payload = function(prompt_opts, body_opts)
local system_prompt = prompt_opts.system_prompt or "" local system_prompt = prompt_opts.system_prompt or ""
local messages = M.parse_messages(prompt_opts) local messages = M.parse_messages(prompt_opts)
local max_tokens = body_opts.max_tokens or 2000 local max_tokens = body_opts.max_tokens or 2000
local temperature = body_opts.temperature or 0.7
local payload = { local payload = {
anthropic_version = "bedrock-2023-05-31", anthropic_version = "bedrock-2023-05-31",
max_tokens = max_tokens, max_tokens = max_tokens,

@ -17,6 +17,44 @@ local P = require("avante.providers")
---@field role "user" | "assistant" ---@field role "user" | "assistant"
---@field content [AvanteClaudeTextMessage | AvanteClaudeImageMessage][] ---@field content [AvanteClaudeTextMessage | AvanteClaudeImageMessage][]
---@class AvanteClaudeTool
---@field name string
---@field description string
---@field input_schema AvanteClaudeToolInputSchema
---@class AvanteClaudeToolInputSchema
---@field type "object"
---@field properties table<string, AvanteClaudeToolInputSchemaProperty>
---@field required string[]
---@class AvanteClaudeToolInputSchemaProperty
---@field type "string" | "number" | "boolean"
---@field description string
---@field enum? string[]
---@param tool AvanteLLMTool
---@return AvanteClaudeTool
local function transform_tool(tool)
local input_schema_properties = {}
local required = {}
for _, field in ipairs(tool.param.fields) do
input_schema_properties[field.name] = {
type = field.type,
description = field.description,
}
if not field.optional then table.insert(required, field.name) end
end
return {
name = tool.name,
description = tool.description,
input_schema = {
type = "object",
properties = input_schema_properties,
required = required,
},
}
end
---@class AvanteProviderFunctor ---@class AvanteProviderFunctor
local M = {} local M = {}
@ -74,26 +112,101 @@ M.parse_messages = function(opts)
messages[#messages].content = message_content messages[#messages].content = message_content
end end
if opts.tool_use then
local msg = {
role = "assistant",
content = {},
}
if opts.response_content then
msg.content[#msg.content + 1] = {
type = "text",
text = opts.response_content,
}
end
msg.content[#msg.content + 1] = {
type = "tool_use",
id = opts.tool_use.id,
name = opts.tool_use.name,
input = vim.json.decode(opts.tool_use.input_json),
}
messages[#messages + 1] = msg
end
if opts.tool_result then
messages[#messages + 1] = {
role = "user",
content = {
{
type = "tool_result",
tool_use_id = opts.tool_result.tool_use_id,
content = opts.tool_result.content,
is_error = opts.tool_result.is_error,
},
},
}
end
return messages return messages
end end
M.parse_response = function(ctx, data_stream, event_state, opts) M.parse_response = function(ctx, data_stream, event_state, opts)
if event_state == nil then if event_state == nil then
if data_stream:match('"content_block_delta"') then if data_stream:match('"message_start"') then
event_state = "content_block_delta" event_state = "message_start"
elseif data_stream:match('"message_delta"') then
event_state = "message_delta"
elseif data_stream:match('"message_stop"') then elseif data_stream:match('"message_stop"') then
event_state = "message_stop" event_state = "message_stop"
elseif data_stream:match('"content_block_start"') then
event_state = "content_block_start"
elseif data_stream:match('"content_block_delta"') then
event_state = "content_block_delta"
elseif data_stream:match('"content_block_stop"') then
event_state = "content_block_stop"
end end
end end
if event_state == "content_block_delta" then if event_state == "message_start" then
local ok, json = pcall(vim.json.decode, data_stream) local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then return end if not ok then return end
opts.on_chunk(json.delta.text) opts.on_start(jsn.message.usage)
elseif event_state == "message_stop" then elseif event_state == "content_block_start" then
opts.on_complete(nil) local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then return end
if jsn.content_block.type == "tool_use" then
ctx.tool_use = {
name = jsn.content_block.name,
id = jsn.content_block.id,
input_json = "",
}
elseif jsn.content_block.type == "text" then
ctx.response_content = ""
end
elseif event_state == "content_block_delta" then
local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then return end
if ctx.tool_use and jsn.delta.type == "input_json_delta" then
ctx.tool_use.input_json = ctx.tool_use.input_json .. jsn.delta.partial_json
return
elseif ctx.response_content and jsn.delta.type == "text_delta" then
ctx.response_content = ctx.response_content .. jsn.delta.text
end
opts.on_chunk(jsn.delta.text)
elseif event_state == "message_delta" then
local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then return end
if jsn.delta.stop_reason == "end_turn" then
opts.on_stop({ reason = "complete", usage = jsn.usage })
elseif jsn.delta.stop_reason == "tool_use" then
opts.on_stop({
reason = "tool_use",
usage = jsn.usage,
tool_use = ctx.tool_use,
response_content = ctx.response_content,
})
end
return return
elseif event_state == "error" then elseif event_state == "error" then
opts.on_complete(vim.json.decode(data_stream)) opts.on_stop({ reason = "error", error = vim.json.decode(data_stream) })
end end
end end
@ -113,6 +226,13 @@ M.parse_curl_args = function(provider, prompt_opts)
local messages = M.parse_messages(prompt_opts) local messages = M.parse_messages(prompt_opts)
local tools = {}
if prompt_opts.tools then
for _, tool in ipairs(prompt_opts.tools) do
table.insert(tools, transform_tool(tool))
end
end
return { return {
url = Utils.url_join(base.endpoint, "/v1/messages"), url = Utils.url_join(base.endpoint, "/v1/messages"),
proxy = base.proxy, proxy = base.proxy,
@ -128,6 +248,7 @@ M.parse_curl_args = function(provider, prompt_opts)
}, },
}, },
messages = messages, messages = messages,
tools = tools,
stream = true, stream = true,
}, body_opts), }, body_opts),
} }

@ -62,7 +62,7 @@ M.parse_stream_data = function(data, opts)
local json = vim.json.decode(data) local json = vim.json.decode(data)
if json.type ~= nil then if json.type ~= nil then
if json.type == "message-end" and json.delta.finish_reason == "COMPLETE" then if json.type == "message-end" and json.delta.finish_reason == "COMPLETE" then
opts.on_complete(nil) opts.on_stop({ reason = "complete" })
return return
end end
if json.type == "content-delta" then opts.on_chunk(json.delta.message.content.text) end if json.type == "content-delta" then opts.on_chunk(json.delta.message.content.text) end

@ -66,17 +66,17 @@ end
M.parse_response = function(ctx, data_stream, _, opts) M.parse_response = function(ctx, data_stream, _, opts)
local ok, json = pcall(vim.json.decode, data_stream) local ok, json = pcall(vim.json.decode, data_stream)
if not ok then opts.on_complete(json) end if not ok then opts.on_stop({ reason = "error", error = json }) end
if json.candidates then if json.candidates then
if #json.candidates > 0 then if #json.candidates > 0 then
if json.candidates[1].finishReason and json.candidates[1].finishReason == "STOP" then if json.candidates[1].finishReason and json.candidates[1].finishReason == "STOP" then
opts.on_chunk(json.candidates[1].content.parts[1].text) opts.on_chunk(json.candidates[1].content.parts[1].text)
opts.on_complete(nil) opts.on_stop({ reason = "complete" })
else else
opts.on_chunk(json.candidates[1].content.parts[1].text) opts.on_chunk(json.candidates[1].content.parts[1].text)
end end
else else
opts.on_complete(nil) opts.on_stop({ reason = "complete" })
end end
end end
end end

@ -11,17 +11,28 @@ local DressingConfig = {
local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@class AvanteHandlerOptions: table<[string], string> ---@class AvanteHandlerOptions: table<[string], string>
---@field on_chunk AvanteChunkParser ---@field on_start AvanteLLMStartCallback
---@field on_complete AvanteCompleteParser ---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback
--- ---
---@class AvanteLLMMessage ---@class AvanteLLMMessage
---@field role "user" | "assistant" ---@field role "user" | "assistant"
---@field content string ---@field content string
--- ---
---@class AvanteLLMToolResult
---@field tool_name string
---@field tool_use_id string
---@field content string
---@field is_error? boolean
---
---@class AvantePromptOptions: table<[string], string> ---@class AvantePromptOptions: table<[string], string>
---@field system_prompt string ---@field system_prompt string
---@field messages AvanteLLMMessage[] ---@field messages AvanteLLMMessage[]
---@field image_paths? string[] ---@field image_paths? string[]
---@field tools? AvanteLLMTool[]
---@field tool_result? AvanteLLMToolResult
---@field tool_use? AvanteLLMToolUse
---@field response_content? string
--- ---
---@class AvanteGeminiMessage ---@class AvanteGeminiMessage
---@field role "user" ---@field role "user"
@ -35,8 +46,9 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput ---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput
--- ---
---@class ResponseParser ---@class ResponseParser
---@field on_chunk fun(chunk: string): any ---@field on_start AvanteLLMStartCallback
---@field on_complete fun(err: string|nil): any ---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback
---@alias AvanteResponseParser fun(ctx: any, data_stream: string, event_state: string, opts: ResponseParser): nil ---@alias AvanteResponseParser fun(ctx: any, data_stream: string, event_state: string, opts: ResponseParser): nil
--- ---
---@class AvanteDefaultBaseProvider: table<string, any> ---@class AvanteDefaultBaseProvider: table<string, any>
@ -54,9 +66,31 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@field temperature? number ---@field temperature? number
---@field max_tokens? number ---@field max_tokens? number
--- ---
---@class AvanteLLMUsage
---@field input_tokens number
---@field cache_creation_input_tokens number
---@field cache_read_input_tokens number
---@field output_tokens number
---
---@class AvanteLLMToolUse
---@field name string
---@field id string
---@field input_json string
---
---@class AvanteLLMStartCallbackOptions
---@field usage? AvanteLLMUsage
---
---@class AvanteLLMStopCallbackOptions
---@field reason "complete" | "tool_use" | "error"
---@field error? string | table
---@field usage? AvanteLLMUsage
---@field tool_use? AvanteLLMToolUse
---@field response_content? string
---
---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil ---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil
---@alias AvanteChunkParser fun(chunk: string): any ---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil
---@alias AvanteCompleteParser fun(err: string|nil): nil ---@alias AvanteLLMChunkCallback fun(chunk: string): any
---@alias AvanteLLMStopCallback fun(opts: AvanteLLMStopCallbackOptions): nil
---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table<string, any> ---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table<string, any>
--- ---
---@class AvanteProvider: AvanteSupportedProvider ---@class AvanteProvider: AvanteSupportedProvider

@ -24,12 +24,72 @@ local P = require("avante.providers")
---@field index integer ---@field index integer
---@field logprobs integer ---@field logprobs integer
--- ---
---@class OpenAIMessageToolCallFunction
---@field name string
---@field arguments string
---
---@class OpenAIMessageToolCall
---@field id string
---@field type "function"
---@field function OpenAIMessageToolCallFunction
---
---@class OpenAIMessage ---@class OpenAIMessage
---@field role? "user" | "system" | "assistant" ---@field role? "user" | "system" | "assistant"
---@field content? string ---@field content? string
---@field reasoning_content? string ---@field reasoning_content? string
---@field reasoning? string ---@field reasoning? string
---@field tool_calls? OpenAIMessageToolCall[]
--- ---
---@class AvanteOpenAITool
---@field type "function"
---@field function AvanteOpenAIToolFunction
---
---@class AvanteOpenAIToolFunction
---@field name string
---@field description string
---@field parameters AvanteOpenAIToolFunctionParameters
---@field strict boolean
---
---@class AvanteOpenAIToolFunctionParameters
---@field type string
---@field properties table<string, AvanteOpenAIToolFunctionParameterProperty>
---@field required string[]
---@field additionalProperties boolean
---
---@class AvanteOpenAIToolFunctionParameterProperty
---@field type string
---@field description string
---@param tool AvanteLLMTool
---@return AvanteOpenAITool
local function transform_tool(tool)
local input_schema_properties = {}
local required = {}
for _, field in ipairs(tool.param.fields) do
input_schema_properties[field.name] = {
type = field.type,
description = field.description,
}
if not field.optional then table.insert(required, field.name) end
end
local res = {
type = "function",
["function"] = {
name = tool.name,
description = tool.description,
},
}
if vim.tbl_count(input_schema_properties) > 0 then
res["function"].parameters = {
type = "object",
properties = input_schema_properties,
required = required,
additionalProperties = false,
}
end
return res
end
---@class AvanteProviderFunctor ---@class AvanteProviderFunctor
local M = {} local M = {}
@ -107,12 +167,34 @@ M.parse_messages = function(opts)
table.insert(final_messages, { role = M.role_map[role] or role, content = message.content }) table.insert(final_messages, { role = M.role_map[role] or role, content = message.content })
end) end)
if opts.tool_result then
table.insert(final_messages, {
role = M.role_map["assistant"],
tool_calls = {
{
id = opts.tool_use.id,
type = "function",
["function"] = {
name = opts.tool_use.name,
arguments = opts.tool_use.input_json,
},
},
},
})
local result_content = opts.tool_result.content or ""
table.insert(final_messages, {
role = "tool",
tool_call_id = opts.tool_result.tool_use_id,
content = opts.tool_result.is_error and "Error: " .. result_content or result_content,
})
end
return final_messages return final_messages
end end
M.parse_response = function(ctx, data_stream, _, opts) M.parse_response = function(ctx, data_stream, _, opts)
if data_stream:match('"%[DONE%]":') then if data_stream:match('"%[DONE%]":') then
opts.on_complete(nil) opts.on_stop({ reason = "complete" })
return return
end end
if data_stream:match('"delta":') then if data_stream:match('"delta":') then
@ -121,7 +203,14 @@ M.parse_response = function(ctx, data_stream, _, opts)
if jsn.choices and jsn.choices[1] then if jsn.choices and jsn.choices[1] then
local choice = jsn.choices[1] local choice = jsn.choices[1]
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
opts.on_complete(nil) opts.on_stop({ reason = "complete" })
elseif choice.finish_reason == "tool_calls" then
opts.on_stop({
reason = "tool_use",
usage = jsn.usage,
tool_use = ctx.tool_use,
response_content = ctx.response_content,
})
elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
ctx.returned_think_start_tag = true ctx.returned_think_start_tag = true
@ -136,6 +225,17 @@ M.parse_response = function(ctx, data_stream, _, opts)
end end
ctx.last_think_content = choice.delta.reasoning ctx.last_think_content = choice.delta.reasoning
opts.on_chunk(choice.delta.reasoning) opts.on_chunk(choice.delta.reasoning)
elseif choice.delta.tool_calls then
local tool_call = choice.delta.tool_calls[1]
if not ctx.tool_use then
ctx.tool_use = {
name = tool_call["function"].name,
id = tool_call.id,
input_json = "",
}
else
ctx.tool_use.input_json = ctx.tool_use.input_json .. tool_call["function"].arguments
end
elseif choice.delta.content then elseif choice.delta.content then
if if
ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag) ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
@ -164,7 +264,7 @@ M.parse_response_without_stream = function(data, _, opts)
local choice = json.choices[1] local choice = json.choices[1]
if choice.message and choice.message.content then if choice.message and choice.message.content then
opts.on_chunk(choice.message.content) opts.on_chunk(choice.message.content)
vim.schedule(function() opts.on_complete(nil) end) vim.schedule(function() opts.on_stop({ reason = "complete" }) end)
end end
end end
end end
@ -198,6 +298,13 @@ M.parse_curl_args = function(provider, code_opts)
body_opts.temperature = 1 body_opts.temperature = 1
end end
local tools = {}
if code_opts.tools then
for _, tool in ipairs(code_opts.tools) do
table.insert(tools, transform_tool(tool))
end
end
Utils.debug("endpoint", base.endpoint) Utils.debug("endpoint", base.endpoint)
Utils.debug("model", base.model) Utils.debug("model", base.model)
@ -210,6 +317,7 @@ M.parse_curl_args = function(provider, code_opts)
model = base.model, model = base.model,
messages = M.parse_messages(code_opts), messages = M.parse_messages(code_opts),
stream = stream, stream = stream,
tools = tools,
}, body_opts), }, body_opts),
} }
end end

@ -157,7 +157,10 @@ function Selection:create_editing_input()
self.prompt_input:start_spinner() self.prompt_input:start_spinner()
---@type AvanteChunkParser ---@type AvanteLLMStartCallback
local on_start = function(start_opts) end
---@type AvanteLLMChunkCallback
local on_chunk = function(chunk) local on_chunk = function(chunk)
full_response = full_response .. chunk full_response = full_response .. chunk
local response_lines_ = vim.split(full_response, "\n") local response_lines_ = vim.split(full_response, "\n")
@ -182,13 +185,15 @@ function Selection:create_editing_input()
finish_line = start_line + #response_lines - 1 finish_line = start_line + #response_lines - 1
end end
---@type AvanteCompleteParser ---@type AvanteLLMStopCallback
local on_complete = function(err) local on_stop = function(stop_opts)
if err then if stop_opts.error then
-- NOTE: in Ubuntu 22.04+ you will see this ignorable error from ~/.local/share/nvim/lazy/avante.nvim/lua/avante/llm.lua `on_error = function(err)`, check to avoid showing this error. -- NOTE: in Ubuntu 22.04+ you will see this ignorable error from ~/.local/share/nvim/lazy/avante.nvim/lua/avante/llm.lua `on_error = function(err)`, check to avoid showing this error.
if type(err) == "table" and err.exit == nil and err.stderr == "{}" then return end if type(stop_opts.error) == "table" and stop_opts.error.exit == nil and stop_opts.error.stderr == "{}" then
return
end
Utils.error( Utils.error(
"Error occurred while processing the response: " .. vim.inspect(err), "Error occurred while processing the response: " .. vim.inspect(stop_opts.error),
{ once = true, title = "Avante" } { once = true, title = "Avante" }
) )
return return
@ -216,8 +221,9 @@ function Selection:create_editing_input()
selected_code = self.selection.content, selected_code = self.selection.content,
instructions = input, instructions = input,
mode = "editing", mode = "editing",
on_start = on_start,
on_chunk = on_chunk, on_chunk = on_chunk,
on_complete = on_complete, on_stop = on_stop,
}) })
end end

@ -13,6 +13,7 @@ local Utils = require("avante.utils")
local Highlights = require("avante.highlights") local Highlights = require("avante.highlights")
local RepoMap = require("avante.repo_map") local RepoMap = require("avante.repo_map")
local FileSelector = require("avante.file_selector") local FileSelector = require("avante.file_selector")
local LLMTools = require("avante.llm_tools")
local RESULT_BUF_NAME = "AVANTE_RESULT" local RESULT_BUF_NAME = "AVANTE_RESULT"
local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated" local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated"
@ -1669,6 +1670,7 @@ function Sidebar:create_input_container(opts)
selected_code = selected_code_content, selected_code = selected_code_content,
instructions = request, instructions = request,
mode = "planning", mode = "planning",
tools = LLMTools.tools,
} }
end end
@ -1755,7 +1757,10 @@ function Sidebar:create_input_container(opts)
vim.keymap.set("n", "k", on_k, { buffer = self.result_container.bufnr }) vim.keymap.set("n", "k", on_k, { buffer = self.result_container.bufnr })
vim.keymap.set("n", "G", on_G, { buffer = self.result_container.bufnr }) vim.keymap.set("n", "G", on_G, { buffer = self.result_container.bufnr })
---@type AvanteChunkParser ---@type AvanteLLMStartCallback
local on_start = function(start_opts) end
---@type AvanteLLMChunkCallback
local on_chunk = function(chunk) local on_chunk = function(chunk)
original_response = original_response .. chunk original_response = original_response .. chunk
@ -1778,8 +1783,8 @@ function Sidebar:create_input_container(opts)
displayed_response = cur_displayed_response displayed_response = cur_displayed_response
end end
---@type AvanteCompleteParser ---@type AvanteLLMStopCallback
local on_complete = function(err) local on_stop = function(stop_opts)
pcall(function() pcall(function()
---remove keymaps ---remove keymaps
vim.keymap.del("n", "j", { buffer = self.result_container.bufnr }) vim.keymap.del("n", "j", { buffer = self.result_container.bufnr })
@ -1787,9 +1792,9 @@ function Sidebar:create_input_container(opts)
vim.keymap.del("n", "G", { buffer = self.result_container.bufnr }) vim.keymap.del("n", "G", { buffer = self.result_container.bufnr })
end) end)
if err ~= nil then if stop_opts.error ~= nil then
self:update_content( self:update_content(
content_prefix .. displayed_response .. "\n\nError: " .. vim.inspect(err), content_prefix .. displayed_response .. "\n\nError: " .. vim.inspect(stop_opts.error),
{ scroll = scroll } { scroll = scroll }
) )
return return
@ -1835,8 +1840,9 @@ function Sidebar:create_input_container(opts)
---@type StreamOptions ---@type StreamOptions
---@diagnostic disable-next-line: assign-type-mismatch ---@diagnostic disable-next-line: assign-type-mismatch
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, { local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
on_start = on_start,
on_chunk = on_chunk, on_chunk = on_chunk,
on_complete = on_complete, on_stop = on_stop,
}) })
Llm.stream(stream_options) Llm.stream(stream_options)

@ -141,8 +141,10 @@ L5: pass
history_messages = history_messages, history_messages = history_messages,
instructions = vim.json.encode(doc), instructions = vim.json.encode(doc),
mode = "suggesting", mode = "suggesting",
on_start = function(_) end,
on_chunk = function(chunk) full_response = full_response .. chunk end, on_chunk = function(chunk) full_response = full_response .. chunk end,
on_complete = function(err) on_stop = function(stop_opts)
local err = stop_opts.error
if err then if err then
Utils.error("Error while suggesting: " .. vim.inspect(err), { once = true, title = "Avante" }) Utils.error("Error while suggesting: " .. vim.inspect(err), { once = true, title = "Avante" })
return return

@ -10,5 +10,13 @@
Act as an expert software developer. Act as an expert software developer.
Always use best practices when coding. Always use best practices when coding.
Respect and use existing conventions, libraries, etc that are already present in the code base. Respect and use existing conventions, libraries, etc that are already present in the code base.
You have access to tools, but only use them when necessary. If a tool is not required, respond as normal.
If you have information that you don't know, please proactively use the tools provided by users! Especially the web search tool.
{% if system_info -%}
Use the appropriate shell based on the user's system info:
{{system_info}}
{%- endif %}
{% block extra_prompt %} {% block extra_prompt %}
{% endblock %} {% endblock %}

@ -47,6 +47,31 @@ M.get_os_name = function()
end end
end end
M.get_system_info = function()
local os_name = vim.loop.os_uname().sysname
local os_version = vim.loop.os_uname().release
local os_machine = vim.loop.os_uname().machine
local lang = os.getenv("LANG")
local res = string.format(
"- Platform: %s-%s-%s\n- Shell: %s\n- Language: %s\n- Current date: %s",
os_name,
os_version,
os_machine,
vim.o.shell,
lang,
os.date("%Y-%m-%d")
)
local project_root = M.root.get()
if project_root then res = res .. string.format("\n- Project root: %s", project_root) end
local is_git_repo = vim.fn.isdirectory(".git") == 1
if is_git_repo then res = res .. "\n- The user is operating inside a git repository" end
return res
end
--- This function will run given shell command synchronously. --- This function will run given shell command synchronously.
---@param input_cmd string ---@param input_cmd string
---@return vim.SystemCompleted ---@return vim.SystemCompleted
@ -622,6 +647,7 @@ function M.parse_gitignore(gitignore_path)
end end
file:close() file:close()
ignore_patterns = vim.list_extend(ignore_patterns, { "%.git", "%.worktree", "__pycache__", "node_modules" })
return ignore_patterns, negate_patterns return ignore_patterns, negate_patterns
end end
@ -635,26 +661,28 @@ function M.is_ignored(file, ignore_patterns, negate_patterns)
return false return false
end end
---@param options { directory: string, add_dirs?: boolean } ---@param options { directory: string, add_dirs?: boolean, depth?: integer }
function M.scan_directory_respect_gitignore(options) function M.scan_directory_respect_gitignore(options)
local directory = options.directory local directory = options.directory
local gitignore_path = directory .. "/.gitignore" local gitignore_path = directory .. "/.gitignore"
local gitignore_patterns, gitignore_negate_patterns = M.parse_gitignore(gitignore_path) local gitignore_patterns, gitignore_negate_patterns = M.parse_gitignore(gitignore_path)
gitignore_patterns = vim.list_extend(gitignore_patterns, { "%.git", "%.worktree", "__pycache__", "node_modules" })
return M.scan_directory({ return M.scan_directory({
directory = directory, directory = directory,
gitignore_patterns = gitignore_patterns, gitignore_patterns = gitignore_patterns,
gitignore_negate_patterns = gitignore_negate_patterns, gitignore_negate_patterns = gitignore_negate_patterns,
add_dirs = options.add_dirs, add_dirs = options.add_dirs,
depth = options.depth,
}) })
end end
---@param options { directory: string, gitignore_patterns: string[], gitignore_negate_patterns: string[], add_dirs?: boolean } ---@param options { directory: string, gitignore_patterns: string[], gitignore_negate_patterns: string[], add_dirs?: boolean, depth?: integer, current_depth?: integer }
function M.scan_directory(options) function M.scan_directory(options)
local directory = options.directory local directory = options.directory
local ignore_patterns = options.gitignore_patterns local ignore_patterns = options.gitignore_patterns
local negate_patterns = options.gitignore_negate_patterns local negate_patterns = options.gitignore_negate_patterns
local add_dirs = options.add_dirs or false local add_dirs = options.add_dirs or false
local depth = options.depth or -1
local current_depth = options.current_depth or 0
local files = {} local files = {}
local handle = vim.loop.fs_scandir(directory) local handle = vim.loop.fs_scandir(directory)
@ -662,6 +690,8 @@ function M.scan_directory(options)
if not handle then return files end if not handle then return files end
while true do while true do
if depth > 0 and current_depth >= depth then break end
local name, type = vim.loop.fs_scandir_next(handle) local name, type = vim.loop.fs_scandir_next(handle)
if not name then break end if not name then break end
@ -677,6 +707,7 @@ function M.scan_directory(options)
gitignore_patterns = ignore_patterns, gitignore_patterns = ignore_patterns,
gitignore_negate_patterns = negate_patterns, gitignore_negate_patterns = negate_patterns,
add_dirs = add_dirs, add_dirs = add_dirs,
current_depth = current_depth + 1,
}) })
) )
elseif type == "file" then elseif type == "file" then

186
tests/llm_tools_spec.lua Normal file

@ -0,0 +1,186 @@
local mock = require("luassert.mock")
local stub = require("luassert.stub")
local LlmTools = require("avante.llm_tools")
local Utils = require("avante.utils")
LlmTools.confirm = function(msg) return true end
describe("llm_tools", function()
local test_dir = "/tmp/test_llm_tools"
local test_file = test_dir .. "/test.txt"
before_each(function()
-- 创建测试目录和文件
os.execute("mkdir -p " .. test_dir)
local file = io.open(test_file, "w")
file:write("test content")
file:close()
-- Mock get_project_root
stub(Utils, "get_project_root", function() return test_dir end)
end)
after_each(function()
-- 清理测试目录
os.execute("rm -rf " .. test_dir)
-- 恢复 mock
Utils.get_project_root:revert()
end)
describe("list_files", function()
it("should list files in directory", function()
local result, err = LlmTools.list_files({ rel_path = ".", depth = 1 })
assert.is_nil(err)
assert.truthy(result:find("test.txt"))
end)
end)
describe("read_file", function()
it("should read file content", function()
local content, err = LlmTools.read_file({ rel_path = "test.txt" })
assert.is_nil(err)
assert.equals("test content", content)
end)
it("should return error for non-existent file", function()
local content, err = LlmTools.read_file({ rel_path = "non_existent.txt" })
assert.truthy(err)
assert.equals("", content)
end)
end)
describe("create_file", function()
it("should create new file", function()
local success, err = LlmTools.create_file({ rel_path = "new_file.txt" })
assert.is_nil(err)
assert.is_true(success)
local file_exists = io.open(test_dir .. "/new_file.txt", "r") ~= nil
assert.is_true(file_exists)
end)
end)
describe("create_dir", function()
it("should create new directory", function()
local success, err = LlmTools.create_dir({ rel_path = "new_dir" })
assert.is_nil(err)
assert.is_true(success)
local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil
assert.is_true(dir_exists)
end)
end)
describe("delete_file", function()
it("should delete existing file", function()
local success, err = LlmTools.delete_file({ rel_path = "test.txt" })
assert.is_nil(err)
assert.is_true(success)
local file_exists = io.open(test_file, "r") ~= nil
assert.is_false(file_exists)
end)
end)
describe("search_files", function()
it("should find files matching pattern", function()
local result, err = LlmTools.search_files({ rel_path = ".", keyword = "test" })
assert.is_nil(err)
assert.truthy(result:find("test.txt"))
end)
end)
describe("search", function()
local original_exepath = vim.fn.exepath
after_each(function() vim.fn.exepath = original_exepath end)
it("should search using ripgrep when available", function()
-- Mock exepath to return rg path
vim.fn.exepath = function(cmd)
if cmd == "rg" then return "/usr/bin/rg" end
return ""
end
-- Create a test file with searchable content
local file = io.open(test_dir .. "/searchable.txt", "w")
file:write("this is searchable content")
file:close()
file = io.open(test_dir .. "/nothing.txt", "w")
file:write("this is nothing")
file:close()
local result, err = LlmTools.search({ rel_path = ".", keyword = "searchable" })
assert.is_nil(err)
assert.truthy(result:find("searchable.txt"))
assert.falsy(result:find("nothing.txt"))
end)
it("should search using ag when rg is not available", function()
-- Mock exepath to return ag path
vim.fn.exepath = function(cmd)
if cmd == "ag" then return "/usr/bin/ag" end
return ""
end
-- Create a test file specifically for ag
local file = io.open(test_dir .. "/ag_test.txt", "w")
file:write("content for ag test")
file:close()
local result, err = LlmTools.search({ rel_path = ".", keyword = "ag test" })
assert.is_nil(err)
assert.is_string(result)
assert.truthy(result:find("ag_test.txt"))
end)
it("should search using grep when rg and ag are not available", function()
-- Mock exepath to return grep path
vim.fn.exepath = function(cmd)
if cmd == "grep" then return "/usr/bin/grep" end
return ""
end
local result, err = LlmTools.search({ rel_path = ".", keyword = "test" })
assert.is_nil(err)
assert.truthy(result:find("test.txt"))
end)
it("should return error when no search tool is available", function()
-- Mock exepath to return nothing
vim.fn.exepath = function() return "" end
local result, err = LlmTools.search({ rel_path = ".", keyword = "test" })
assert.equals("", result)
assert.equals("No search command found", err)
end)
it("should respect path permissions", function()
local result, err = LlmTools.search({ rel_path = "../outside_project", keyword = "test" })
assert.truthy(err:find("No permission to access path"))
end)
it("should handle non-existent paths", function()
local result, err = LlmTools.search({ rel_path = "non_existent_dir", keyword = "test" })
assert.equals("", result)
assert.truthy(err)
assert.truthy(err:find("No such file or directory"))
end)
end)
describe("run_command", function()
it("should execute command and return output", function()
local result, err = LlmTools.run_command({ rel_path = ".", command = "echo 'test'" })
assert.is_nil(err)
assert.equals("test\n", result)
end)
it("should return error when running outside current directory", function()
local result, err = LlmTools.run_command({ rel_path = "../outside_project", command = "echo 'test'" })
assert.is_false(result)
assert.truthy(err)
assert.truthy(err:find("No permission to access path"))
end)
end)
end)