feat: tools (#1180)

* feat: tools

* feat: claude use tools

* feat: openai use tools
This commit is contained in:
yetone 2025-02-05 22:39:54 +08:00 committed by GitHub
parent 1726d32778
commit 1437f319d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1321 additions and 74 deletions

View File

@ -38,6 +38,9 @@ jobs:
mkdir -p _neovim
curl -sL "https://github.com/neovim/neovim/releases/download/${{ matrix.rev }}" | tar xzf - --strip-components=1 -C "${PWD}/_neovim"
}
sudo apt-get update
sudo apt-get install -y ripgrep
sudo apt-get install -y silversearcher-ag
- name: Run tests
run: |

View File

@ -31,6 +31,7 @@ struct TemplateContext {
selected_code: Option<String>,
project_context: Option<String>,
diagnostics: Option<String>,
system_info: Option<String>,
}
// Given the file name registered after add, the context table in Lua, resulted in a formatted
@ -54,6 +55,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
selected_code => context.selected_code,
project_context => context.project_context,
diagnostics => context.diagnostics,
system_info => context.system_info,
})
.map_err(LuaError::external)
.unwrap())

View File

@ -20,6 +20,14 @@ M._defaults = {
-- For most providers that we support we will determine this automatically.
-- If you wish to use a given implementation, then you can override it here.
tokenizer = "tiktoken",
web_search_engine = {
provider = "tavily",
api_key_name = "TAVILY_API_KEY",
provider_opts = {
time_range = "d",
include_answer = "basic",
},
},
---@type AvanteSupportedProvider
openai = {
endpoint = "https://api.openai.com/v1",

View File

@ -8,6 +8,7 @@ local Utils = require("avante.utils")
local Config = require("avante.config")
local Path = require("avante.path")
local P = require("avante.providers")
local LLMTools = require("avante.llm_tools")
---@class avante.LLM
local M = {}
@ -45,6 +46,8 @@ M.generate_prompts = function(opts)
local project_root = Utils.root.get()
Path.prompts.initialize(Path.prompts.get(project_root))
local system_info = Utils.get_system_info()
local template_opts = {
use_xml_format = Provider.use_xml_format,
ask = opts.ask, -- TODO: add mode without ask instruction
@ -53,6 +56,7 @@ M.generate_prompts = function(opts)
selected_code = opts.selected_code,
project_context = opts.project_context,
diagnostics = opts.diagnostics,
system_info = system_info,
}
local system_prompt = Path.prompts.render_mode(mode, template_opts)
@ -111,6 +115,10 @@ M.generate_prompts = function(opts)
system_prompt = system_prompt,
messages = messages,
image_paths = image_paths,
tools = opts.tools,
tool_use = opts.tool_use,
tool_result = opts.tool_result,
response_content = opts.response_content,
}
end
@ -135,7 +143,28 @@ M._stream = function(opts)
local current_event_state = nil
---@type AvanteHandlerOptions
local handler_opts = { on_chunk = opts.on_chunk, on_complete = opts.on_complete }
local handler_opts = {
on_start = opts.on_start,
on_chunk = opts.on_chunk,
on_stop = function(stop_opts)
if stop_opts.reason == "tool_use" and stop_opts.tool_use then
local result, error = LLMTools.process_tool_use(stop_opts.tool_use)
local tool_result = {
tool_use_id = stop_opts.tool_use.id,
content = error ~= nil and error or result,
is_error = error ~= nil,
}
local new_opts = vim.tbl_deep_extend(
"force",
opts,
{ tool_result = tool_result, tool_use = stop_opts.tool_use, response_content = stop_opts.response_content }
)
return M._stream(new_opts)
end
return opts.on_stop(stop_opts)
end,
}
---@type AvanteCurlOutput
local spec = Provider.parse_curl_args(Provider, code_opts)
@ -180,7 +209,7 @@ M._stream = function(opts)
stream = function(err, data, _)
if err then
completed = true
opts.on_complete(err)
handler_opts.on_stop({ reason = "error", error = err })
return
end
if not data then return end
@ -224,7 +253,7 @@ M._stream = function(opts)
active_job = nil
completed = true
cleanup()
opts.on_complete(err)
handler_opts.on_stop({ reason = "error", error = err })
end,
callback = function(result)
active_job = nil
@ -238,9 +267,10 @@ M._stream = function(opts)
vim.schedule(function()
if not completed then
completed = true
opts.on_complete(
"API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body)
)
handler_opts.on_stop({
reason = "error",
error = "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body),
})
end
end)
end
@ -335,9 +365,9 @@ M._dual_boost_stream = function(opts, Provider1, Provider2)
on_chunk = function(chunk)
if chunk then response = response .. chunk end
end,
on_complete = function(err)
if err then
Utils.error(string.format("Stream %d failed: %s", index, err))
on_stop = function(stop_opts)
if stop_opts.error then
Utils.error(string.format("Stream %d failed: %s", index, stop_opts.error))
return
end
Utils.debug(string.format("Response %d completed", index))
@ -381,10 +411,15 @@ end
---@field instructions string
---@field mode LlmMode
---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil
---@field tools? AvanteLLMTool[]
---@field tool_result? AvanteLLMToolResult
---@field tool_use? AvanteLLMToolUse
---@field response_content? string
---
---@class StreamOptions: GeneratePromptsOptions
---@field on_chunk AvanteChunkParser
---@field on_complete AvanteCompleteParser
---@field on_start AvanteLLMStartCallback
---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback
---@param opts StreamOptions
M.stream = function(opts)
@ -396,12 +431,12 @@ M.stream = function(opts)
return original_on_chunk(chunk)
end)
end
if opts.on_complete ~= nil then
local original_on_complete = opts.on_complete
opts.on_complete = vim.schedule_wrap(function(err)
if opts.on_stop ~= nil then
local original_on_stop = opts.on_stop
opts.on_stop = vim.schedule_wrap(function(stop_opts)
if is_completed then return end
is_completed = true
return original_on_complete(err)
if stop_opts.reason == "complete" or stop_opts.reason == "error" then is_completed = true end
return original_on_stop(stop_opts)
end)
end
if Config.dual_boost.enabled then

714
lua/avante/llm_tools.lua Normal file
View File

@ -0,0 +1,714 @@
local curl = require("plenary.curl")
local Utils = require("avante.utils")
local Path = require("plenary.path")
local Config = require("avante.config")
local M = {}
---@param rel_path string
---@return string
local function get_abs_path(rel_path)
local project_root = Utils.get_project_root()
return Path:new(project_root):joinpath(rel_path):absolute()
end
function M.comfirm(msg)
local ok = vim.fn.confirm(msg, "&Yes\n&No", 2)
return ok == 1
end
---@param abs_path string
---@return boolean
local function has_permission_to_access(abs_path)
if not Path:new(abs_path):is_absolute() then return false end
local project_root = Utils.get_project_root()
if abs_path:sub(1, #project_root) ~= project_root then return false end
local gitignore_path = project_root .. "/.gitignore"
local gitignore_patterns, gitignore_negate_patterns = Utils.parse_gitignore(gitignore_path)
return not Utils.is_ignored(abs_path, gitignore_patterns, gitignore_negate_patterns)
end
---@param opts { rel_path: string, depth?: integer }
---@return string files
---@return string|nil error
function M.list_files(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
local files = Utils.scan_directory_respect_gitignore({
directory = abs_path,
add_dirs = true,
depth = opts.depth,
})
local result = ""
for _, file in ipairs(files) do
local uniform_path = Utils.uniform_path(file)
result = result .. uniform_path .. "\n"
end
result = result:gsub("\n$", "")
return result, nil
end
---@param opts { rel_path: string, keyword: string }
---@return string files
---@return string|nil error
function M.search_files(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
local files = Utils.scan_directory_respect_gitignore({
directory = abs_path,
})
local result = ""
for _, file in ipairs(files) do
if file:find(opts.keyword) then result = result .. file .. "\n" end
end
result = result:gsub("\n$", "")
return result, nil
end
---@param opts { rel_path: string, keyword: string }
---@return string result
---@return string|nil error
function M.search(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return "", "No such file or directory: " .. abs_path end
---check if any search cmd is available
local search_cmd = vim.fn.exepath("rg")
if search_cmd == "" then search_cmd = vim.fn.exepath("ag") end
if search_cmd == "" then search_cmd = vim.fn.exepath("ack") end
if search_cmd == "" then search_cmd = vim.fn.exepath("grep") end
if search_cmd == "" then return "", "No search command found" end
---execute the search command
local cmd = ""
if search_cmd:find("rg") then
cmd = string.format("%s --files-with-matches --no-ignore-vcs --ignore-case --hidden --glob '!.git'", search_cmd)
cmd = string.format("%s '%s' %s", cmd, opts.keyword, abs_path)
elseif search_cmd:find("ag") then
cmd = string.format("%s '%s' --nocolor --nogroup --hidden --ignore .git %s", search_cmd, opts.keyword, abs_path)
elseif search_cmd:find("ack") then
cmd = string.format("%s --nocolor --nogroup --hidden --ignore-dir .git", search_cmd)
cmd = string.format("%s '%s' %s", cmd, opts.keyword, abs_path)
elseif search_cmd:find("grep") then
cmd = string.format("%s -riH --exclude-dir=.git %s %s", search_cmd, opts.keyword, abs_path)
end
Utils.debug("cmd", cmd)
local result = vim.fn.system(cmd)
return result or "", nil
end
---@param opts { rel_path: string }
---@return string definitions
---@return string|nil error
function M.read_file_toplevel_symbols(opts)
local RepoMap = require("avante.repo_map")
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
local filetype = RepoMap.get_ts_lang(abs_path)
local repo_map_lib = RepoMap._init_repo_map_lib()
if not repo_map_lib then return "", "Failed to load avante_repo_map" end
local definitions = filetype
and repo_map_lib.stringify_definitions(filetype, Utils.file.read_content(abs_path) or "")
or ""
return definitions, nil
end
---@param opts { rel_path: string }
---@return string content
---@return string|nil error
function M.read_file(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end
local file = io.open(abs_path, "r")
if not file then return "", "file not found: " .. abs_path end
local content = file:read("*a")
file:close()
return content, nil
end
---@param opts { rel_path: string }
---@return boolean success
---@return string|nil error
function M.create_file(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
---create directory if it doesn't exist
local dir = Path:new(abs_path):parent()
if not dir:exists() then dir:mkdir({ parents = true }) end
---create file if it doesn't exist
if not dir:joinpath(opts.rel_path):exists() then
local file = io.open(abs_path, "w")
if not file then return false, "file not found: " .. abs_path end
file:close()
end
return true, nil
end
---@param opts { rel_path: string, new_rel_path: string }
---@return boolean success
---@return string|nil error
function M.rename_file(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
local new_abs_path = get_abs_path(opts.new_rel_path)
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end
if not M.confirm("Are you sure you want to rename the file: " .. abs_path .. " to: " .. new_abs_path) then
return false, "User canceled"
end
os.rename(abs_path, new_abs_path)
return true, nil
end
---@param opts { rel_path: string, new_rel_path: string }
---@return boolean success
---@return string|nil error
function M.copy_file(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
local new_abs_path = get_abs_path(opts.new_rel_path)
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end
Path:new(new_abs_path):write(Path:new(abs_path):read())
return true, nil
end
---@param opts { rel_path: string }
---@return boolean success
---@return string|nil error
function M.delete_file(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
if not M.confirm("Are you sure you want to delete the file: " .. abs_path) then return false, "User canceled" end
os.remove(abs_path)
return true, nil
end
---@param opts { rel_path: string }
---@return boolean success
---@return string|nil error
function M.create_dir(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
Path:new(abs_path):mkdir({ parents = true })
return true, nil
end
---@param opts { rel_path: string, new_rel_path: string }
---@return boolean success
---@return string|nil error
function M.rename_dir(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end
local new_abs_path = get_abs_path(opts.new_rel_path)
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
if Path:new(new_abs_path):exists() then return false, "Directory already exists: " .. new_abs_path end
if not M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?") then
return false, "User canceled"
end
os.rename(abs_path, new_abs_path)
return true, nil
end
---@param opts { rel_path: string }
---@return boolean success
---@return string|nil error
function M.delete_dir(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end
if not M.confirm("Are you sure you want to delete the directory: " .. abs_path) then
return false, "User canceled"
end
os.remove(abs_path)
return true, nil
end
---@param opts { rel_path: string, command: string }
---@return string|boolean result
---@return string|nil error
function M.run_command(opts)
local abs_path = get_abs_path(opts.rel_path)
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
if
not M.confirm("Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path)
then
return false, "User canceled"
end
---change cwd to abs_path
local old_cwd = vim.fn.getcwd()
vim.fn.chdir(abs_path)
local res = Utils.shell_run(opts.command)
vim.fn.chdir(old_cwd)
if res.code ~= 0 then
if res.stdout then return false, "Error: " .. res.stdout .. "; Error code: " .. tostring(res.code) end
return false, "Error code: " .. tostring(res.code)
end
return res.stdout, nil
end
---@param opts { query: string }
---@return string|nil result
---@return string|nil error
function M.web_search(opts)
local search_engine = Config.web_search_engine
if search_engine.provider == "tavily" then
if search_engine.api_key_name == "" then return nil, "No API key provided" end
local api_key = os.getenv(search_engine.api_key_name)
if api_key == nil or api_key == "" then
return nil, "Environment variable " .. search_engine.api_key_name .. " is not set"
end
local resp = curl.post("https://api.tavily.com/search", {
headers = {
["Content-Type"] = "application/json",
["Authorization"] = "Bearer " .. api_key,
},
body = vim.json.encode(vim.tbl_deep_extend("force", {
query = opts.query,
}, search_engine.provider_opts)),
})
if resp.status ~= 200 then return nil, "Error: " .. resp.body end
local jsn = vim.json.decode(resp.body)
return jsn.anwser, nil
end
end
---@class AvanteLLMTool
---@field name string
---@field description string
---@field param AvanteLLMToolParam
---@field returns AvanteLLMToolReturn[]
---@class AvanteLLMToolParam
---@field type string
---@field fields AvanteLLMToolParamField[]
---@class AvanteLLMToolParamField
---@field name string
---@field description string
---@field type string
---@field optional? boolean
---@class AvanteLLMToolReturn
---@field name string
---@field description string
---@field type string
---@field optional? boolean
---@type AvanteLLMTool[]
M.tools = {
{
name = "list_files",
description = "List files in a directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
{
name = "depth",
description = "Depth of the directory",
type = "integer",
optional = true,
},
},
},
returns = {
{
name = "files",
description = "List of files in the directory",
type = "string[]",
},
{
name = "error",
description = "Error message if the directory was not listed successfully",
type = "string",
optional = true,
},
},
},
{
name = "search_files",
description = "Search for files in a directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
{
name = "keyword",
description = "Keyword to search for",
type = "string",
},
},
},
returns = {
{
name = "files",
description = "List of files that match the keyword",
type = "string",
},
{
name = "error",
description = "Error message if the directory was not searched successfully",
type = "string",
optional = true,
},
},
},
{
name = "search",
description = "Search for a keyword in a directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
{
name = "keyword",
description = "Keyword to search for",
type = "string",
},
},
},
returns = {
{
name = "files",
description = "List of files that match the keyword",
type = "string",
},
{
name = "error",
description = "Error message if the directory was not searched successfully",
type = "string",
optional = true,
},
},
},
{
name = "read_file_toplevel_symbols",
description = "Read the top-level symbols of a file",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
type = "string",
},
},
},
returns = {
{
name = "definitions",
description = "Top-level symbols of the file",
type = "string",
},
{
name = "error",
description = "Error message if the file was not read successfully",
type = "string",
optional = true,
},
},
},
{
name = "read_file",
description = "Read the contents of a file",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
type = "string",
},
},
},
returns = {
{
name = "content",
description = "Contents of the file",
type = "string",
},
{
name = "error",
description = "Error message if the file was not read successfully",
type = "string",
optional = true,
},
},
},
{
name = "create_file",
description = "Create a new file",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the file was created successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the file was not created successfully",
type = "string",
optional = true,
},
},
},
{
name = "rename_file",
description = "Rename a file",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
type = "string",
},
{
name = "new_rel_path",
description = "New relative path for the file",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the file was renamed successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the file was not renamed successfully",
type = "string",
optional = true,
},
},
},
{
name = "delete_file",
description = "Delete a file",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the file",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the file was deleted successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the file was not deleted successfully",
type = "string",
optional = true,
},
},
},
{
name = "create_dir",
description = "Create a new directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the directory was created successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the directory was not created successfully",
type = "string",
optional = true,
},
},
},
{
name = "rename_dir",
description = "Rename a directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
{
name = "new_rel_path",
description = "New relative path for the directory",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the directory was renamed successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the directory was not renamed successfully",
type = "string",
optional = true,
},
},
},
{
name = "delete_dir",
description = "Delete a directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
},
},
returns = {
{
name = "success",
description = "True if the directory was deleted successfully, false otherwise",
type = "boolean",
},
{
name = "error",
description = "Error message if the directory was not deleted successfully",
type = "string",
optional = true,
},
},
},
{
name = "run_command",
description = "Run a command in a directory",
param = {
type = "table",
fields = {
{
name = "rel_path",
description = "Relative path to the directory",
type = "string",
},
{
name = "command",
description = "Command to run",
type = "string",
},
},
},
returns = {
{
name = "stdout",
description = "Output of the command",
type = "string",
},
{
name = "error",
description = "Error message if the command was not run successfully",
type = "string",
optional = true,
},
},
},
{
name = "web_search",
description = "Search the web",
param = {
type = "table",
fields = {
{
name = "query",
description = "Query to search",
type = "string",
},
},
},
returns = {
{
name = "result",
description = "Result of the search",
type = "string",
},
{
name = "error",
description = "Error message if the search was not successful",
type = "string",
optional = true,
},
},
},
}
---@param tool_use AvanteLLMToolUse
---@return string | nil result
---@return string | nil error
function M.process_tool_use(tool_use)
Utils.debug("use tool", tool_use.name, tool_use.input_json)
local tool = vim.iter(M.tools):find(function(tool) return tool.name == tool_use.name end)
if tool == nil then return end
local input_json = vim.json.decode(tool_use.input_json)
local func = M[tool.name]
local result, error = func(input_json)
-- Utils.debug("result", result)
-- Utils.debug("error", error)
if result ~= nil and type(result) ~= "string" then result = vim.json.encode(result) end
return result, error
end
return M

View File

@ -6,6 +6,8 @@
---@field role "user" | "assistant"
---@field content [AvanteBedrockClaudeTextMessage][]
local Claude = require("avante.providers.claude")
---@class AvanteBedrockModelHandler
local M = {}
@ -33,25 +35,7 @@ M.parse_messages = function(opts)
return messages
end
M.parse_response = function(ctx, data_stream, event_state, opts)
if event_state == nil then
if data_stream:match('"content_block_delta"') then
event_state = "content_block_delta"
elseif data_stream:match('"message_stop"') then
event_state = "message_stop"
end
end
if event_state == "content_block_delta" then
local ok, json = pcall(vim.json.decode, data_stream)
if not ok then return end
opts.on_chunk(json.delta.text)
elseif event_state == "message_stop" then
opts.on_complete(nil)
return
elseif event_state == "error" then
opts.on_complete(vim.json.decode(data_stream))
end
end
M.parse_response = Claude.parse_response
---@param prompt_opts AvantePromptOptions
---@param body_opts table
@ -60,7 +44,6 @@ M.build_bedrock_payload = function(prompt_opts, body_opts)
local system_prompt = prompt_opts.system_prompt or ""
local messages = M.parse_messages(prompt_opts)
local max_tokens = body_opts.max_tokens or 2000
local temperature = body_opts.temperature or 0.7
local payload = {
anthropic_version = "bedrock-2023-05-31",
max_tokens = max_tokens,

View File

@ -17,6 +17,44 @@ local P = require("avante.providers")
---@field role "user" | "assistant"
---@field content [AvanteClaudeTextMessage | AvanteClaudeImageMessage][]
---@class AvanteClaudeTool
---@field name string
---@field description string
---@field input_schema AvanteClaudeToolInputSchema
---@class AvanteClaudeToolInputSchema
---@field type "object"
---@field properties table<string, AvanteClaudeToolInputSchemaProperty>
---@field required string[]
---@class AvanteClaudeToolInputSchemaProperty
---@field type "string" | "number" | "boolean"
---@field description string
---@field enum? string[]
---@param tool AvanteLLMTool
---@return AvanteClaudeTool
local function transform_tool(tool)
local input_schema_properties = {}
local required = {}
for _, field in ipairs(tool.param.fields) do
input_schema_properties[field.name] = {
type = field.type,
description = field.description,
}
if not field.optional then table.insert(required, field.name) end
end
return {
name = tool.name,
description = tool.description,
input_schema = {
type = "object",
properties = input_schema_properties,
required = required,
},
}
end
---@class AvanteProviderFunctor
local M = {}
@ -74,26 +112,101 @@ M.parse_messages = function(opts)
messages[#messages].content = message_content
end
if opts.tool_use then
local msg = {
role = "assistant",
content = {},
}
if opts.response_content then
msg.content[#msg.content + 1] = {
type = "text",
text = opts.response_content,
}
end
msg.content[#msg.content + 1] = {
type = "tool_use",
id = opts.tool_use.id,
name = opts.tool_use.name,
input = vim.json.decode(opts.tool_use.input_json),
}
messages[#messages + 1] = msg
end
if opts.tool_result then
messages[#messages + 1] = {
role = "user",
content = {
{
type = "tool_result",
tool_use_id = opts.tool_result.tool_use_id,
content = opts.tool_result.content,
is_error = opts.tool_result.is_error,
},
},
}
end
return messages
end
M.parse_response = function(ctx, data_stream, event_state, opts)
if event_state == nil then
if data_stream:match('"content_block_delta"') then
event_state = "content_block_delta"
if data_stream:match('"message_start"') then
event_state = "message_start"
elseif data_stream:match('"message_delta"') then
event_state = "message_delta"
elseif data_stream:match('"message_stop"') then
event_state = "message_stop"
elseif data_stream:match('"content_block_start"') then
event_state = "content_block_start"
elseif data_stream:match('"content_block_delta"') then
event_state = "content_block_delta"
elseif data_stream:match('"content_block_stop"') then
event_state = "content_block_stop"
end
end
if event_state == "content_block_delta" then
local ok, json = pcall(vim.json.decode, data_stream)
if event_state == "message_start" then
local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then return end
opts.on_chunk(json.delta.text)
elseif event_state == "message_stop" then
opts.on_complete(nil)
opts.on_start(jsn.message.usage)
elseif event_state == "content_block_start" then
local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then return end
if jsn.content_block.type == "tool_use" then
ctx.tool_use = {
name = jsn.content_block.name,
id = jsn.content_block.id,
input_json = "",
}
elseif jsn.content_block.type == "text" then
ctx.response_content = ""
end
elseif event_state == "content_block_delta" then
local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then return end
if ctx.tool_use and jsn.delta.type == "input_json_delta" then
ctx.tool_use.input_json = ctx.tool_use.input_json .. jsn.delta.partial_json
return
elseif ctx.response_content and jsn.delta.type == "text_delta" then
ctx.response_content = ctx.response_content .. jsn.delta.text
end
opts.on_chunk(jsn.delta.text)
elseif event_state == "message_delta" then
local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then return end
if jsn.delta.stop_reason == "end_turn" then
opts.on_stop({ reason = "complete", usage = jsn.usage })
elseif jsn.delta.stop_reason == "tool_use" then
opts.on_stop({
reason = "tool_use",
usage = jsn.usage,
tool_use = ctx.tool_use,
response_content = ctx.response_content,
})
end
return
elseif event_state == "error" then
opts.on_complete(vim.json.decode(data_stream))
opts.on_stop({ reason = "error", error = vim.json.decode(data_stream) })
end
end
@ -113,6 +226,13 @@ M.parse_curl_args = function(provider, prompt_opts)
local messages = M.parse_messages(prompt_opts)
local tools = {}
if prompt_opts.tools then
for _, tool in ipairs(prompt_opts.tools) do
table.insert(tools, transform_tool(tool))
end
end
return {
url = Utils.url_join(base.endpoint, "/v1/messages"),
proxy = base.proxy,
@ -128,6 +248,7 @@ M.parse_curl_args = function(provider, prompt_opts)
},
},
messages = messages,
tools = tools,
stream = true,
}, body_opts),
}

View File

@ -62,7 +62,7 @@ M.parse_stream_data = function(data, opts)
local json = vim.json.decode(data)
if json.type ~= nil then
if json.type == "message-end" and json.delta.finish_reason == "COMPLETE" then
opts.on_complete(nil)
opts.on_stop({ reason = "complete" })
return
end
if json.type == "content-delta" then opts.on_chunk(json.delta.message.content.text) end

View File

@ -66,17 +66,17 @@ end
M.parse_response = function(ctx, data_stream, _, opts)
local ok, json = pcall(vim.json.decode, data_stream)
if not ok then opts.on_complete(json) end
if not ok then opts.on_stop({ reason = "error", error = json }) end
if json.candidates then
if #json.candidates > 0 then
if json.candidates[1].finishReason and json.candidates[1].finishReason == "STOP" then
opts.on_chunk(json.candidates[1].content.parts[1].text)
opts.on_complete(nil)
opts.on_stop({ reason = "complete" })
else
opts.on_chunk(json.candidates[1].content.parts[1].text)
end
else
opts.on_complete(nil)
opts.on_stop({ reason = "complete" })
end
end
end

View File

@ -11,17 +11,28 @@ local DressingConfig = {
local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@class AvanteHandlerOptions: table<[string], string>
---@field on_chunk AvanteChunkParser
---@field on_complete AvanteCompleteParser
---@field on_start AvanteLLMStartCallback
---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback
---
---@class AvanteLLMMessage
---@field role "user" | "assistant"
---@field content string
---
---@class AvanteLLMToolResult
---@field tool_name string
---@field tool_use_id string
---@field content string
---@field is_error? boolean
---
---@class AvantePromptOptions: table<[string], string>
---@field system_prompt string
---@field messages AvanteLLMMessage[]
---@field image_paths? string[]
---@field tools? AvanteLLMTool[]
---@field tool_result? AvanteLLMToolResult
---@field tool_use? AvanteLLMToolUse
---@field response_content? string
---
---@class AvanteGeminiMessage
---@field role "user"
@ -35,8 +46,9 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput
---
---@class ResponseParser
---@field on_chunk fun(chunk: string): any
---@field on_complete fun(err: string|nil): any
---@field on_start AvanteLLMStartCallback
---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback
---@alias AvanteResponseParser fun(ctx: any, data_stream: string, event_state: string, opts: ResponseParser): nil
---
---@class AvanteDefaultBaseProvider: table<string, any>
@ -54,9 +66,31 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@field temperature? number
---@field max_tokens? number
---
---@class AvanteLLMUsage
---@field input_tokens number
---@field cache_creation_input_tokens number
---@field cache_read_input_tokens number
---@field output_tokens number
---
---@class AvanteLLMToolUse
---@field name string
---@field id string
---@field input_json string
---
---@class AvanteLLMStartCallbackOptions
---@field usage? AvanteLLMUsage
---
---@class AvanteLLMStopCallbackOptions
---@field reason "complete" | "tool_use" | "error"
---@field error? string | table
---@field usage? AvanteLLMUsage
---@field tool_use? AvanteLLMToolUse
---@field response_content? string
---
---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil
---@alias AvanteChunkParser fun(chunk: string): any
---@alias AvanteCompleteParser fun(err: string|nil): nil
---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil
---@alias AvanteLLMChunkCallback fun(chunk: string): any
---@alias AvanteLLMStopCallback fun(opts: AvanteLLMStopCallbackOptions): nil
---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table<string, any>
---
---@class AvanteProvider: AvanteSupportedProvider

View File

@ -24,12 +24,72 @@ local P = require("avante.providers")
---@field index integer
---@field logprobs integer
---
---@class OpenAIMessageToolCallFunction
---@field name string
---@field arguments string
---
---@class OpenAIMessageToolCall
---@field id string
---@field type "function"
---@field function OpenAIMessageToolCallFunction
---
---@class OpenAIMessage
---@field role? "user" | "system" | "assistant"
---@field content? string
---@field reasoning_content? string
---@field reasoning? string
---@field tool_calls? OpenAIMessageToolCall[]
---
---@class AvanteOpenAITool
---@field type "function"
---@field function AvanteOpenAIToolFunction
---
---@class AvanteOpenAIToolFunction
---@field name string
---@field description string
---@field parameters AvanteOpenAIToolFunctionParameters
---@field strict boolean
---
---@class AvanteOpenAIToolFunctionParameters
---@field type string
---@field properties table<string, AvanteOpenAIToolFunctionParameterProperty>
---@field required string[]
---@field additionalProperties boolean
---
---@class AvanteOpenAIToolFunctionParameterProperty
---@field type string
---@field description string
---@param tool AvanteLLMTool
---@return AvanteOpenAITool
local function transform_tool(tool)
local input_schema_properties = {}
local required = {}
for _, field in ipairs(tool.param.fields) do
input_schema_properties[field.name] = {
type = field.type,
description = field.description,
}
if not field.optional then table.insert(required, field.name) end
end
local res = {
type = "function",
["function"] = {
name = tool.name,
description = tool.description,
},
}
if vim.tbl_count(input_schema_properties) > 0 then
res["function"].parameters = {
type = "object",
properties = input_schema_properties,
required = required,
additionalProperties = false,
}
end
return res
end
---@class AvanteProviderFunctor
local M = {}
@ -107,12 +167,34 @@ M.parse_messages = function(opts)
table.insert(final_messages, { role = M.role_map[role] or role, content = message.content })
end)
if opts.tool_result then
table.insert(final_messages, {
role = M.role_map["assistant"],
tool_calls = {
{
id = opts.tool_use.id,
type = "function",
["function"] = {
name = opts.tool_use.name,
arguments = opts.tool_use.input_json,
},
},
},
})
local result_content = opts.tool_result.content or ""
table.insert(final_messages, {
role = "tool",
tool_call_id = opts.tool_result.tool_use_id,
content = opts.tool_result.is_error and "Error: " .. result_content or result_content,
})
end
return final_messages
end
M.parse_response = function(ctx, data_stream, _, opts)
if data_stream:match('"%[DONE%]":') then
opts.on_complete(nil)
opts.on_stop({ reason = "complete" })
return
end
if data_stream:match('"delta":') then
@ -121,7 +203,14 @@ M.parse_response = function(ctx, data_stream, _, opts)
if jsn.choices and jsn.choices[1] then
local choice = jsn.choices[1]
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
opts.on_complete(nil)
opts.on_stop({ reason = "complete" })
elseif choice.finish_reason == "tool_calls" then
opts.on_stop({
reason = "tool_use",
usage = jsn.usage,
tool_use = ctx.tool_use,
response_content = ctx.response_content,
})
elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
ctx.returned_think_start_tag = true
@ -136,6 +225,17 @@ M.parse_response = function(ctx, data_stream, _, opts)
end
ctx.last_think_content = choice.delta.reasoning
opts.on_chunk(choice.delta.reasoning)
elseif choice.delta.tool_calls then
local tool_call = choice.delta.tool_calls[1]
if not ctx.tool_use then
ctx.tool_use = {
name = tool_call["function"].name,
id = tool_call.id,
input_json = "",
}
else
ctx.tool_use.input_json = ctx.tool_use.input_json .. tool_call["function"].arguments
end
elseif choice.delta.content then
if
ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
@ -164,7 +264,7 @@ M.parse_response_without_stream = function(data, _, opts)
local choice = json.choices[1]
if choice.message and choice.message.content then
opts.on_chunk(choice.message.content)
vim.schedule(function() opts.on_complete(nil) end)
vim.schedule(function() opts.on_stop({ reason = "complete" }) end)
end
end
end
@ -198,6 +298,13 @@ M.parse_curl_args = function(provider, code_opts)
body_opts.temperature = 1
end
local tools = {}
if code_opts.tools then
for _, tool in ipairs(code_opts.tools) do
table.insert(tools, transform_tool(tool))
end
end
Utils.debug("endpoint", base.endpoint)
Utils.debug("model", base.model)
@ -210,6 +317,7 @@ M.parse_curl_args = function(provider, code_opts)
model = base.model,
messages = M.parse_messages(code_opts),
stream = stream,
tools = tools,
}, body_opts),
}
end

View File

@ -157,7 +157,10 @@ function Selection:create_editing_input()
self.prompt_input:start_spinner()
---@type AvanteChunkParser
---@type AvanteLLMStartCallback
local on_start = function(start_opts) end
---@type AvanteLLMChunkCallback
local on_chunk = function(chunk)
full_response = full_response .. chunk
local response_lines_ = vim.split(full_response, "\n")
@ -182,13 +185,15 @@ function Selection:create_editing_input()
finish_line = start_line + #response_lines - 1
end
---@type AvanteCompleteParser
local on_complete = function(err)
if err then
---@type AvanteLLMStopCallback
local on_stop = function(stop_opts)
if stop_opts.error then
-- NOTE: in Ubuntu 22.04+ you will see this ignorable error from ~/.local/share/nvim/lazy/avante.nvim/lua/avante/llm.lua `on_error = function(err)`, check to avoid showing this error.
if type(err) == "table" and err.exit == nil and err.stderr == "{}" then return end
if type(stop_opts.error) == "table" and stop_opts.error.exit == nil and stop_opts.error.stderr == "{}" then
return
end
Utils.error(
"Error occurred while processing the response: " .. vim.inspect(err),
"Error occurred while processing the response: " .. vim.inspect(stop_opts.error),
{ once = true, title = "Avante" }
)
return
@ -216,8 +221,9 @@ function Selection:create_editing_input()
selected_code = self.selection.content,
instructions = input,
mode = "editing",
on_start = on_start,
on_chunk = on_chunk,
on_complete = on_complete,
on_stop = on_stop,
})
end

View File

@ -13,6 +13,7 @@ local Utils = require("avante.utils")
local Highlights = require("avante.highlights")
local RepoMap = require("avante.repo_map")
local FileSelector = require("avante.file_selector")
local LLMTools = require("avante.llm_tools")
local RESULT_BUF_NAME = "AVANTE_RESULT"
local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated"
@ -1669,6 +1670,7 @@ function Sidebar:create_input_container(opts)
selected_code = selected_code_content,
instructions = request,
mode = "planning",
tools = LLMTools.tools,
}
end
@ -1755,7 +1757,10 @@ function Sidebar:create_input_container(opts)
vim.keymap.set("n", "k", on_k, { buffer = self.result_container.bufnr })
vim.keymap.set("n", "G", on_G, { buffer = self.result_container.bufnr })
---@type AvanteChunkParser
---@type AvanteLLMStartCallback
local on_start = function(start_opts) end
---@type AvanteLLMChunkCallback
local on_chunk = function(chunk)
original_response = original_response .. chunk
@ -1778,8 +1783,8 @@ function Sidebar:create_input_container(opts)
displayed_response = cur_displayed_response
end
---@type AvanteCompleteParser
local on_complete = function(err)
---@type AvanteLLMStopCallback
local on_stop = function(stop_opts)
pcall(function()
---remove keymaps
vim.keymap.del("n", "j", { buffer = self.result_container.bufnr })
@ -1787,9 +1792,9 @@ function Sidebar:create_input_container(opts)
vim.keymap.del("n", "G", { buffer = self.result_container.bufnr })
end)
if err ~= nil then
if stop_opts.error ~= nil then
self:update_content(
content_prefix .. displayed_response .. "\n\nError: " .. vim.inspect(err),
content_prefix .. displayed_response .. "\n\nError: " .. vim.inspect(stop_opts.error),
{ scroll = scroll }
)
return
@ -1835,8 +1840,9 @@ function Sidebar:create_input_container(opts)
---@type StreamOptions
---@diagnostic disable-next-line: assign-type-mismatch
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
on_start = on_start,
on_chunk = on_chunk,
on_complete = on_complete,
on_stop = on_stop,
})
Llm.stream(stream_options)

View File

@ -141,8 +141,10 @@ L5: pass
history_messages = history_messages,
instructions = vim.json.encode(doc),
mode = "suggesting",
on_start = function(_) end,
on_chunk = function(chunk) full_response = full_response .. chunk end,
on_complete = function(err)
on_stop = function(stop_opts)
local err = stop_opts.error
if err then
Utils.error("Error while suggesting: " .. vim.inspect(err), { once = true, title = "Avante" })
return

View File

@ -10,5 +10,13 @@
Act as an expert software developer.
Always use best practices when coding.
Respect and use existing conventions, libraries, etc that are already present in the code base.
You have access to tools, but only use them when necessary. If a tool is not required, respond as normal.
If you have information that you don't know, please proactively use the tools provided by users! Especially the web search tool.
{% if system_info -%}
Use the appropriate shell based on the user's system info:
{{system_info}}
{%- endif %}
{% block extra_prompt %}
{% endblock %}

View File

@ -47,6 +47,31 @@ M.get_os_name = function()
end
end
M.get_system_info = function()
local os_name = vim.loop.os_uname().sysname
local os_version = vim.loop.os_uname().release
local os_machine = vim.loop.os_uname().machine
local lang = os.getenv("LANG")
local res = string.format(
"- Platform: %s-%s-%s\n- Shell: %s\n- Language: %s\n- Current date: %s",
os_name,
os_version,
os_machine,
vim.o.shell,
lang,
os.date("%Y-%m-%d")
)
local project_root = M.root.get()
if project_root then res = res .. string.format("\n- Project root: %s", project_root) end
local is_git_repo = vim.fn.isdirectory(".git") == 1
if is_git_repo then res = res .. "\n- The user is operating inside a git repository" end
return res
end
--- This function will run given shell command synchronously.
---@param input_cmd string
---@return vim.SystemCompleted
@ -622,6 +647,7 @@ function M.parse_gitignore(gitignore_path)
end
file:close()
ignore_patterns = vim.list_extend(ignore_patterns, { "%.git", "%.worktree", "__pycache__", "node_modules" })
return ignore_patterns, negate_patterns
end
@ -635,26 +661,28 @@ function M.is_ignored(file, ignore_patterns, negate_patterns)
return false
end
---@param options { directory: string, add_dirs?: boolean }
---@param options { directory: string, add_dirs?: boolean, depth?: integer }
function M.scan_directory_respect_gitignore(options)
local directory = options.directory
local gitignore_path = directory .. "/.gitignore"
local gitignore_patterns, gitignore_negate_patterns = M.parse_gitignore(gitignore_path)
gitignore_patterns = vim.list_extend(gitignore_patterns, { "%.git", "%.worktree", "__pycache__", "node_modules" })
return M.scan_directory({
directory = directory,
gitignore_patterns = gitignore_patterns,
gitignore_negate_patterns = gitignore_negate_patterns,
add_dirs = options.add_dirs,
depth = options.depth,
})
end
---@param options { directory: string, gitignore_patterns: string[], gitignore_negate_patterns: string[], add_dirs?: boolean }
---@param options { directory: string, gitignore_patterns: string[], gitignore_negate_patterns: string[], add_dirs?: boolean, depth?: integer, current_depth?: integer }
function M.scan_directory(options)
local directory = options.directory
local ignore_patterns = options.gitignore_patterns
local negate_patterns = options.gitignore_negate_patterns
local add_dirs = options.add_dirs or false
local depth = options.depth or -1
local current_depth = options.current_depth or 0
local files = {}
local handle = vim.loop.fs_scandir(directory)
@ -662,6 +690,8 @@ function M.scan_directory(options)
if not handle then return files end
while true do
if depth > 0 and current_depth >= depth then break end
local name, type = vim.loop.fs_scandir_next(handle)
if not name then break end
@ -677,6 +707,7 @@ function M.scan_directory(options)
gitignore_patterns = ignore_patterns,
gitignore_negate_patterns = negate_patterns,
add_dirs = add_dirs,
current_depth = current_depth + 1,
})
)
elseif type == "file" then

186
tests/llm_tools_spec.lua Normal file
View File

@ -0,0 +1,186 @@
local mock = require("luassert.mock")
local stub = require("luassert.stub")
local LlmTools = require("avante.llm_tools")
local Utils = require("avante.utils")
LlmTools.confirm = function(msg) return true end
describe("llm_tools", function()
local test_dir = "/tmp/test_llm_tools"
local test_file = test_dir .. "/test.txt"
before_each(function()
-- 创建测试目录和文件
os.execute("mkdir -p " .. test_dir)
local file = io.open(test_file, "w")
file:write("test content")
file:close()
-- Mock get_project_root
stub(Utils, "get_project_root", function() return test_dir end)
end)
after_each(function()
-- 清理测试目录
os.execute("rm -rf " .. test_dir)
-- 恢复 mock
Utils.get_project_root:revert()
end)
describe("list_files", function()
it("should list files in directory", function()
local result, err = LlmTools.list_files({ rel_path = ".", depth = 1 })
assert.is_nil(err)
assert.truthy(result:find("test.txt"))
end)
end)
describe("read_file", function()
it("should read file content", function()
local content, err = LlmTools.read_file({ rel_path = "test.txt" })
assert.is_nil(err)
assert.equals("test content", content)
end)
it("should return error for non-existent file", function()
local content, err = LlmTools.read_file({ rel_path = "non_existent.txt" })
assert.truthy(err)
assert.equals("", content)
end)
end)
describe("create_file", function()
it("should create new file", function()
local success, err = LlmTools.create_file({ rel_path = "new_file.txt" })
assert.is_nil(err)
assert.is_true(success)
local file_exists = io.open(test_dir .. "/new_file.txt", "r") ~= nil
assert.is_true(file_exists)
end)
end)
describe("create_dir", function()
it("should create new directory", function()
local success, err = LlmTools.create_dir({ rel_path = "new_dir" })
assert.is_nil(err)
assert.is_true(success)
local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil
assert.is_true(dir_exists)
end)
end)
describe("delete_file", function()
it("should delete existing file", function()
local success, err = LlmTools.delete_file({ rel_path = "test.txt" })
assert.is_nil(err)
assert.is_true(success)
local file_exists = io.open(test_file, "r") ~= nil
assert.is_false(file_exists)
end)
end)
describe("search_files", function()
it("should find files matching pattern", function()
local result, err = LlmTools.search_files({ rel_path = ".", keyword = "test" })
assert.is_nil(err)
assert.truthy(result:find("test.txt"))
end)
end)
describe("search", function()
local original_exepath = vim.fn.exepath
after_each(function() vim.fn.exepath = original_exepath end)
it("should search using ripgrep when available", function()
-- Mock exepath to return rg path
vim.fn.exepath = function(cmd)
if cmd == "rg" then return "/usr/bin/rg" end
return ""
end
-- Create a test file with searchable content
local file = io.open(test_dir .. "/searchable.txt", "w")
file:write("this is searchable content")
file:close()
file = io.open(test_dir .. "/nothing.txt", "w")
file:write("this is nothing")
file:close()
local result, err = LlmTools.search({ rel_path = ".", keyword = "searchable" })
assert.is_nil(err)
assert.truthy(result:find("searchable.txt"))
assert.falsy(result:find("nothing.txt"))
end)
it("should search using ag when rg is not available", function()
-- Mock exepath to return ag path
vim.fn.exepath = function(cmd)
if cmd == "ag" then return "/usr/bin/ag" end
return ""
end
-- Create a test file specifically for ag
local file = io.open(test_dir .. "/ag_test.txt", "w")
file:write("content for ag test")
file:close()
local result, err = LlmTools.search({ rel_path = ".", keyword = "ag test" })
assert.is_nil(err)
assert.is_string(result)
assert.truthy(result:find("ag_test.txt"))
end)
it("should search using grep when rg and ag are not available", function()
-- Mock exepath to return grep path
vim.fn.exepath = function(cmd)
if cmd == "grep" then return "/usr/bin/grep" end
return ""
end
local result, err = LlmTools.search({ rel_path = ".", keyword = "test" })
assert.is_nil(err)
assert.truthy(result:find("test.txt"))
end)
it("should return error when no search tool is available", function()
-- Mock exepath to return nothing
vim.fn.exepath = function() return "" end
local result, err = LlmTools.search({ rel_path = ".", keyword = "test" })
assert.equals("", result)
assert.equals("No search command found", err)
end)
it("should respect path permissions", function()
local result, err = LlmTools.search({ rel_path = "../outside_project", keyword = "test" })
assert.truthy(err:find("No permission to access path"))
end)
it("should handle non-existent paths", function()
local result, err = LlmTools.search({ rel_path = "non_existent_dir", keyword = "test" })
assert.equals("", result)
assert.truthy(err)
assert.truthy(err:find("No such file or directory"))
end)
end)
describe("run_command", function()
it("should execute command and return output", function()
local result, err = LlmTools.run_command({ rel_path = ".", command = "echo 'test'" })
assert.is_nil(err)
assert.equals("test\n", result)
end)
it("should return error when running outside current directory", function()
local result, err = LlmTools.run_command({ rel_path = "../outside_project", command = "echo 'test'" })
assert.is_false(result)
assert.truthy(err)
assert.truthy(err:find("No permission to access path"))
end)
end)
end)