diff --git a/.github/workflows/lua.yaml b/.github/workflows/lua.yaml index a1f97b2..50d2815 100644 --- a/.github/workflows/lua.yaml +++ b/.github/workflows/lua.yaml @@ -38,6 +38,9 @@ jobs: mkdir -p _neovim curl -sL "https://github.com/neovim/neovim/releases/download/${{ matrix.rev }}" | tar xzf - --strip-components=1 -C "${PWD}/_neovim" } + sudo apt-get update + sudo apt-get install -y ripgrep + sudo apt-get install -y silversearcher-ag - name: Run tests run: | diff --git a/crates/avante-templates/src/lib.rs b/crates/avante-templates/src/lib.rs index 144666c..ece86c4 100644 --- a/crates/avante-templates/src/lib.rs +++ b/crates/avante-templates/src/lib.rs @@ -31,6 +31,7 @@ struct TemplateContext { selected_code: Option, project_context: Option, diagnostics: Option, + system_info: Option, } // Given the file name registered after add, the context table in Lua, resulted in a formatted @@ -54,6 +55,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult< selected_code => context.selected_code, project_context => context.project_context, diagnostics => context.diagnostics, + system_info => context.system_info, }) .map_err(LuaError::external) .unwrap()) diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 90f5bd9..f959286 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -20,6 +20,14 @@ M._defaults = { -- For most providers that we support we will determine this automatically. -- If you wish to use a given implementation, then you can override it here. tokenizer = "tiktoken", + web_search_engine = { + provider = "tavily", + api_key_name = "TAVILY_API_KEY", + provider_opts = { + time_range = "d", + include_answer = "basic", + }, + }, ---@type AvanteSupportedProvider openai = { endpoint = "https://api.openai.com/v1", diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index cd5d1bc..9a843cf 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -8,6 +8,7 @@ local Utils = require("avante.utils") local Config = require("avante.config") local Path = require("avante.path") local P = require("avante.providers") +local LLMTools = require("avante.llm_tools") ---@class avante.LLM local M = {} @@ -45,6 +46,8 @@ M.generate_prompts = function(opts) local project_root = Utils.root.get() Path.prompts.initialize(Path.prompts.get(project_root)) + local system_info = Utils.get_system_info() + local template_opts = { use_xml_format = Provider.use_xml_format, ask = opts.ask, -- TODO: add mode without ask instruction @@ -53,6 +56,7 @@ M.generate_prompts = function(opts) selected_code = opts.selected_code, project_context = opts.project_context, diagnostics = opts.diagnostics, + system_info = system_info, } local system_prompt = Path.prompts.render_mode(mode, template_opts) @@ -111,6 +115,10 @@ M.generate_prompts = function(opts) system_prompt = system_prompt, messages = messages, image_paths = image_paths, + tools = opts.tools, + tool_use = opts.tool_use, + tool_result = opts.tool_result, + response_content = opts.response_content, } end @@ -135,7 +143,28 @@ M._stream = function(opts) local current_event_state = nil ---@type AvanteHandlerOptions - local handler_opts = { on_chunk = opts.on_chunk, on_complete = opts.on_complete } + local handler_opts = { + on_start = opts.on_start, + on_chunk = opts.on_chunk, + on_stop = function(stop_opts) + if stop_opts.reason == "tool_use" and stop_opts.tool_use then + local result, error = LLMTools.process_tool_use(stop_opts.tool_use) + local tool_result = { + tool_use_id = stop_opts.tool_use.id, + content = error ~= nil and error or result, + is_error = error ~= nil, + } + local new_opts = vim.tbl_deep_extend( + "force", + opts, + { tool_result = tool_result, tool_use = stop_opts.tool_use, response_content = stop_opts.response_content } + ) + return M._stream(new_opts) + end + return opts.on_stop(stop_opts) + end, + } + ---@type AvanteCurlOutput local spec = Provider.parse_curl_args(Provider, code_opts) @@ -180,7 +209,7 @@ M._stream = function(opts) stream = function(err, data, _) if err then completed = true - opts.on_complete(err) + handler_opts.on_stop({ reason = "error", error = err }) return end if not data then return end @@ -224,7 +253,7 @@ M._stream = function(opts) active_job = nil completed = true cleanup() - opts.on_complete(err) + handler_opts.on_stop({ reason = "error", error = err }) end, callback = function(result) active_job = nil @@ -238,9 +267,10 @@ M._stream = function(opts) vim.schedule(function() if not completed then completed = true - opts.on_complete( - "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body) - ) + handler_opts.on_stop({ + reason = "error", + error = "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body), + }) end end) end @@ -335,9 +365,9 @@ M._dual_boost_stream = function(opts, Provider1, Provider2) on_chunk = function(chunk) if chunk then response = response .. chunk end end, - on_complete = function(err) - if err then - Utils.error(string.format("Stream %d failed: %s", index, err)) + on_stop = function(stop_opts) + if stop_opts.error then + Utils.error(string.format("Stream %d failed: %s", index, stop_opts.error)) return end Utils.debug(string.format("Response %d completed", index)) @@ -381,10 +411,15 @@ end ---@field instructions string ---@field mode LlmMode ---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil +---@field tools? AvanteLLMTool[] +---@field tool_result? AvanteLLMToolResult +---@field tool_use? AvanteLLMToolUse +---@field response_content? string --- ---@class StreamOptions: GeneratePromptsOptions ----@field on_chunk AvanteChunkParser ----@field on_complete AvanteCompleteParser +---@field on_start AvanteLLMStartCallback +---@field on_chunk AvanteLLMChunkCallback +---@field on_stop AvanteLLMStopCallback ---@param opts StreamOptions M.stream = function(opts) @@ -396,12 +431,12 @@ M.stream = function(opts) return original_on_chunk(chunk) end) end - if opts.on_complete ~= nil then - local original_on_complete = opts.on_complete - opts.on_complete = vim.schedule_wrap(function(err) + if opts.on_stop ~= nil then + local original_on_stop = opts.on_stop + opts.on_stop = vim.schedule_wrap(function(stop_opts) if is_completed then return end - is_completed = true - return original_on_complete(err) + if stop_opts.reason == "complete" or stop_opts.reason == "error" then is_completed = true end + return original_on_stop(stop_opts) end) end if Config.dual_boost.enabled then diff --git a/lua/avante/llm_tools.lua b/lua/avante/llm_tools.lua new file mode 100644 index 0000000..a5f795a --- /dev/null +++ b/lua/avante/llm_tools.lua @@ -0,0 +1,714 @@ +local curl = require("plenary.curl") +local Utils = require("avante.utils") +local Path = require("plenary.path") +local Config = require("avante.config") +local M = {} + +---@param rel_path string +---@return string +local function get_abs_path(rel_path) + local project_root = Utils.get_project_root() + return Path:new(project_root):joinpath(rel_path):absolute() +end + +function M.comfirm(msg) + local ok = vim.fn.confirm(msg, "&Yes\n&No", 2) + return ok == 1 +end + +---@param abs_path string +---@return boolean +local function has_permission_to_access(abs_path) + if not Path:new(abs_path):is_absolute() then return false end + local project_root = Utils.get_project_root() + if abs_path:sub(1, #project_root) ~= project_root then return false end + local gitignore_path = project_root .. "/.gitignore" + local gitignore_patterns, gitignore_negate_patterns = Utils.parse_gitignore(gitignore_path) + return not Utils.is_ignored(abs_path, gitignore_patterns, gitignore_negate_patterns) +end + +---@param opts { rel_path: string, depth?: integer } +---@return string files +---@return string|nil error +function M.list_files(opts) + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end + local files = Utils.scan_directory_respect_gitignore({ + directory = abs_path, + add_dirs = true, + depth = opts.depth, + }) + local result = "" + for _, file in ipairs(files) do + local uniform_path = Utils.uniform_path(file) + result = result .. uniform_path .. "\n" + end + result = result:gsub("\n$", "") + return result, nil +end + +---@param opts { rel_path: string, keyword: string } +---@return string files +---@return string|nil error +function M.search_files(opts) + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end + local files = Utils.scan_directory_respect_gitignore({ + directory = abs_path, + }) + local result = "" + for _, file in ipairs(files) do + if file:find(opts.keyword) then result = result .. file .. "\n" end + end + result = result:gsub("\n$", "") + return result, nil +end + +---@param opts { rel_path: string, keyword: string } +---@return string result +---@return string|nil error +function M.search(opts) + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end + if not Path:new(abs_path):exists() then return "", "No such file or directory: " .. abs_path end + + ---check if any search cmd is available + local search_cmd = vim.fn.exepath("rg") + if search_cmd == "" then search_cmd = vim.fn.exepath("ag") end + if search_cmd == "" then search_cmd = vim.fn.exepath("ack") end + if search_cmd == "" then search_cmd = vim.fn.exepath("grep") end + if search_cmd == "" then return "", "No search command found" end + + ---execute the search command + local cmd = "" + if search_cmd:find("rg") then + cmd = string.format("%s --files-with-matches --no-ignore-vcs --ignore-case --hidden --glob '!.git'", search_cmd) + cmd = string.format("%s '%s' %s", cmd, opts.keyword, abs_path) + elseif search_cmd:find("ag") then + cmd = string.format("%s '%s' --nocolor --nogroup --hidden --ignore .git %s", search_cmd, opts.keyword, abs_path) + elseif search_cmd:find("ack") then + cmd = string.format("%s --nocolor --nogroup --hidden --ignore-dir .git", search_cmd) + cmd = string.format("%s '%s' %s", cmd, opts.keyword, abs_path) + elseif search_cmd:find("grep") then + cmd = string.format("%s -riH --exclude-dir=.git %s %s", search_cmd, opts.keyword, abs_path) + end + + Utils.debug("cmd", cmd) + local result = vim.fn.system(cmd) + + return result or "", nil +end + +---@param opts { rel_path: string } +---@return string definitions +---@return string|nil error +function M.read_file_toplevel_symbols(opts) + local RepoMap = require("avante.repo_map") + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end + local filetype = RepoMap.get_ts_lang(abs_path) + local repo_map_lib = RepoMap._init_repo_map_lib() + if not repo_map_lib then return "", "Failed to load avante_repo_map" end + local definitions = filetype + and repo_map_lib.stringify_definitions(filetype, Utils.file.read_content(abs_path) or "") + or "" + return definitions, nil +end + +---@param opts { rel_path: string } +---@return string content +---@return string|nil error +function M.read_file(opts) + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end + local file = io.open(abs_path, "r") + if not file then return "", "file not found: " .. abs_path end + local content = file:read("*a") + file:close() + return content, nil +end + +---@param opts { rel_path: string } +---@return boolean success +---@return string|nil error +function M.create_file(opts) + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + ---create directory if it doesn't exist + local dir = Path:new(abs_path):parent() + if not dir:exists() then dir:mkdir({ parents = true }) end + ---create file if it doesn't exist + if not dir:joinpath(opts.rel_path):exists() then + local file = io.open(abs_path, "w") + if not file then return false, "file not found: " .. abs_path end + file:close() + end + + return true, nil +end + +---@param opts { rel_path: string, new_rel_path: string } +---@return boolean success +---@return string|nil error +function M.rename_file(opts) + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end + if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end + local new_abs_path = get_abs_path(opts.new_rel_path) + if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end + if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end + if not M.confirm("Are you sure you want to rename the file: " .. abs_path .. " to: " .. new_abs_path) then + return false, "User canceled" + end + os.rename(abs_path, new_abs_path) + return true, nil +end + +---@param opts { rel_path: string, new_rel_path: string } +---@return boolean success +---@return string|nil error +function M.copy_file(opts) + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end + if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end + local new_abs_path = get_abs_path(opts.new_rel_path) + if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end + if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end + Path:new(new_abs_path):write(Path:new(abs_path):read()) + return true, nil +end + +---@param opts { rel_path: string } +---@return boolean success +---@return string|nil error +function M.delete_file(opts) + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end + if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end + if not M.confirm("Are you sure you want to delete the file: " .. abs_path) then return false, "User canceled" end + os.remove(abs_path) + return true, nil +end + +---@param opts { rel_path: string } +---@return boolean success +---@return string|nil error +function M.create_dir(opts) + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end + Path:new(abs_path):mkdir({ parents = true }) + return true, nil +end + +---@param opts { rel_path: string, new_rel_path: string } +---@return boolean success +---@return string|nil error +function M.rename_dir(opts) + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end + if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end + local new_abs_path = get_abs_path(opts.new_rel_path) + if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end + if Path:new(new_abs_path):exists() then return false, "Directory already exists: " .. new_abs_path end + if not M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?") then + return false, "User canceled" + end + os.rename(abs_path, new_abs_path) + return true, nil +end + +---@param opts { rel_path: string } +---@return boolean success +---@return string|nil error +function M.delete_dir(opts) + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end + if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end + if not M.confirm("Are you sure you want to delete the directory: " .. abs_path) then + return false, "User canceled" + end + os.remove(abs_path) + return true, nil +end + +---@param opts { rel_path: string, command: string } +---@return string|boolean result +---@return string|nil error +function M.run_command(opts) + local abs_path = get_abs_path(opts.rel_path) + if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end + if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end + if + not M.confirm("Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path) + then + return false, "User canceled" + end + ---change cwd to abs_path + local old_cwd = vim.fn.getcwd() + vim.fn.chdir(abs_path) + local res = Utils.shell_run(opts.command) + vim.fn.chdir(old_cwd) + if res.code ~= 0 then + if res.stdout then return false, "Error: " .. res.stdout .. "; Error code: " .. tostring(res.code) end + return false, "Error code: " .. tostring(res.code) + end + return res.stdout, nil +end + +---@param opts { query: string } +---@return string|nil result +---@return string|nil error +function M.web_search(opts) + local search_engine = Config.web_search_engine + if search_engine.provider == "tavily" then + if search_engine.api_key_name == "" then return nil, "No API key provided" end + local api_key = os.getenv(search_engine.api_key_name) + if api_key == nil or api_key == "" then + return nil, "Environment variable " .. search_engine.api_key_name .. " is not set" + end + local resp = curl.post("https://api.tavily.com/search", { + headers = { + ["Content-Type"] = "application/json", + ["Authorization"] = "Bearer " .. api_key, + }, + body = vim.json.encode(vim.tbl_deep_extend("force", { + query = opts.query, + }, search_engine.provider_opts)), + }) + if resp.status ~= 200 then return nil, "Error: " .. resp.body end + local jsn = vim.json.decode(resp.body) + return jsn.anwser, nil + end +end + +---@class AvanteLLMTool +---@field name string +---@field description string +---@field param AvanteLLMToolParam +---@field returns AvanteLLMToolReturn[] + +---@class AvanteLLMToolParam +---@field type string +---@field fields AvanteLLMToolParamField[] + +---@class AvanteLLMToolParamField +---@field name string +---@field description string +---@field type string +---@field optional? boolean + +---@class AvanteLLMToolReturn +---@field name string +---@field description string +---@field type string +---@field optional? boolean + +---@type AvanteLLMTool[] +M.tools = { + { + name = "list_files", + description = "List files in a directory", + param = { + type = "table", + fields = { + { + name = "rel_path", + description = "Relative path to the directory", + type = "string", + }, + { + name = "depth", + description = "Depth of the directory", + type = "integer", + optional = true, + }, + }, + }, + returns = { + { + name = "files", + description = "List of files in the directory", + type = "string[]", + }, + { + name = "error", + description = "Error message if the directory was not listed successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "search_files", + description = "Search for files in a directory", + param = { + type = "table", + fields = { + { + name = "rel_path", + description = "Relative path to the directory", + type = "string", + }, + { + name = "keyword", + description = "Keyword to search for", + type = "string", + }, + }, + }, + returns = { + { + name = "files", + description = "List of files that match the keyword", + type = "string", + }, + { + name = "error", + description = "Error message if the directory was not searched successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "search", + description = "Search for a keyword in a directory", + param = { + type = "table", + fields = { + { + name = "rel_path", + description = "Relative path to the directory", + type = "string", + }, + { + name = "keyword", + description = "Keyword to search for", + type = "string", + }, + }, + }, + returns = { + { + name = "files", + description = "List of files that match the keyword", + type = "string", + }, + { + name = "error", + description = "Error message if the directory was not searched successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "read_file_toplevel_symbols", + description = "Read the top-level symbols of a file", + param = { + type = "table", + fields = { + { + name = "rel_path", + description = "Relative path to the file", + type = "string", + }, + }, + }, + returns = { + { + name = "definitions", + description = "Top-level symbols of the file", + type = "string", + }, + { + name = "error", + description = "Error message if the file was not read successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "read_file", + description = "Read the contents of a file", + param = { + type = "table", + fields = { + { + name = "rel_path", + description = "Relative path to the file", + type = "string", + }, + }, + }, + returns = { + { + name = "content", + description = "Contents of the file", + type = "string", + }, + { + name = "error", + description = "Error message if the file was not read successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "create_file", + description = "Create a new file", + param = { + type = "table", + fields = { + { + name = "rel_path", + description = "Relative path to the file", + type = "string", + }, + }, + }, + returns = { + { + name = "success", + description = "True if the file was created successfully, false otherwise", + type = "boolean", + }, + { + name = "error", + description = "Error message if the file was not created successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "rename_file", + description = "Rename a file", + param = { + type = "table", + fields = { + { + name = "rel_path", + description = "Relative path to the file", + type = "string", + }, + { + name = "new_rel_path", + description = "New relative path for the file", + type = "string", + }, + }, + }, + returns = { + { + name = "success", + description = "True if the file was renamed successfully, false otherwise", + type = "boolean", + }, + { + name = "error", + description = "Error message if the file was not renamed successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "delete_file", + description = "Delete a file", + param = { + type = "table", + fields = { + { + name = "rel_path", + description = "Relative path to the file", + type = "string", + }, + }, + }, + returns = { + { + name = "success", + description = "True if the file was deleted successfully, false otherwise", + type = "boolean", + }, + { + name = "error", + description = "Error message if the file was not deleted successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "create_dir", + description = "Create a new directory", + param = { + type = "table", + fields = { + { + name = "rel_path", + description = "Relative path to the directory", + type = "string", + }, + }, + }, + returns = { + { + name = "success", + description = "True if the directory was created successfully, false otherwise", + type = "boolean", + }, + { + name = "error", + description = "Error message if the directory was not created successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "rename_dir", + description = "Rename a directory", + param = { + type = "table", + fields = { + { + name = "rel_path", + description = "Relative path to the directory", + type = "string", + }, + { + name = "new_rel_path", + description = "New relative path for the directory", + type = "string", + }, + }, + }, + returns = { + { + name = "success", + description = "True if the directory was renamed successfully, false otherwise", + type = "boolean", + }, + { + name = "error", + description = "Error message if the directory was not renamed successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "delete_dir", + description = "Delete a directory", + param = { + type = "table", + fields = { + { + name = "rel_path", + description = "Relative path to the directory", + type = "string", + }, + }, + }, + returns = { + { + name = "success", + description = "True if the directory was deleted successfully, false otherwise", + type = "boolean", + }, + { + name = "error", + description = "Error message if the directory was not deleted successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "run_command", + description = "Run a command in a directory", + param = { + type = "table", + fields = { + { + name = "rel_path", + description = "Relative path to the directory", + type = "string", + }, + { + name = "command", + description = "Command to run", + type = "string", + }, + }, + }, + returns = { + { + name = "stdout", + description = "Output of the command", + type = "string", + }, + { + name = "error", + description = "Error message if the command was not run successfully", + type = "string", + optional = true, + }, + }, + }, + { + name = "web_search", + description = "Search the web", + param = { + type = "table", + fields = { + { + name = "query", + description = "Query to search", + type = "string", + }, + }, + }, + returns = { + { + name = "result", + description = "Result of the search", + type = "string", + }, + { + name = "error", + description = "Error message if the search was not successful", + type = "string", + optional = true, + }, + }, + }, +} + +---@param tool_use AvanteLLMToolUse +---@return string | nil result +---@return string | nil error +function M.process_tool_use(tool_use) + Utils.debug("use tool", tool_use.name, tool_use.input_json) + local tool = vim.iter(M.tools):find(function(tool) return tool.name == tool_use.name end) + if tool == nil then return end + local input_json = vim.json.decode(tool_use.input_json) + local func = M[tool.name] + local result, error = func(input_json) + -- Utils.debug("result", result) + -- Utils.debug("error", error) + if result ~= nil and type(result) ~= "string" then result = vim.json.encode(result) end + return result, error +end + +return M diff --git a/lua/avante/providers/bedrock/claude.lua b/lua/avante/providers/bedrock/claude.lua index 47c6b40..f857cb9 100644 --- a/lua/avante/providers/bedrock/claude.lua +++ b/lua/avante/providers/bedrock/claude.lua @@ -6,6 +6,8 @@ ---@field role "user" | "assistant" ---@field content [AvanteBedrockClaudeTextMessage][] +local Claude = require("avante.providers.claude") + ---@class AvanteBedrockModelHandler local M = {} @@ -33,25 +35,7 @@ M.parse_messages = function(opts) return messages end -M.parse_response = function(ctx, data_stream, event_state, opts) - if event_state == nil then - if data_stream:match('"content_block_delta"') then - event_state = "content_block_delta" - elseif data_stream:match('"message_stop"') then - event_state = "message_stop" - end - end - if event_state == "content_block_delta" then - local ok, json = pcall(vim.json.decode, data_stream) - if not ok then return end - opts.on_chunk(json.delta.text) - elseif event_state == "message_stop" then - opts.on_complete(nil) - return - elseif event_state == "error" then - opts.on_complete(vim.json.decode(data_stream)) - end -end +M.parse_response = Claude.parse_response ---@param prompt_opts AvantePromptOptions ---@param body_opts table @@ -60,7 +44,6 @@ M.build_bedrock_payload = function(prompt_opts, body_opts) local system_prompt = prompt_opts.system_prompt or "" local messages = M.parse_messages(prompt_opts) local max_tokens = body_opts.max_tokens or 2000 - local temperature = body_opts.temperature or 0.7 local payload = { anthropic_version = "bedrock-2023-05-31", max_tokens = max_tokens, diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index a59bb6d..300e671 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -17,6 +17,44 @@ local P = require("avante.providers") ---@field role "user" | "assistant" ---@field content [AvanteClaudeTextMessage | AvanteClaudeImageMessage][] +---@class AvanteClaudeTool +---@field name string +---@field description string +---@field input_schema AvanteClaudeToolInputSchema + +---@class AvanteClaudeToolInputSchema +---@field type "object" +---@field properties table +---@field required string[] + +---@class AvanteClaudeToolInputSchemaProperty +---@field type "string" | "number" | "boolean" +---@field description string +---@field enum? string[] + +---@param tool AvanteLLMTool +---@return AvanteClaudeTool +local function transform_tool(tool) + local input_schema_properties = {} + local required = {} + for _, field in ipairs(tool.param.fields) do + input_schema_properties[field.name] = { + type = field.type, + description = field.description, + } + if not field.optional then table.insert(required, field.name) end + end + return { + name = tool.name, + description = tool.description, + input_schema = { + type = "object", + properties = input_schema_properties, + required = required, + }, + } +end + ---@class AvanteProviderFunctor local M = {} @@ -74,26 +112,101 @@ M.parse_messages = function(opts) messages[#messages].content = message_content end + if opts.tool_use then + local msg = { + role = "assistant", + content = {}, + } + if opts.response_content then + msg.content[#msg.content + 1] = { + type = "text", + text = opts.response_content, + } + end + msg.content[#msg.content + 1] = { + type = "tool_use", + id = opts.tool_use.id, + name = opts.tool_use.name, + input = vim.json.decode(opts.tool_use.input_json), + } + messages[#messages + 1] = msg + end + + if opts.tool_result then + messages[#messages + 1] = { + role = "user", + content = { + { + type = "tool_result", + tool_use_id = opts.tool_result.tool_use_id, + content = opts.tool_result.content, + is_error = opts.tool_result.is_error, + }, + }, + } + end + return messages end M.parse_response = function(ctx, data_stream, event_state, opts) if event_state == nil then - if data_stream:match('"content_block_delta"') then - event_state = "content_block_delta" + if data_stream:match('"message_start"') then + event_state = "message_start" + elseif data_stream:match('"message_delta"') then + event_state = "message_delta" elseif data_stream:match('"message_stop"') then event_state = "message_stop" + elseif data_stream:match('"content_block_start"') then + event_state = "content_block_start" + elseif data_stream:match('"content_block_delta"') then + event_state = "content_block_delta" + elseif data_stream:match('"content_block_stop"') then + event_state = "content_block_stop" end end - if event_state == "content_block_delta" then - local ok, json = pcall(vim.json.decode, data_stream) + if event_state == "message_start" then + local ok, jsn = pcall(vim.json.decode, data_stream) if not ok then return end - opts.on_chunk(json.delta.text) - elseif event_state == "message_stop" then - opts.on_complete(nil) + opts.on_start(jsn.message.usage) + elseif event_state == "content_block_start" then + local ok, jsn = pcall(vim.json.decode, data_stream) + if not ok then return end + if jsn.content_block.type == "tool_use" then + ctx.tool_use = { + name = jsn.content_block.name, + id = jsn.content_block.id, + input_json = "", + } + elseif jsn.content_block.type == "text" then + ctx.response_content = "" + end + elseif event_state == "content_block_delta" then + local ok, jsn = pcall(vim.json.decode, data_stream) + if not ok then return end + if ctx.tool_use and jsn.delta.type == "input_json_delta" then + ctx.tool_use.input_json = ctx.tool_use.input_json .. jsn.delta.partial_json + return + elseif ctx.response_content and jsn.delta.type == "text_delta" then + ctx.response_content = ctx.response_content .. jsn.delta.text + end + opts.on_chunk(jsn.delta.text) + elseif event_state == "message_delta" then + local ok, jsn = pcall(vim.json.decode, data_stream) + if not ok then return end + if jsn.delta.stop_reason == "end_turn" then + opts.on_stop({ reason = "complete", usage = jsn.usage }) + elseif jsn.delta.stop_reason == "tool_use" then + opts.on_stop({ + reason = "tool_use", + usage = jsn.usage, + tool_use = ctx.tool_use, + response_content = ctx.response_content, + }) + end return elseif event_state == "error" then - opts.on_complete(vim.json.decode(data_stream)) + opts.on_stop({ reason = "error", error = vim.json.decode(data_stream) }) end end @@ -113,6 +226,13 @@ M.parse_curl_args = function(provider, prompt_opts) local messages = M.parse_messages(prompt_opts) + local tools = {} + if prompt_opts.tools then + for _, tool in ipairs(prompt_opts.tools) do + table.insert(tools, transform_tool(tool)) + end + end + return { url = Utils.url_join(base.endpoint, "/v1/messages"), proxy = base.proxy, @@ -128,6 +248,7 @@ M.parse_curl_args = function(provider, prompt_opts) }, }, messages = messages, + tools = tools, stream = true, }, body_opts), } diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua index af12faa..e6f787e 100644 --- a/lua/avante/providers/cohere.lua +++ b/lua/avante/providers/cohere.lua @@ -62,7 +62,7 @@ M.parse_stream_data = function(data, opts) local json = vim.json.decode(data) if json.type ~= nil then if json.type == "message-end" and json.delta.finish_reason == "COMPLETE" then - opts.on_complete(nil) + opts.on_stop({ reason = "complete" }) return end if json.type == "content-delta" then opts.on_chunk(json.delta.message.content.text) end diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 9e29e4e..af63565 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -66,17 +66,17 @@ end M.parse_response = function(ctx, data_stream, _, opts) local ok, json = pcall(vim.json.decode, data_stream) - if not ok then opts.on_complete(json) end + if not ok then opts.on_stop({ reason = "error", error = json }) end if json.candidates then if #json.candidates > 0 then if json.candidates[1].finishReason and json.candidates[1].finishReason == "STOP" then opts.on_chunk(json.candidates[1].content.parts[1].text) - opts.on_complete(nil) + opts.on_stop({ reason = "complete" }) else opts.on_chunk(json.candidates[1].content.parts[1].text) end else - opts.on_complete(nil) + opts.on_stop({ reason = "complete" }) end end end diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 1802f3a..ba9848c 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -11,17 +11,28 @@ local DressingConfig = { local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@class AvanteHandlerOptions: table<[string], string> ----@field on_chunk AvanteChunkParser ----@field on_complete AvanteCompleteParser +---@field on_start AvanteLLMStartCallback +---@field on_chunk AvanteLLMChunkCallback +---@field on_stop AvanteLLMStopCallback --- ---@class AvanteLLMMessage ---@field role "user" | "assistant" ---@field content string --- +---@class AvanteLLMToolResult +---@field tool_name string +---@field tool_use_id string +---@field content string +---@field is_error? boolean +--- ---@class AvantePromptOptions: table<[string], string> ---@field system_prompt string ---@field messages AvanteLLMMessage[] ---@field image_paths? string[] +---@field tools? AvanteLLMTool[] +---@field tool_result? AvanteLLMToolResult +---@field tool_use? AvanteLLMToolUse +---@field response_content? string --- ---@class AvanteGeminiMessage ---@field role "user" @@ -35,8 +46,9 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput --- ---@class ResponseParser ----@field on_chunk fun(chunk: string): any ----@field on_complete fun(err: string|nil): any +---@field on_start AvanteLLMStartCallback +---@field on_chunk AvanteLLMChunkCallback +---@field on_stop AvanteLLMStopCallback ---@alias AvanteResponseParser fun(ctx: any, data_stream: string, event_state: string, opts: ResponseParser): nil --- ---@class AvanteDefaultBaseProvider: table @@ -54,9 +66,31 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field temperature? number ---@field max_tokens? number --- +---@class AvanteLLMUsage +---@field input_tokens number +---@field cache_creation_input_tokens number +---@field cache_read_input_tokens number +---@field output_tokens number +--- +---@class AvanteLLMToolUse +---@field name string +---@field id string +---@field input_json string +--- +---@class AvanteLLMStartCallbackOptions +---@field usage? AvanteLLMUsage +--- +---@class AvanteLLMStopCallbackOptions +---@field reason "complete" | "tool_use" | "error" +---@field error? string | table +---@field usage? AvanteLLMUsage +---@field tool_use? AvanteLLMToolUse +---@field response_content? string +--- ---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil ----@alias AvanteChunkParser fun(chunk: string): any ----@alias AvanteCompleteParser fun(err: string|nil): nil +---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil +---@alias AvanteLLMChunkCallback fun(chunk: string): any +---@alias AvanteLLMStopCallback fun(opts: AvanteLLMStopCallbackOptions): nil ---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table --- ---@class AvanteProvider: AvanteSupportedProvider diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 40b7951..54ab509 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -24,12 +24,72 @@ local P = require("avante.providers") ---@field index integer ---@field logprobs integer --- +---@class OpenAIMessageToolCallFunction +---@field name string +---@field arguments string +--- +---@class OpenAIMessageToolCall +---@field id string +---@field type "function" +---@field function OpenAIMessageToolCallFunction +--- ---@class OpenAIMessage ---@field role? "user" | "system" | "assistant" ---@field content? string ---@field reasoning_content? string ---@field reasoning? string +---@field tool_calls? OpenAIMessageToolCall[] --- +---@class AvanteOpenAITool +---@field type "function" +---@field function AvanteOpenAIToolFunction +--- +---@class AvanteOpenAIToolFunction +---@field name string +---@field description string +---@field parameters AvanteOpenAIToolFunctionParameters +---@field strict boolean +--- +---@class AvanteOpenAIToolFunctionParameters +---@field type string +---@field properties table +---@field required string[] +---@field additionalProperties boolean +--- +---@class AvanteOpenAIToolFunctionParameterProperty +---@field type string +---@field description string + +---@param tool AvanteLLMTool +---@return AvanteOpenAITool +local function transform_tool(tool) + local input_schema_properties = {} + local required = {} + for _, field in ipairs(tool.param.fields) do + input_schema_properties[field.name] = { + type = field.type, + description = field.description, + } + if not field.optional then table.insert(required, field.name) end + end + local res = { + type = "function", + ["function"] = { + name = tool.name, + description = tool.description, + }, + } + if vim.tbl_count(input_schema_properties) > 0 then + res["function"].parameters = { + type = "object", + properties = input_schema_properties, + required = required, + additionalProperties = false, + } + end + return res +end + ---@class AvanteProviderFunctor local M = {} @@ -107,12 +167,34 @@ M.parse_messages = function(opts) table.insert(final_messages, { role = M.role_map[role] or role, content = message.content }) end) + if opts.tool_result then + table.insert(final_messages, { + role = M.role_map["assistant"], + tool_calls = { + { + id = opts.tool_use.id, + type = "function", + ["function"] = { + name = opts.tool_use.name, + arguments = opts.tool_use.input_json, + }, + }, + }, + }) + local result_content = opts.tool_result.content or "" + table.insert(final_messages, { + role = "tool", + tool_call_id = opts.tool_result.tool_use_id, + content = opts.tool_result.is_error and "Error: " .. result_content or result_content, + }) + end + return final_messages end M.parse_response = function(ctx, data_stream, _, opts) if data_stream:match('"%[DONE%]":') then - opts.on_complete(nil) + opts.on_stop({ reason = "complete" }) return end if data_stream:match('"delta":') then @@ -121,7 +203,14 @@ M.parse_response = function(ctx, data_stream, _, opts) if jsn.choices and jsn.choices[1] then local choice = jsn.choices[1] if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then - opts.on_complete(nil) + opts.on_stop({ reason = "complete" }) + elseif choice.finish_reason == "tool_calls" then + opts.on_stop({ + reason = "tool_use", + usage = jsn.usage, + tool_use = ctx.tool_use, + response_content = ctx.response_content, + }) elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then ctx.returned_think_start_tag = true @@ -136,6 +225,17 @@ M.parse_response = function(ctx, data_stream, _, opts) end ctx.last_think_content = choice.delta.reasoning opts.on_chunk(choice.delta.reasoning) + elseif choice.delta.tool_calls then + local tool_call = choice.delta.tool_calls[1] + if not ctx.tool_use then + ctx.tool_use = { + name = tool_call["function"].name, + id = tool_call.id, + input_json = "", + } + else + ctx.tool_use.input_json = ctx.tool_use.input_json .. tool_call["function"].arguments + end elseif choice.delta.content then if ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag) @@ -164,7 +264,7 @@ M.parse_response_without_stream = function(data, _, opts) local choice = json.choices[1] if choice.message and choice.message.content then opts.on_chunk(choice.message.content) - vim.schedule(function() opts.on_complete(nil) end) + vim.schedule(function() opts.on_stop({ reason = "complete" }) end) end end end @@ -198,6 +298,13 @@ M.parse_curl_args = function(provider, code_opts) body_opts.temperature = 1 end + local tools = {} + if code_opts.tools then + for _, tool in ipairs(code_opts.tools) do + table.insert(tools, transform_tool(tool)) + end + end + Utils.debug("endpoint", base.endpoint) Utils.debug("model", base.model) @@ -210,6 +317,7 @@ M.parse_curl_args = function(provider, code_opts) model = base.model, messages = M.parse_messages(code_opts), stream = stream, + tools = tools, }, body_opts), } end diff --git a/lua/avante/selection.lua b/lua/avante/selection.lua index 37d5de7..db68127 100644 --- a/lua/avante/selection.lua +++ b/lua/avante/selection.lua @@ -157,7 +157,10 @@ function Selection:create_editing_input() self.prompt_input:start_spinner() - ---@type AvanteChunkParser + ---@type AvanteLLMStartCallback + local on_start = function(start_opts) end + + ---@type AvanteLLMChunkCallback local on_chunk = function(chunk) full_response = full_response .. chunk local response_lines_ = vim.split(full_response, "\n") @@ -182,13 +185,15 @@ function Selection:create_editing_input() finish_line = start_line + #response_lines - 1 end - ---@type AvanteCompleteParser - local on_complete = function(err) - if err then + ---@type AvanteLLMStopCallback + local on_stop = function(stop_opts) + if stop_opts.error then -- NOTE: in Ubuntu 22.04+ you will see this ignorable error from ~/.local/share/nvim/lazy/avante.nvim/lua/avante/llm.lua `on_error = function(err)`, check to avoid showing this error. - if type(err) == "table" and err.exit == nil and err.stderr == "{}" then return end + if type(stop_opts.error) == "table" and stop_opts.error.exit == nil and stop_opts.error.stderr == "{}" then + return + end Utils.error( - "Error occurred while processing the response: " .. vim.inspect(err), + "Error occurred while processing the response: " .. vim.inspect(stop_opts.error), { once = true, title = "Avante" } ) return @@ -216,8 +221,9 @@ function Selection:create_editing_input() selected_code = self.selection.content, instructions = input, mode = "editing", + on_start = on_start, on_chunk = on_chunk, - on_complete = on_complete, + on_stop = on_stop, }) end diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 111749c..a55c394 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -13,6 +13,7 @@ local Utils = require("avante.utils") local Highlights = require("avante.highlights") local RepoMap = require("avante.repo_map") local FileSelector = require("avante.file_selector") +local LLMTools = require("avante.llm_tools") local RESULT_BUF_NAME = "AVANTE_RESULT" local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated" @@ -1669,6 +1670,7 @@ function Sidebar:create_input_container(opts) selected_code = selected_code_content, instructions = request, mode = "planning", + tools = LLMTools.tools, } end @@ -1755,7 +1757,10 @@ function Sidebar:create_input_container(opts) vim.keymap.set("n", "k", on_k, { buffer = self.result_container.bufnr }) vim.keymap.set("n", "G", on_G, { buffer = self.result_container.bufnr }) - ---@type AvanteChunkParser + ---@type AvanteLLMStartCallback + local on_start = function(start_opts) end + + ---@type AvanteLLMChunkCallback local on_chunk = function(chunk) original_response = original_response .. chunk @@ -1778,8 +1783,8 @@ function Sidebar:create_input_container(opts) displayed_response = cur_displayed_response end - ---@type AvanteCompleteParser - local on_complete = function(err) + ---@type AvanteLLMStopCallback + local on_stop = function(stop_opts) pcall(function() ---remove keymaps vim.keymap.del("n", "j", { buffer = self.result_container.bufnr }) @@ -1787,9 +1792,9 @@ function Sidebar:create_input_container(opts) vim.keymap.del("n", "G", { buffer = self.result_container.bufnr }) end) - if err ~= nil then + if stop_opts.error ~= nil then self:update_content( - content_prefix .. displayed_response .. "\n\nError: " .. vim.inspect(err), + content_prefix .. displayed_response .. "\n\nError: " .. vim.inspect(stop_opts.error), { scroll = scroll } ) return @@ -1835,8 +1840,9 @@ function Sidebar:create_input_container(opts) ---@type StreamOptions ---@diagnostic disable-next-line: assign-type-mismatch local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, { + on_start = on_start, on_chunk = on_chunk, - on_complete = on_complete, + on_stop = on_stop, }) Llm.stream(stream_options) diff --git a/lua/avante/suggestion.lua b/lua/avante/suggestion.lua index da06e2a..2218b6c 100644 --- a/lua/avante/suggestion.lua +++ b/lua/avante/suggestion.lua @@ -141,8 +141,10 @@ L5: pass history_messages = history_messages, instructions = vim.json.encode(doc), mode = "suggesting", + on_start = function(_) end, on_chunk = function(chunk) full_response = full_response .. chunk end, - on_complete = function(err) + on_stop = function(stop_opts) + local err = stop_opts.error if err then Utils.error("Error while suggesting: " .. vim.inspect(err), { once = true, title = "Avante" }) return diff --git a/lua/avante/templates/base.avanterules b/lua/avante/templates/base.avanterules index ce2a6c3..a196d4a 100644 --- a/lua/avante/templates/base.avanterules +++ b/lua/avante/templates/base.avanterules @@ -10,5 +10,13 @@ Act as an expert software developer. Always use best practices when coding. Respect and use existing conventions, libraries, etc that are already present in the code base. +You have access to tools, but only use them when necessary. If a tool is not required, respond as normal. +If you have information that you don't know, please proactively use the tools provided by users! Especially the web search tool. + +{% if system_info -%} +Use the appropriate shell based on the user's system info: +{{system_info}} +{%- endif %} + {% block extra_prompt %} {% endblock %} diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 23c1079..bf2e3d3 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -47,6 +47,31 @@ M.get_os_name = function() end end +M.get_system_info = function() + local os_name = vim.loop.os_uname().sysname + local os_version = vim.loop.os_uname().release + local os_machine = vim.loop.os_uname().machine + local lang = os.getenv("LANG") + + local res = string.format( + "- Platform: %s-%s-%s\n- Shell: %s\n- Language: %s\n- Current date: %s", + os_name, + os_version, + os_machine, + vim.o.shell, + lang, + os.date("%Y-%m-%d") + ) + + local project_root = M.root.get() + if project_root then res = res .. string.format("\n- Project root: %s", project_root) end + + local is_git_repo = vim.fn.isdirectory(".git") == 1 + if is_git_repo then res = res .. "\n- The user is operating inside a git repository" end + + return res +end + --- This function will run given shell command synchronously. ---@param input_cmd string ---@return vim.SystemCompleted @@ -622,6 +647,7 @@ function M.parse_gitignore(gitignore_path) end file:close() + ignore_patterns = vim.list_extend(ignore_patterns, { "%.git", "%.worktree", "__pycache__", "node_modules" }) return ignore_patterns, negate_patterns end @@ -635,26 +661,28 @@ function M.is_ignored(file, ignore_patterns, negate_patterns) return false end ----@param options { directory: string, add_dirs?: boolean } +---@param options { directory: string, add_dirs?: boolean, depth?: integer } function M.scan_directory_respect_gitignore(options) local directory = options.directory local gitignore_path = directory .. "/.gitignore" local gitignore_patterns, gitignore_negate_patterns = M.parse_gitignore(gitignore_path) - gitignore_patterns = vim.list_extend(gitignore_patterns, { "%.git", "%.worktree", "__pycache__", "node_modules" }) return M.scan_directory({ directory = directory, gitignore_patterns = gitignore_patterns, gitignore_negate_patterns = gitignore_negate_patterns, add_dirs = options.add_dirs, + depth = options.depth, }) end ----@param options { directory: string, gitignore_patterns: string[], gitignore_negate_patterns: string[], add_dirs?: boolean } +---@param options { directory: string, gitignore_patterns: string[], gitignore_negate_patterns: string[], add_dirs?: boolean, depth?: integer, current_depth?: integer } function M.scan_directory(options) local directory = options.directory local ignore_patterns = options.gitignore_patterns local negate_patterns = options.gitignore_negate_patterns local add_dirs = options.add_dirs or false + local depth = options.depth or -1 + local current_depth = options.current_depth or 0 local files = {} local handle = vim.loop.fs_scandir(directory) @@ -662,6 +690,8 @@ function M.scan_directory(options) if not handle then return files end while true do + if depth > 0 and current_depth >= depth then break end + local name, type = vim.loop.fs_scandir_next(handle) if not name then break end @@ -677,6 +707,7 @@ function M.scan_directory(options) gitignore_patterns = ignore_patterns, gitignore_negate_patterns = negate_patterns, add_dirs = add_dirs, + current_depth = current_depth + 1, }) ) elseif type == "file" then diff --git a/tests/llm_tools_spec.lua b/tests/llm_tools_spec.lua new file mode 100644 index 0000000..372499c --- /dev/null +++ b/tests/llm_tools_spec.lua @@ -0,0 +1,186 @@ +local mock = require("luassert.mock") +local stub = require("luassert.stub") +local LlmTools = require("avante.llm_tools") +local Utils = require("avante.utils") + +LlmTools.confirm = function(msg) return true end + +describe("llm_tools", function() + local test_dir = "/tmp/test_llm_tools" + local test_file = test_dir .. "/test.txt" + + before_each(function() + -- 创建测试目录和文件 + os.execute("mkdir -p " .. test_dir) + local file = io.open(test_file, "w") + file:write("test content") + file:close() + + -- Mock get_project_root + stub(Utils, "get_project_root", function() return test_dir end) + end) + + after_each(function() + -- 清理测试目录 + os.execute("rm -rf " .. test_dir) + -- 恢复 mock + Utils.get_project_root:revert() + end) + + describe("list_files", function() + it("should list files in directory", function() + local result, err = LlmTools.list_files({ rel_path = ".", depth = 1 }) + assert.is_nil(err) + assert.truthy(result:find("test.txt")) + end) + end) + + describe("read_file", function() + it("should read file content", function() + local content, err = LlmTools.read_file({ rel_path = "test.txt" }) + assert.is_nil(err) + assert.equals("test content", content) + end) + + it("should return error for non-existent file", function() + local content, err = LlmTools.read_file({ rel_path = "non_existent.txt" }) + assert.truthy(err) + assert.equals("", content) + end) + end) + + describe("create_file", function() + it("should create new file", function() + local success, err = LlmTools.create_file({ rel_path = "new_file.txt" }) + assert.is_nil(err) + assert.is_true(success) + + local file_exists = io.open(test_dir .. "/new_file.txt", "r") ~= nil + assert.is_true(file_exists) + end) + end) + + describe("create_dir", function() + it("should create new directory", function() + local success, err = LlmTools.create_dir({ rel_path = "new_dir" }) + assert.is_nil(err) + assert.is_true(success) + + local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil + assert.is_true(dir_exists) + end) + end) + + describe("delete_file", function() + it("should delete existing file", function() + local success, err = LlmTools.delete_file({ rel_path = "test.txt" }) + assert.is_nil(err) + assert.is_true(success) + + local file_exists = io.open(test_file, "r") ~= nil + assert.is_false(file_exists) + end) + end) + + describe("search_files", function() + it("should find files matching pattern", function() + local result, err = LlmTools.search_files({ rel_path = ".", keyword = "test" }) + assert.is_nil(err) + assert.truthy(result:find("test.txt")) + end) + end) + + describe("search", function() + local original_exepath = vim.fn.exepath + + after_each(function() vim.fn.exepath = original_exepath end) + + it("should search using ripgrep when available", function() + -- Mock exepath to return rg path + vim.fn.exepath = function(cmd) + if cmd == "rg" then return "/usr/bin/rg" end + return "" + end + + -- Create a test file with searchable content + local file = io.open(test_dir .. "/searchable.txt", "w") + file:write("this is searchable content") + file:close() + + file = io.open(test_dir .. "/nothing.txt", "w") + file:write("this is nothing") + file:close() + + local result, err = LlmTools.search({ rel_path = ".", keyword = "searchable" }) + assert.is_nil(err) + assert.truthy(result:find("searchable.txt")) + assert.falsy(result:find("nothing.txt")) + end) + + it("should search using ag when rg is not available", function() + -- Mock exepath to return ag path + vim.fn.exepath = function(cmd) + if cmd == "ag" then return "/usr/bin/ag" end + return "" + end + + -- Create a test file specifically for ag + local file = io.open(test_dir .. "/ag_test.txt", "w") + file:write("content for ag test") + file:close() + + local result, err = LlmTools.search({ rel_path = ".", keyword = "ag test" }) + assert.is_nil(err) + assert.is_string(result) + assert.truthy(result:find("ag_test.txt")) + end) + + it("should search using grep when rg and ag are not available", function() + -- Mock exepath to return grep path + vim.fn.exepath = function(cmd) + if cmd == "grep" then return "/usr/bin/grep" end + return "" + end + + local result, err = LlmTools.search({ rel_path = ".", keyword = "test" }) + assert.is_nil(err) + assert.truthy(result:find("test.txt")) + end) + + it("should return error when no search tool is available", function() + -- Mock exepath to return nothing + vim.fn.exepath = function() return "" end + + local result, err = LlmTools.search({ rel_path = ".", keyword = "test" }) + assert.equals("", result) + assert.equals("No search command found", err) + end) + + it("should respect path permissions", function() + local result, err = LlmTools.search({ rel_path = "../outside_project", keyword = "test" }) + assert.truthy(err:find("No permission to access path")) + end) + + it("should handle non-existent paths", function() + local result, err = LlmTools.search({ rel_path = "non_existent_dir", keyword = "test" }) + assert.equals("", result) + assert.truthy(err) + assert.truthy(err:find("No such file or directory")) + end) + end) + + describe("run_command", function() + it("should execute command and return output", function() + local result, err = LlmTools.run_command({ rel_path = ".", command = "echo 'test'" }) + assert.is_nil(err) + assert.equals("test\n", result) + end) + + it("should return error when running outside current directory", function() + local result, err = LlmTools.run_command({ rel_path = "../outside_project", command = "echo 'test'" }) + assert.is_false(result) + assert.truthy(err) + assert.truthy(err:find("No permission to access path")) + end) + end) +end)