feat: tools (#1180)
* feat: tools * feat: claude use tools * feat: openai use tools
This commit is contained in:
		
							parent
							
								
									1726d32778
								
							
						
					
					
						commit
						1437f319d2
					
				
							
								
								
									
										3
									
								
								.github/workflows/lua.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/lua.yaml
									
									
									
									
										vendored
									
									
								
							| @ -38,6 +38,9 @@ jobs: | ||||
|             mkdir -p _neovim | ||||
|             curl -sL "https://github.com/neovim/neovim/releases/download/${{ matrix.rev }}" | tar xzf - --strip-components=1 -C "${PWD}/_neovim" | ||||
|           } | ||||
|           sudo apt-get update | ||||
|           sudo apt-get install -y ripgrep | ||||
|           sudo apt-get install -y silversearcher-ag | ||||
| 
 | ||||
|       - name: Run tests | ||||
|         run: | | ||||
|  | ||||
| @ -31,6 +31,7 @@ struct TemplateContext { | ||||
|     selected_code: Option<String>, | ||||
|     project_context: Option<String>, | ||||
|     diagnostics: Option<String>, | ||||
|     system_info: Option<String>, | ||||
| } | ||||
| 
 | ||||
| // Given the file name registered after add, the context table in Lua, resulted in a formatted
 | ||||
| @ -54,6 +55,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult< | ||||
|                   selected_code => context.selected_code, | ||||
|                   project_context => context.project_context, | ||||
|                   diagnostics => context.diagnostics, | ||||
|                   system_info => context.system_info, | ||||
|                 }) | ||||
|                 .map_err(LuaError::external) | ||||
|                 .unwrap()) | ||||
|  | ||||
| @ -20,6 +20,14 @@ M._defaults = { | ||||
|   -- For most providers that we support we will determine this automatically. | ||||
|   -- If you wish to use a given implementation, then you can override it here. | ||||
|   tokenizer = "tiktoken", | ||||
|   web_search_engine = { | ||||
|     provider = "tavily", | ||||
|     api_key_name = "TAVILY_API_KEY", | ||||
|     provider_opts = { | ||||
|       time_range = "d", | ||||
|       include_answer = "basic", | ||||
|     }, | ||||
|   }, | ||||
|   ---@type AvanteSupportedProvider | ||||
|   openai = { | ||||
|     endpoint = "https://api.openai.com/v1", | ||||
|  | ||||
| @ -8,6 +8,7 @@ local Utils = require("avante.utils") | ||||
| local Config = require("avante.config") | ||||
| local Path = require("avante.path") | ||||
| local P = require("avante.providers") | ||||
| local LLMTools = require("avante.llm_tools") | ||||
| 
 | ||||
| ---@class avante.LLM | ||||
| local M = {} | ||||
| @ -45,6 +46,8 @@ M.generate_prompts = function(opts) | ||||
|   local project_root = Utils.root.get() | ||||
|   Path.prompts.initialize(Path.prompts.get(project_root)) | ||||
| 
 | ||||
|   local system_info = Utils.get_system_info() | ||||
| 
 | ||||
|   local template_opts = { | ||||
|     use_xml_format = Provider.use_xml_format, | ||||
|     ask = opts.ask, -- TODO: add mode without ask instruction | ||||
| @ -53,6 +56,7 @@ M.generate_prompts = function(opts) | ||||
|     selected_code = opts.selected_code, | ||||
|     project_context = opts.project_context, | ||||
|     diagnostics = opts.diagnostics, | ||||
|     system_info = system_info, | ||||
|   } | ||||
| 
 | ||||
|   local system_prompt = Path.prompts.render_mode(mode, template_opts) | ||||
| @ -111,6 +115,10 @@ M.generate_prompts = function(opts) | ||||
|     system_prompt = system_prompt, | ||||
|     messages = messages, | ||||
|     image_paths = image_paths, | ||||
|     tools = opts.tools, | ||||
|     tool_use = opts.tool_use, | ||||
|     tool_result = opts.tool_result, | ||||
|     response_content = opts.response_content, | ||||
|   } | ||||
| end | ||||
| 
 | ||||
| @ -135,7 +143,28 @@ M._stream = function(opts) | ||||
|   local current_event_state = nil | ||||
| 
 | ||||
|   ---@type AvanteHandlerOptions | ||||
|   local handler_opts = { on_chunk = opts.on_chunk, on_complete = opts.on_complete } | ||||
|   local handler_opts = { | ||||
|     on_start = opts.on_start, | ||||
|     on_chunk = opts.on_chunk, | ||||
|     on_stop = function(stop_opts) | ||||
|       if stop_opts.reason == "tool_use" and stop_opts.tool_use then | ||||
|         local result, error = LLMTools.process_tool_use(stop_opts.tool_use) | ||||
|         local tool_result = { | ||||
|           tool_use_id = stop_opts.tool_use.id, | ||||
|           content = error ~= nil and error or result, | ||||
|           is_error = error ~= nil, | ||||
|         } | ||||
|         local new_opts = vim.tbl_deep_extend( | ||||
|           "force", | ||||
|           opts, | ||||
|           { tool_result = tool_result, tool_use = stop_opts.tool_use, response_content = stop_opts.response_content } | ||||
|         ) | ||||
|         return M._stream(new_opts) | ||||
|       end | ||||
|       return opts.on_stop(stop_opts) | ||||
|     end, | ||||
|   } | ||||
| 
 | ||||
|   ---@type AvanteCurlOutput | ||||
|   local spec = Provider.parse_curl_args(Provider, code_opts) | ||||
| 
 | ||||
| @ -180,7 +209,7 @@ M._stream = function(opts) | ||||
|     stream = function(err, data, _) | ||||
|       if err then | ||||
|         completed = true | ||||
|         opts.on_complete(err) | ||||
|         handler_opts.on_stop({ reason = "error", error = err }) | ||||
|         return | ||||
|       end | ||||
|       if not data then return end | ||||
| @ -224,7 +253,7 @@ M._stream = function(opts) | ||||
|       active_job = nil | ||||
|       completed = true | ||||
|       cleanup() | ||||
|       opts.on_complete(err) | ||||
|       handler_opts.on_stop({ reason = "error", error = err }) | ||||
|     end, | ||||
|     callback = function(result) | ||||
|       active_job = nil | ||||
| @ -238,9 +267,10 @@ M._stream = function(opts) | ||||
|         vim.schedule(function() | ||||
|           if not completed then | ||||
|             completed = true | ||||
|             opts.on_complete( | ||||
|               "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body) | ||||
|             ) | ||||
|             handler_opts.on_stop({ | ||||
|               reason = "error", | ||||
|               error = "API request failed with status " .. result.status .. ". Body: " .. vim.inspect(result.body), | ||||
|             }) | ||||
|           end | ||||
|         end) | ||||
|       end | ||||
| @ -335,9 +365,9 @@ M._dual_boost_stream = function(opts, Provider1, Provider2) | ||||
|       on_chunk = function(chunk) | ||||
|         if chunk then response = response .. chunk end | ||||
|       end, | ||||
|       on_complete = function(err) | ||||
|         if err then | ||||
|           Utils.error(string.format("Stream %d failed: %s", index, err)) | ||||
|       on_stop = function(stop_opts) | ||||
|         if stop_opts.error then | ||||
|           Utils.error(string.format("Stream %d failed: %s", index, stop_opts.error)) | ||||
|           return | ||||
|         end | ||||
|         Utils.debug(string.format("Response %d completed", index)) | ||||
| @ -381,10 +411,15 @@ end | ||||
| ---@field instructions string | ||||
| ---@field mode LlmMode | ||||
| ---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil | ||||
| ---@field tools? AvanteLLMTool[] | ||||
| ---@field tool_result? AvanteLLMToolResult | ||||
| ---@field tool_use? AvanteLLMToolUse | ||||
| ---@field response_content? string | ||||
| --- | ||||
| ---@class StreamOptions: GeneratePromptsOptions | ||||
| ---@field on_chunk AvanteChunkParser | ||||
| ---@field on_complete AvanteCompleteParser | ||||
| ---@field on_start AvanteLLMStartCallback | ||||
| ---@field on_chunk AvanteLLMChunkCallback | ||||
| ---@field on_stop AvanteLLMStopCallback | ||||
| 
 | ||||
| ---@param opts StreamOptions | ||||
| M.stream = function(opts) | ||||
| @ -396,12 +431,12 @@ M.stream = function(opts) | ||||
|       return original_on_chunk(chunk) | ||||
|     end) | ||||
|   end | ||||
|   if opts.on_complete ~= nil then | ||||
|     local original_on_complete = opts.on_complete | ||||
|     opts.on_complete = vim.schedule_wrap(function(err) | ||||
|   if opts.on_stop ~= nil then | ||||
|     local original_on_stop = opts.on_stop | ||||
|     opts.on_stop = vim.schedule_wrap(function(stop_opts) | ||||
|       if is_completed then return end | ||||
|       is_completed = true | ||||
|       return original_on_complete(err) | ||||
|       if stop_opts.reason == "complete" or stop_opts.reason == "error" then is_completed = true end | ||||
|       return original_on_stop(stop_opts) | ||||
|     end) | ||||
|   end | ||||
|   if Config.dual_boost.enabled then | ||||
|  | ||||
							
								
								
									
										714
									
								
								lua/avante/llm_tools.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										714
									
								
								lua/avante/llm_tools.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,714 @@ | ||||
| local curl = require("plenary.curl") | ||||
| local Utils = require("avante.utils") | ||||
| local Path = require("plenary.path") | ||||
| local Config = require("avante.config") | ||||
| local M = {} | ||||
| 
 | ||||
| ---@param rel_path string | ||||
| ---@return string | ||||
| local function get_abs_path(rel_path) | ||||
|   local project_root = Utils.get_project_root() | ||||
|   return Path:new(project_root):joinpath(rel_path):absolute() | ||||
| end | ||||
| 
 | ||||
| function M.comfirm(msg) | ||||
|   local ok = vim.fn.confirm(msg, "&Yes\n&No", 2) | ||||
|   return ok == 1 | ||||
| end | ||||
| 
 | ||||
| ---@param abs_path string | ||||
| ---@return boolean | ||||
| local function has_permission_to_access(abs_path) | ||||
|   if not Path:new(abs_path):is_absolute() then return false end | ||||
|   local project_root = Utils.get_project_root() | ||||
|   if abs_path:sub(1, #project_root) ~= project_root then return false end | ||||
|   local gitignore_path = project_root .. "/.gitignore" | ||||
|   local gitignore_patterns, gitignore_negate_patterns = Utils.parse_gitignore(gitignore_path) | ||||
|   return not Utils.is_ignored(abs_path, gitignore_patterns, gitignore_negate_patterns) | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string, depth?: integer } | ||||
| ---@return string files | ||||
| ---@return string|nil error | ||||
| function M.list_files(opts) | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end | ||||
|   local files = Utils.scan_directory_respect_gitignore({ | ||||
|     directory = abs_path, | ||||
|     add_dirs = true, | ||||
|     depth = opts.depth, | ||||
|   }) | ||||
|   local result = "" | ||||
|   for _, file in ipairs(files) do | ||||
|     local uniform_path = Utils.uniform_path(file) | ||||
|     result = result .. uniform_path .. "\n" | ||||
|   end | ||||
|   result = result:gsub("\n$", "") | ||||
|   return result, nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string, keyword: string } | ||||
| ---@return string files | ||||
| ---@return string|nil error | ||||
| function M.search_files(opts) | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end | ||||
|   local files = Utils.scan_directory_respect_gitignore({ | ||||
|     directory = abs_path, | ||||
|   }) | ||||
|   local result = "" | ||||
|   for _, file in ipairs(files) do | ||||
|     if file:find(opts.keyword) then result = result .. file .. "\n" end | ||||
|   end | ||||
|   result = result:gsub("\n$", "") | ||||
|   return result, nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string, keyword: string } | ||||
| ---@return string result | ||||
| ---@return string|nil error | ||||
| function M.search(opts) | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end | ||||
|   if not Path:new(abs_path):exists() then return "", "No such file or directory: " .. abs_path end | ||||
| 
 | ||||
|   ---check if any search cmd is available | ||||
|   local search_cmd = vim.fn.exepath("rg") | ||||
|   if search_cmd == "" then search_cmd = vim.fn.exepath("ag") end | ||||
|   if search_cmd == "" then search_cmd = vim.fn.exepath("ack") end | ||||
|   if search_cmd == "" then search_cmd = vim.fn.exepath("grep") end | ||||
|   if search_cmd == "" then return "", "No search command found" end | ||||
| 
 | ||||
|   ---execute the search command | ||||
|   local cmd = "" | ||||
|   if search_cmd:find("rg") then | ||||
|     cmd = string.format("%s --files-with-matches --no-ignore-vcs --ignore-case --hidden --glob '!.git'", search_cmd) | ||||
|     cmd = string.format("%s '%s' %s", cmd, opts.keyword, abs_path) | ||||
|   elseif search_cmd:find("ag") then | ||||
|     cmd = string.format("%s '%s' --nocolor --nogroup --hidden --ignore .git %s", search_cmd, opts.keyword, abs_path) | ||||
|   elseif search_cmd:find("ack") then | ||||
|     cmd = string.format("%s --nocolor --nogroup --hidden --ignore-dir .git", search_cmd) | ||||
|     cmd = string.format("%s '%s' %s", cmd, opts.keyword, abs_path) | ||||
|   elseif search_cmd:find("grep") then | ||||
|     cmd = string.format("%s -riH --exclude-dir=.git %s %s", search_cmd, opts.keyword, abs_path) | ||||
|   end | ||||
| 
 | ||||
|   Utils.debug("cmd", cmd) | ||||
|   local result = vim.fn.system(cmd) | ||||
| 
 | ||||
|   return result or "", nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string } | ||||
| ---@return string definitions | ||||
| ---@return string|nil error | ||||
| function M.read_file_toplevel_symbols(opts) | ||||
|   local RepoMap = require("avante.repo_map") | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end | ||||
|   local filetype = RepoMap.get_ts_lang(abs_path) | ||||
|   local repo_map_lib = RepoMap._init_repo_map_lib() | ||||
|   if not repo_map_lib then return "", "Failed to load avante_repo_map" end | ||||
|   local definitions = filetype | ||||
|       and repo_map_lib.stringify_definitions(filetype, Utils.file.read_content(abs_path) or "") | ||||
|     or "" | ||||
|   return definitions, nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string } | ||||
| ---@return string content | ||||
| ---@return string|nil error | ||||
| function M.read_file(opts) | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return "", "No permission to access path: " .. abs_path end | ||||
|   local file = io.open(abs_path, "r") | ||||
|   if not file then return "", "file not found: " .. abs_path end | ||||
|   local content = file:read("*a") | ||||
|   file:close() | ||||
|   return content, nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string } | ||||
| ---@return boolean success | ||||
| ---@return string|nil error | ||||
| function M.create_file(opts) | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end | ||||
|   ---create directory if it doesn't exist | ||||
|   local dir = Path:new(abs_path):parent() | ||||
|   if not dir:exists() then dir:mkdir({ parents = true }) end | ||||
|   ---create file if it doesn't exist | ||||
|   if not dir:joinpath(opts.rel_path):exists() then | ||||
|     local file = io.open(abs_path, "w") | ||||
|     if not file then return false, "file not found: " .. abs_path end | ||||
|     file:close() | ||||
|   end | ||||
| 
 | ||||
|   return true, nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string, new_rel_path: string } | ||||
| ---@return boolean success | ||||
| ---@return string|nil error | ||||
| function M.rename_file(opts) | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end | ||||
|   if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end | ||||
|   if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end | ||||
|   local new_abs_path = get_abs_path(opts.new_rel_path) | ||||
|   if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end | ||||
|   if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end | ||||
|   if not M.confirm("Are you sure you want to rename the file: " .. abs_path .. " to: " .. new_abs_path) then | ||||
|     return false, "User canceled" | ||||
|   end | ||||
|   os.rename(abs_path, new_abs_path) | ||||
|   return true, nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string, new_rel_path: string } | ||||
| ---@return boolean success | ||||
| ---@return string|nil error | ||||
| function M.copy_file(opts) | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end | ||||
|   if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end | ||||
|   if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end | ||||
|   local new_abs_path = get_abs_path(opts.new_rel_path) | ||||
|   if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end | ||||
|   if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end | ||||
|   Path:new(new_abs_path):write(Path:new(abs_path):read()) | ||||
|   return true, nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string } | ||||
| ---@return boolean success | ||||
| ---@return string|nil error | ||||
| function M.delete_file(opts) | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end | ||||
|   if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end | ||||
|   if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end | ||||
|   if not M.confirm("Are you sure you want to delete the file: " .. abs_path) then return false, "User canceled" end | ||||
|   os.remove(abs_path) | ||||
|   return true, nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string } | ||||
| ---@return boolean success | ||||
| ---@return string|nil error | ||||
| function M.create_dir(opts) | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end | ||||
|   if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end | ||||
|   Path:new(abs_path):mkdir({ parents = true }) | ||||
|   return true, nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string, new_rel_path: string } | ||||
| ---@return boolean success | ||||
| ---@return string|nil error | ||||
| function M.rename_dir(opts) | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end | ||||
|   if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end | ||||
|   if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end | ||||
|   local new_abs_path = get_abs_path(opts.new_rel_path) | ||||
|   if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end | ||||
|   if Path:new(new_abs_path):exists() then return false, "Directory already exists: " .. new_abs_path end | ||||
|   if not M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?") then | ||||
|     return false, "User canceled" | ||||
|   end | ||||
|   os.rename(abs_path, new_abs_path) | ||||
|   return true, nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string } | ||||
| ---@return boolean success | ||||
| ---@return string|nil error | ||||
| function M.delete_dir(opts) | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end | ||||
|   if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end | ||||
|   if not Path:new(abs_path):is_dir() then return false, "Path is not a directory: " .. abs_path end | ||||
|   if not M.confirm("Are you sure you want to delete the directory: " .. abs_path) then | ||||
|     return false, "User canceled" | ||||
|   end | ||||
|   os.remove(abs_path) | ||||
|   return true, nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { rel_path: string, command: string } | ||||
| ---@return string|boolean result | ||||
| ---@return string|nil error | ||||
| function M.run_command(opts) | ||||
|   local abs_path = get_abs_path(opts.rel_path) | ||||
|   if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end | ||||
|   if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end | ||||
|   if | ||||
|     not M.confirm("Are you sure you want to run the command: `" .. opts.command .. "` in the directory: " .. abs_path) | ||||
|   then | ||||
|     return false, "User canceled" | ||||
|   end | ||||
|   ---change cwd to abs_path | ||||
|   local old_cwd = vim.fn.getcwd() | ||||
|   vim.fn.chdir(abs_path) | ||||
|   local res = Utils.shell_run(opts.command) | ||||
|   vim.fn.chdir(old_cwd) | ||||
|   if res.code ~= 0 then | ||||
|     if res.stdout then return false, "Error: " .. res.stdout .. "; Error code: " .. tostring(res.code) end | ||||
|     return false, "Error code: " .. tostring(res.code) | ||||
|   end | ||||
|   return res.stdout, nil | ||||
| end | ||||
| 
 | ||||
| ---@param opts { query: string } | ||||
| ---@return string|nil result | ||||
| ---@return string|nil error | ||||
| function M.web_search(opts) | ||||
|   local search_engine = Config.web_search_engine | ||||
|   if search_engine.provider == "tavily" then | ||||
|     if search_engine.api_key_name == "" then return nil, "No API key provided" end | ||||
|     local api_key = os.getenv(search_engine.api_key_name) | ||||
|     if api_key == nil or api_key == "" then | ||||
|       return nil, "Environment variable " .. search_engine.api_key_name .. " is not set" | ||||
|     end | ||||
|     local resp = curl.post("https://api.tavily.com/search", { | ||||
|       headers = { | ||||
|         ["Content-Type"] = "application/json", | ||||
|         ["Authorization"] = "Bearer " .. api_key, | ||||
|       }, | ||||
|       body = vim.json.encode(vim.tbl_deep_extend("force", { | ||||
|         query = opts.query, | ||||
|       }, search_engine.provider_opts)), | ||||
|     }) | ||||
|     if resp.status ~= 200 then return nil, "Error: " .. resp.body end | ||||
|     local jsn = vim.json.decode(resp.body) | ||||
|     return jsn.anwser, nil | ||||
|   end | ||||
| end | ||||
| 
 | ||||
| ---@class AvanteLLMTool | ||||
| ---@field name string | ||||
| ---@field description string | ||||
| ---@field param AvanteLLMToolParam | ||||
| ---@field returns AvanteLLMToolReturn[] | ||||
| 
 | ||||
| ---@class AvanteLLMToolParam | ||||
| ---@field type string | ||||
| ---@field fields AvanteLLMToolParamField[] | ||||
| 
 | ||||
| ---@class AvanteLLMToolParamField | ||||
| ---@field name string | ||||
| ---@field description string | ||||
| ---@field type string | ||||
| ---@field optional? boolean | ||||
| 
 | ||||
| ---@class AvanteLLMToolReturn | ||||
| ---@field name string | ||||
| ---@field description string | ||||
| ---@field type string | ||||
| ---@field optional? boolean | ||||
| 
 | ||||
| ---@type AvanteLLMTool[] | ||||
| M.tools = { | ||||
|   { | ||||
|     name = "list_files", | ||||
|     description = "List files in a directory", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "rel_path", | ||||
|           description = "Relative path to the directory", | ||||
|           type = "string", | ||||
|         }, | ||||
|         { | ||||
|           name = "depth", | ||||
|           description = "Depth of the directory", | ||||
|           type = "integer", | ||||
|           optional = true, | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "files", | ||||
|         description = "List of files in the directory", | ||||
|         type = "string[]", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the directory was not listed successfully", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
|   { | ||||
|     name = "search_files", | ||||
|     description = "Search for files in a directory", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "rel_path", | ||||
|           description = "Relative path to the directory", | ||||
|           type = "string", | ||||
|         }, | ||||
|         { | ||||
|           name = "keyword", | ||||
|           description = "Keyword to search for", | ||||
|           type = "string", | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "files", | ||||
|         description = "List of files that match the keyword", | ||||
|         type = "string", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the directory was not searched successfully", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
|   { | ||||
|     name = "search", | ||||
|     description = "Search for a keyword in a directory", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "rel_path", | ||||
|           description = "Relative path to the directory", | ||||
|           type = "string", | ||||
|         }, | ||||
|         { | ||||
|           name = "keyword", | ||||
|           description = "Keyword to search for", | ||||
|           type = "string", | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "files", | ||||
|         description = "List of files that match the keyword", | ||||
|         type = "string", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the directory was not searched successfully", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
|   { | ||||
|     name = "read_file_toplevel_symbols", | ||||
|     description = "Read the top-level symbols of a file", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "rel_path", | ||||
|           description = "Relative path to the file", | ||||
|           type = "string", | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "definitions", | ||||
|         description = "Top-level symbols of the file", | ||||
|         type = "string", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the file was not read successfully", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
|   { | ||||
|     name = "read_file", | ||||
|     description = "Read the contents of a file", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "rel_path", | ||||
|           description = "Relative path to the file", | ||||
|           type = "string", | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "content", | ||||
|         description = "Contents of the file", | ||||
|         type = "string", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the file was not read successfully", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
|   { | ||||
|     name = "create_file", | ||||
|     description = "Create a new file", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "rel_path", | ||||
|           description = "Relative path to the file", | ||||
|           type = "string", | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "success", | ||||
|         description = "True if the file was created successfully, false otherwise", | ||||
|         type = "boolean", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the file was not created successfully", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
|   { | ||||
|     name = "rename_file", | ||||
|     description = "Rename a file", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "rel_path", | ||||
|           description = "Relative path to the file", | ||||
|           type = "string", | ||||
|         }, | ||||
|         { | ||||
|           name = "new_rel_path", | ||||
|           description = "New relative path for the file", | ||||
|           type = "string", | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "success", | ||||
|         description = "True if the file was renamed successfully, false otherwise", | ||||
|         type = "boolean", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the file was not renamed successfully", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
|   { | ||||
|     name = "delete_file", | ||||
|     description = "Delete a file", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "rel_path", | ||||
|           description = "Relative path to the file", | ||||
|           type = "string", | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "success", | ||||
|         description = "True if the file was deleted successfully, false otherwise", | ||||
|         type = "boolean", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the file was not deleted successfully", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
|   { | ||||
|     name = "create_dir", | ||||
|     description = "Create a new directory", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "rel_path", | ||||
|           description = "Relative path to the directory", | ||||
|           type = "string", | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "success", | ||||
|         description = "True if the directory was created successfully, false otherwise", | ||||
|         type = "boolean", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the directory was not created successfully", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
|   { | ||||
|     name = "rename_dir", | ||||
|     description = "Rename a directory", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "rel_path", | ||||
|           description = "Relative path to the directory", | ||||
|           type = "string", | ||||
|         }, | ||||
|         { | ||||
|           name = "new_rel_path", | ||||
|           description = "New relative path for the directory", | ||||
|           type = "string", | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "success", | ||||
|         description = "True if the directory was renamed successfully, false otherwise", | ||||
|         type = "boolean", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the directory was not renamed successfully", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
|   { | ||||
|     name = "delete_dir", | ||||
|     description = "Delete a directory", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "rel_path", | ||||
|           description = "Relative path to the directory", | ||||
|           type = "string", | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "success", | ||||
|         description = "True if the directory was deleted successfully, false otherwise", | ||||
|         type = "boolean", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the directory was not deleted successfully", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
|   { | ||||
|     name = "run_command", | ||||
|     description = "Run a command in a directory", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "rel_path", | ||||
|           description = "Relative path to the directory", | ||||
|           type = "string", | ||||
|         }, | ||||
|         { | ||||
|           name = "command", | ||||
|           description = "Command to run", | ||||
|           type = "string", | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "stdout", | ||||
|         description = "Output of the command", | ||||
|         type = "string", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the command was not run successfully", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
|   { | ||||
|     name = "web_search", | ||||
|     description = "Search the web", | ||||
|     param = { | ||||
|       type = "table", | ||||
|       fields = { | ||||
|         { | ||||
|           name = "query", | ||||
|           description = "Query to search", | ||||
|           type = "string", | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     returns = { | ||||
|       { | ||||
|         name = "result", | ||||
|         description = "Result of the search", | ||||
|         type = "string", | ||||
|       }, | ||||
|       { | ||||
|         name = "error", | ||||
|         description = "Error message if the search was not successful", | ||||
|         type = "string", | ||||
|         optional = true, | ||||
|       }, | ||||
|     }, | ||||
|   }, | ||||
| } | ||||
| 
 | ||||
| ---@param tool_use AvanteLLMToolUse | ||||
| ---@return string | nil result | ||||
| ---@return string | nil error | ||||
| function M.process_tool_use(tool_use) | ||||
|   Utils.debug("use tool", tool_use.name, tool_use.input_json) | ||||
|   local tool = vim.iter(M.tools):find(function(tool) return tool.name == tool_use.name end) | ||||
|   if tool == nil then return end | ||||
|   local input_json = vim.json.decode(tool_use.input_json) | ||||
|   local func = M[tool.name] | ||||
|   local result, error = func(input_json) | ||||
|   -- Utils.debug("result", result) | ||||
|   -- Utils.debug("error", error) | ||||
|   if result ~= nil and type(result) ~= "string" then result = vim.json.encode(result) end | ||||
|   return result, error | ||||
| end | ||||
| 
 | ||||
| return M | ||||
| @ -6,6 +6,8 @@ | ||||
| ---@field role "user" | "assistant" | ||||
| ---@field content [AvanteBedrockClaudeTextMessage][] | ||||
| 
 | ||||
| local Claude = require("avante.providers.claude") | ||||
| 
 | ||||
| ---@class AvanteBedrockModelHandler | ||||
| local M = {} | ||||
| 
 | ||||
| @ -33,25 +35,7 @@ M.parse_messages = function(opts) | ||||
|   return messages | ||||
| end | ||||
| 
 | ||||
| M.parse_response = function(ctx, data_stream, event_state, opts) | ||||
|   if event_state == nil then | ||||
|     if data_stream:match('"content_block_delta"') then | ||||
|       event_state = "content_block_delta" | ||||
|     elseif data_stream:match('"message_stop"') then | ||||
|       event_state = "message_stop" | ||||
|     end | ||||
|   end | ||||
|   if event_state == "content_block_delta" then | ||||
|     local ok, json = pcall(vim.json.decode, data_stream) | ||||
|     if not ok then return end | ||||
|     opts.on_chunk(json.delta.text) | ||||
|   elseif event_state == "message_stop" then | ||||
|     opts.on_complete(nil) | ||||
|     return | ||||
|   elseif event_state == "error" then | ||||
|     opts.on_complete(vim.json.decode(data_stream)) | ||||
|   end | ||||
| end | ||||
| M.parse_response = Claude.parse_response | ||||
| 
 | ||||
| ---@param prompt_opts AvantePromptOptions | ||||
| ---@param body_opts table | ||||
| @ -60,7 +44,6 @@ M.build_bedrock_payload = function(prompt_opts, body_opts) | ||||
|   local system_prompt = prompt_opts.system_prompt or "" | ||||
|   local messages = M.parse_messages(prompt_opts) | ||||
|   local max_tokens = body_opts.max_tokens or 2000 | ||||
|   local temperature = body_opts.temperature or 0.7 | ||||
|   local payload = { | ||||
|     anthropic_version = "bedrock-2023-05-31", | ||||
|     max_tokens = max_tokens, | ||||
|  | ||||
| @ -17,6 +17,44 @@ local P = require("avante.providers") | ||||
| ---@field role "user" | "assistant" | ||||
| ---@field content [AvanteClaudeTextMessage | AvanteClaudeImageMessage][] | ||||
| 
 | ||||
| ---@class AvanteClaudeTool | ||||
| ---@field name string | ||||
| ---@field description string | ||||
| ---@field input_schema AvanteClaudeToolInputSchema | ||||
| 
 | ||||
| ---@class AvanteClaudeToolInputSchema | ||||
| ---@field type "object" | ||||
| ---@field properties table<string, AvanteClaudeToolInputSchemaProperty> | ||||
| ---@field required string[] | ||||
| 
 | ||||
| ---@class AvanteClaudeToolInputSchemaProperty | ||||
| ---@field type "string" | "number" | "boolean" | ||||
| ---@field description string | ||||
| ---@field enum? string[] | ||||
| 
 | ||||
| ---@param tool AvanteLLMTool | ||||
| ---@return AvanteClaudeTool | ||||
| local function transform_tool(tool) | ||||
|   local input_schema_properties = {} | ||||
|   local required = {} | ||||
|   for _, field in ipairs(tool.param.fields) do | ||||
|     input_schema_properties[field.name] = { | ||||
|       type = field.type, | ||||
|       description = field.description, | ||||
|     } | ||||
|     if not field.optional then table.insert(required, field.name) end | ||||
|   end | ||||
|   return { | ||||
|     name = tool.name, | ||||
|     description = tool.description, | ||||
|     input_schema = { | ||||
|       type = "object", | ||||
|       properties = input_schema_properties, | ||||
|       required = required, | ||||
|     }, | ||||
|   } | ||||
| end | ||||
| 
 | ||||
| ---@class AvanteProviderFunctor | ||||
| local M = {} | ||||
| 
 | ||||
| @ -74,26 +112,101 @@ M.parse_messages = function(opts) | ||||
|     messages[#messages].content = message_content | ||||
|   end | ||||
| 
 | ||||
|   if opts.tool_use then | ||||
|     local msg = { | ||||
|       role = "assistant", | ||||
|       content = {}, | ||||
|     } | ||||
|     if opts.response_content then | ||||
|       msg.content[#msg.content + 1] = { | ||||
|         type = "text", | ||||
|         text = opts.response_content, | ||||
|       } | ||||
|     end | ||||
|     msg.content[#msg.content + 1] = { | ||||
|       type = "tool_use", | ||||
|       id = opts.tool_use.id, | ||||
|       name = opts.tool_use.name, | ||||
|       input = vim.json.decode(opts.tool_use.input_json), | ||||
|     } | ||||
|     messages[#messages + 1] = msg | ||||
|   end | ||||
| 
 | ||||
|   if opts.tool_result then | ||||
|     messages[#messages + 1] = { | ||||
|       role = "user", | ||||
|       content = { | ||||
|         { | ||||
|           type = "tool_result", | ||||
|           tool_use_id = opts.tool_result.tool_use_id, | ||||
|           content = opts.tool_result.content, | ||||
|           is_error = opts.tool_result.is_error, | ||||
|         }, | ||||
|       }, | ||||
|     } | ||||
|   end | ||||
| 
 | ||||
|   return messages | ||||
| end | ||||
| 
 | ||||
| M.parse_response = function(ctx, data_stream, event_state, opts) | ||||
|   if event_state == nil then | ||||
|     if data_stream:match('"content_block_delta"') then | ||||
|       event_state = "content_block_delta" | ||||
|     if data_stream:match('"message_start"') then | ||||
|       event_state = "message_start" | ||||
|     elseif data_stream:match('"message_delta"') then | ||||
|       event_state = "message_delta" | ||||
|     elseif data_stream:match('"message_stop"') then | ||||
|       event_state = "message_stop" | ||||
|     elseif data_stream:match('"content_block_start"') then | ||||
|       event_state = "content_block_start" | ||||
|     elseif data_stream:match('"content_block_delta"') then | ||||
|       event_state = "content_block_delta" | ||||
|     elseif data_stream:match('"content_block_stop"') then | ||||
|       event_state = "content_block_stop" | ||||
|     end | ||||
|   end | ||||
|   if event_state == "content_block_delta" then | ||||
|     local ok, json = pcall(vim.json.decode, data_stream) | ||||
|   if event_state == "message_start" then | ||||
|     local ok, jsn = pcall(vim.json.decode, data_stream) | ||||
|     if not ok then return end | ||||
|     opts.on_chunk(json.delta.text) | ||||
|   elseif event_state == "message_stop" then | ||||
|     opts.on_complete(nil) | ||||
|     opts.on_start(jsn.message.usage) | ||||
|   elseif event_state == "content_block_start" then | ||||
|     local ok, jsn = pcall(vim.json.decode, data_stream) | ||||
|     if not ok then return end | ||||
|     if jsn.content_block.type == "tool_use" then | ||||
|       ctx.tool_use = { | ||||
|         name = jsn.content_block.name, | ||||
|         id = jsn.content_block.id, | ||||
|         input_json = "", | ||||
|       } | ||||
|     elseif jsn.content_block.type == "text" then | ||||
|       ctx.response_content = "" | ||||
|     end | ||||
|   elseif event_state == "content_block_delta" then | ||||
|     local ok, jsn = pcall(vim.json.decode, data_stream) | ||||
|     if not ok then return end | ||||
|     if ctx.tool_use and jsn.delta.type == "input_json_delta" then | ||||
|       ctx.tool_use.input_json = ctx.tool_use.input_json .. jsn.delta.partial_json | ||||
|       return | ||||
|     elseif ctx.response_content and jsn.delta.type == "text_delta" then | ||||
|       ctx.response_content = ctx.response_content .. jsn.delta.text | ||||
|     end | ||||
|     opts.on_chunk(jsn.delta.text) | ||||
|   elseif event_state == "message_delta" then | ||||
|     local ok, jsn = pcall(vim.json.decode, data_stream) | ||||
|     if not ok then return end | ||||
|     if jsn.delta.stop_reason == "end_turn" then | ||||
|       opts.on_stop({ reason = "complete", usage = jsn.usage }) | ||||
|     elseif jsn.delta.stop_reason == "tool_use" then | ||||
|       opts.on_stop({ | ||||
|         reason = "tool_use", | ||||
|         usage = jsn.usage, | ||||
|         tool_use = ctx.tool_use, | ||||
|         response_content = ctx.response_content, | ||||
|       }) | ||||
|     end | ||||
|     return | ||||
|   elseif event_state == "error" then | ||||
|     opts.on_complete(vim.json.decode(data_stream)) | ||||
|     opts.on_stop({ reason = "error", error = vim.json.decode(data_stream) }) | ||||
|   end | ||||
| end | ||||
| 
 | ||||
| @ -113,6 +226,13 @@ M.parse_curl_args = function(provider, prompt_opts) | ||||
| 
 | ||||
|   local messages = M.parse_messages(prompt_opts) | ||||
| 
 | ||||
|   local tools = {} | ||||
|   if prompt_opts.tools then | ||||
|     for _, tool in ipairs(prompt_opts.tools) do | ||||
|       table.insert(tools, transform_tool(tool)) | ||||
|     end | ||||
|   end | ||||
| 
 | ||||
|   return { | ||||
|     url = Utils.url_join(base.endpoint, "/v1/messages"), | ||||
|     proxy = base.proxy, | ||||
| @ -128,6 +248,7 @@ M.parse_curl_args = function(provider, prompt_opts) | ||||
|         }, | ||||
|       }, | ||||
|       messages = messages, | ||||
|       tools = tools, | ||||
|       stream = true, | ||||
|     }, body_opts), | ||||
|   } | ||||
|  | ||||
| @ -62,7 +62,7 @@ M.parse_stream_data = function(data, opts) | ||||
|   local json = vim.json.decode(data) | ||||
|   if json.type ~= nil then | ||||
|     if json.type == "message-end" and json.delta.finish_reason == "COMPLETE" then | ||||
|       opts.on_complete(nil) | ||||
|       opts.on_stop({ reason = "complete" }) | ||||
|       return | ||||
|     end | ||||
|     if json.type == "content-delta" then opts.on_chunk(json.delta.message.content.text) end | ||||
|  | ||||
| @ -66,17 +66,17 @@ end | ||||
| 
 | ||||
| M.parse_response = function(ctx, data_stream, _, opts) | ||||
|   local ok, json = pcall(vim.json.decode, data_stream) | ||||
|   if not ok then opts.on_complete(json) end | ||||
|   if not ok then opts.on_stop({ reason = "error", error = json }) end | ||||
|   if json.candidates then | ||||
|     if #json.candidates > 0 then | ||||
|       if json.candidates[1].finishReason and json.candidates[1].finishReason == "STOP" then | ||||
|         opts.on_chunk(json.candidates[1].content.parts[1].text) | ||||
|         opts.on_complete(nil) | ||||
|         opts.on_stop({ reason = "complete" }) | ||||
|       else | ||||
|         opts.on_chunk(json.candidates[1].content.parts[1].text) | ||||
|       end | ||||
|     else | ||||
|       opts.on_complete(nil) | ||||
|       opts.on_stop({ reason = "complete" }) | ||||
|     end | ||||
|   end | ||||
| end | ||||
|  | ||||
| @ -11,17 +11,28 @@ local DressingConfig = { | ||||
| local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } | ||||
| 
 | ||||
| ---@class AvanteHandlerOptions: table<[string], string> | ||||
| ---@field on_chunk AvanteChunkParser | ||||
| ---@field on_complete AvanteCompleteParser | ||||
| ---@field on_start AvanteLLMStartCallback | ||||
| ---@field on_chunk AvanteLLMChunkCallback | ||||
| ---@field on_stop AvanteLLMStopCallback | ||||
| --- | ||||
| ---@class AvanteLLMMessage | ||||
| ---@field role "user" | "assistant" | ||||
| ---@field content string | ||||
| --- | ||||
| ---@class AvanteLLMToolResult | ||||
| ---@field tool_name string | ||||
| ---@field tool_use_id string | ||||
| ---@field content string | ||||
| ---@field is_error? boolean | ||||
| --- | ||||
| ---@class AvantePromptOptions: table<[string], string> | ||||
| ---@field system_prompt string | ||||
| ---@field messages AvanteLLMMessage[] | ||||
| ---@field image_paths? string[] | ||||
| ---@field tools? AvanteLLMTool[] | ||||
| ---@field tool_result? AvanteLLMToolResult | ||||
| ---@field tool_use? AvanteLLMToolUse | ||||
| ---@field response_content? string | ||||
| --- | ||||
| ---@class AvanteGeminiMessage | ||||
| ---@field role "user" | ||||
| @ -35,8 +46,9 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } | ||||
| ---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput | ||||
| --- | ||||
| ---@class ResponseParser | ||||
| ---@field on_chunk fun(chunk: string): any | ||||
| ---@field on_complete fun(err: string|nil): any | ||||
| ---@field on_start AvanteLLMStartCallback | ||||
| ---@field on_chunk AvanteLLMChunkCallback | ||||
| ---@field on_stop AvanteLLMStopCallback | ||||
| ---@alias AvanteResponseParser fun(ctx: any, data_stream: string, event_state: string, opts: ResponseParser): nil | ||||
| --- | ||||
| ---@class AvanteDefaultBaseProvider: table<string, any> | ||||
| @ -54,9 +66,31 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } | ||||
| ---@field temperature? number | ||||
| ---@field max_tokens? number | ||||
| --- | ||||
| ---@class AvanteLLMUsage | ||||
| ---@field input_tokens number | ||||
| ---@field cache_creation_input_tokens number | ||||
| ---@field cache_read_input_tokens number | ||||
| ---@field output_tokens number | ||||
| --- | ||||
| ---@class AvanteLLMToolUse | ||||
| ---@field name string | ||||
| ---@field id string | ||||
| ---@field input_json string | ||||
| --- | ||||
| ---@class AvanteLLMStartCallbackOptions | ||||
| ---@field usage? AvanteLLMUsage | ||||
| --- | ||||
| ---@class AvanteLLMStopCallbackOptions | ||||
| ---@field reason "complete" | "tool_use" | "error" | ||||
| ---@field error? string | table | ||||
| ---@field usage? AvanteLLMUsage | ||||
| ---@field tool_use? AvanteLLMToolUse | ||||
| ---@field response_content? string | ||||
| --- | ||||
| ---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil | ||||
| ---@alias AvanteChunkParser fun(chunk: string): any | ||||
| ---@alias AvanteCompleteParser fun(err: string|nil): nil | ||||
| ---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil | ||||
| ---@alias AvanteLLMChunkCallback fun(chunk: string): any | ||||
| ---@alias AvanteLLMStopCallback fun(opts: AvanteLLMStopCallbackOptions): nil | ||||
| ---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table<string, any> | ||||
| --- | ||||
| ---@class AvanteProvider: AvanteSupportedProvider | ||||
|  | ||||
| @ -24,12 +24,72 @@ local P = require("avante.providers") | ||||
| ---@field index integer | ||||
| ---@field logprobs integer | ||||
| --- | ||||
| ---@class OpenAIMessageToolCallFunction | ||||
| ---@field name string | ||||
| ---@field arguments string | ||||
| --- | ||||
| ---@class OpenAIMessageToolCall | ||||
| ---@field id string | ||||
| ---@field type "function" | ||||
| ---@field function OpenAIMessageToolCallFunction | ||||
| --- | ||||
| ---@class OpenAIMessage | ||||
| ---@field role? "user" | "system" | "assistant" | ||||
| ---@field content? string | ||||
| ---@field reasoning_content? string | ||||
| ---@field reasoning? string | ||||
| ---@field tool_calls? OpenAIMessageToolCall[] | ||||
| --- | ||||
| ---@class AvanteOpenAITool | ||||
| ---@field type "function" | ||||
| ---@field function AvanteOpenAIToolFunction | ||||
| --- | ||||
| ---@class AvanteOpenAIToolFunction | ||||
| ---@field name string | ||||
| ---@field description string | ||||
| ---@field parameters AvanteOpenAIToolFunctionParameters | ||||
| ---@field strict boolean | ||||
| --- | ||||
| ---@class AvanteOpenAIToolFunctionParameters | ||||
| ---@field type string | ||||
| ---@field properties table<string, AvanteOpenAIToolFunctionParameterProperty> | ||||
| ---@field required string[] | ||||
| ---@field additionalProperties boolean | ||||
| --- | ||||
| ---@class AvanteOpenAIToolFunctionParameterProperty | ||||
| ---@field type string | ||||
| ---@field description string | ||||
| 
 | ||||
| ---@param tool AvanteLLMTool | ||||
| ---@return AvanteOpenAITool | ||||
| local function transform_tool(tool) | ||||
|   local input_schema_properties = {} | ||||
|   local required = {} | ||||
|   for _, field in ipairs(tool.param.fields) do | ||||
|     input_schema_properties[field.name] = { | ||||
|       type = field.type, | ||||
|       description = field.description, | ||||
|     } | ||||
|     if not field.optional then table.insert(required, field.name) end | ||||
|   end | ||||
|   local res = { | ||||
|     type = "function", | ||||
|     ["function"] = { | ||||
|       name = tool.name, | ||||
|       description = tool.description, | ||||
|     }, | ||||
|   } | ||||
|   if vim.tbl_count(input_schema_properties) > 0 then | ||||
|     res["function"].parameters = { | ||||
|       type = "object", | ||||
|       properties = input_schema_properties, | ||||
|       required = required, | ||||
|       additionalProperties = false, | ||||
|     } | ||||
|   end | ||||
|   return res | ||||
| end | ||||
| 
 | ||||
| ---@class AvanteProviderFunctor | ||||
| local M = {} | ||||
| 
 | ||||
| @ -107,12 +167,34 @@ M.parse_messages = function(opts) | ||||
|     table.insert(final_messages, { role = M.role_map[role] or role, content = message.content }) | ||||
|   end) | ||||
| 
 | ||||
|   if opts.tool_result then | ||||
|     table.insert(final_messages, { | ||||
|       role = M.role_map["assistant"], | ||||
|       tool_calls = { | ||||
|         { | ||||
|           id = opts.tool_use.id, | ||||
|           type = "function", | ||||
|           ["function"] = { | ||||
|             name = opts.tool_use.name, | ||||
|             arguments = opts.tool_use.input_json, | ||||
|           }, | ||||
|         }, | ||||
|       }, | ||||
|     }) | ||||
|     local result_content = opts.tool_result.content or "" | ||||
|     table.insert(final_messages, { | ||||
|       role = "tool", | ||||
|       tool_call_id = opts.tool_result.tool_use_id, | ||||
|       content = opts.tool_result.is_error and "Error: " .. result_content or result_content, | ||||
|     }) | ||||
|   end | ||||
| 
 | ||||
|   return final_messages | ||||
| end | ||||
| 
 | ||||
| M.parse_response = function(ctx, data_stream, _, opts) | ||||
|   if data_stream:match('"%[DONE%]":') then | ||||
|     opts.on_complete(nil) | ||||
|     opts.on_stop({ reason = "complete" }) | ||||
|     return | ||||
|   end | ||||
|   if data_stream:match('"delta":') then | ||||
| @ -121,7 +203,14 @@ M.parse_response = function(ctx, data_stream, _, opts) | ||||
|     if jsn.choices and jsn.choices[1] then | ||||
|       local choice = jsn.choices[1] | ||||
|       if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then | ||||
|         opts.on_complete(nil) | ||||
|         opts.on_stop({ reason = "complete" }) | ||||
|       elseif choice.finish_reason == "tool_calls" then | ||||
|         opts.on_stop({ | ||||
|           reason = "tool_use", | ||||
|           usage = jsn.usage, | ||||
|           tool_use = ctx.tool_use, | ||||
|           response_content = ctx.response_content, | ||||
|         }) | ||||
|       elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then | ||||
|         if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then | ||||
|           ctx.returned_think_start_tag = true | ||||
| @ -136,6 +225,17 @@ M.parse_response = function(ctx, data_stream, _, opts) | ||||
|         end | ||||
|         ctx.last_think_content = choice.delta.reasoning | ||||
|         opts.on_chunk(choice.delta.reasoning) | ||||
|       elseif choice.delta.tool_calls then | ||||
|         local tool_call = choice.delta.tool_calls[1] | ||||
|         if not ctx.tool_use then | ||||
|           ctx.tool_use = { | ||||
|             name = tool_call["function"].name, | ||||
|             id = tool_call.id, | ||||
|             input_json = "", | ||||
|           } | ||||
|         else | ||||
|           ctx.tool_use.input_json = ctx.tool_use.input_json .. tool_call["function"].arguments | ||||
|         end | ||||
|       elseif choice.delta.content then | ||||
|         if | ||||
|           ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag) | ||||
| @ -164,7 +264,7 @@ M.parse_response_without_stream = function(data, _, opts) | ||||
|     local choice = json.choices[1] | ||||
|     if choice.message and choice.message.content then | ||||
|       opts.on_chunk(choice.message.content) | ||||
|       vim.schedule(function() opts.on_complete(nil) end) | ||||
|       vim.schedule(function() opts.on_stop({ reason = "complete" }) end) | ||||
|     end | ||||
|   end | ||||
| end | ||||
| @ -198,6 +298,13 @@ M.parse_curl_args = function(provider, code_opts) | ||||
|     body_opts.temperature = 1 | ||||
|   end | ||||
| 
 | ||||
|   local tools = {} | ||||
|   if code_opts.tools then | ||||
|     for _, tool in ipairs(code_opts.tools) do | ||||
|       table.insert(tools, transform_tool(tool)) | ||||
|     end | ||||
|   end | ||||
| 
 | ||||
|   Utils.debug("endpoint", base.endpoint) | ||||
|   Utils.debug("model", base.model) | ||||
| 
 | ||||
| @ -210,6 +317,7 @@ M.parse_curl_args = function(provider, code_opts) | ||||
|       model = base.model, | ||||
|       messages = M.parse_messages(code_opts), | ||||
|       stream = stream, | ||||
|       tools = tools, | ||||
|     }, body_opts), | ||||
|   } | ||||
| end | ||||
|  | ||||
| @ -157,7 +157,10 @@ function Selection:create_editing_input() | ||||
| 
 | ||||
|     self.prompt_input:start_spinner() | ||||
| 
 | ||||
|     ---@type AvanteChunkParser | ||||
|     ---@type AvanteLLMStartCallback | ||||
|     local on_start = function(start_opts) end | ||||
| 
 | ||||
|     ---@type AvanteLLMChunkCallback | ||||
|     local on_chunk = function(chunk) | ||||
|       full_response = full_response .. chunk | ||||
|       local response_lines_ = vim.split(full_response, "\n") | ||||
| @ -182,13 +185,15 @@ function Selection:create_editing_input() | ||||
|       finish_line = start_line + #response_lines - 1 | ||||
|     end | ||||
| 
 | ||||
|     ---@type AvanteCompleteParser | ||||
|     local on_complete = function(err) | ||||
|       if err then | ||||
|     ---@type AvanteLLMStopCallback | ||||
|     local on_stop = function(stop_opts) | ||||
|       if stop_opts.error then | ||||
|         -- NOTE: in Ubuntu 22.04+ you will see this ignorable error from ~/.local/share/nvim/lazy/avante.nvim/lua/avante/llm.lua `on_error = function(err)`, check to avoid showing this error. | ||||
|         if type(err) == "table" and err.exit == nil and err.stderr == "{}" then return end | ||||
|         if type(stop_opts.error) == "table" and stop_opts.error.exit == nil and stop_opts.error.stderr == "{}" then | ||||
|           return | ||||
|         end | ||||
|         Utils.error( | ||||
|           "Error occurred while processing the response: " .. vim.inspect(err), | ||||
|           "Error occurred while processing the response: " .. vim.inspect(stop_opts.error), | ||||
|           { once = true, title = "Avante" } | ||||
|         ) | ||||
|         return | ||||
| @ -216,8 +221,9 @@ function Selection:create_editing_input() | ||||
|       selected_code = self.selection.content, | ||||
|       instructions = input, | ||||
|       mode = "editing", | ||||
|       on_start = on_start, | ||||
|       on_chunk = on_chunk, | ||||
|       on_complete = on_complete, | ||||
|       on_stop = on_stop, | ||||
|     }) | ||||
|   end | ||||
| 
 | ||||
|  | ||||
| @ -13,6 +13,7 @@ local Utils = require("avante.utils") | ||||
| local Highlights = require("avante.highlights") | ||||
| local RepoMap = require("avante.repo_map") | ||||
| local FileSelector = require("avante.file_selector") | ||||
| local LLMTools = require("avante.llm_tools") | ||||
| 
 | ||||
| local RESULT_BUF_NAME = "AVANTE_RESULT" | ||||
| local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated" | ||||
| @ -1669,6 +1670,7 @@ function Sidebar:create_input_container(opts) | ||||
|       selected_code = selected_code_content, | ||||
|       instructions = request, | ||||
|       mode = "planning", | ||||
|       tools = LLMTools.tools, | ||||
|     } | ||||
|   end | ||||
| 
 | ||||
| @ -1755,7 +1757,10 @@ function Sidebar:create_input_container(opts) | ||||
|     vim.keymap.set("n", "k", on_k, { buffer = self.result_container.bufnr }) | ||||
|     vim.keymap.set("n", "G", on_G, { buffer = self.result_container.bufnr }) | ||||
| 
 | ||||
|     ---@type AvanteChunkParser | ||||
|     ---@type AvanteLLMStartCallback | ||||
|     local on_start = function(start_opts) end | ||||
| 
 | ||||
|     ---@type AvanteLLMChunkCallback | ||||
|     local on_chunk = function(chunk) | ||||
|       original_response = original_response .. chunk | ||||
| 
 | ||||
| @ -1778,8 +1783,8 @@ function Sidebar:create_input_container(opts) | ||||
|       displayed_response = cur_displayed_response | ||||
|     end | ||||
| 
 | ||||
|     ---@type AvanteCompleteParser | ||||
|     local on_complete = function(err) | ||||
|     ---@type AvanteLLMStopCallback | ||||
|     local on_stop = function(stop_opts) | ||||
|       pcall(function() | ||||
|         ---remove keymaps | ||||
|         vim.keymap.del("n", "j", { buffer = self.result_container.bufnr }) | ||||
| @ -1787,9 +1792,9 @@ function Sidebar:create_input_container(opts) | ||||
|         vim.keymap.del("n", "G", { buffer = self.result_container.bufnr }) | ||||
|       end) | ||||
| 
 | ||||
|       if err ~= nil then | ||||
|       if stop_opts.error ~= nil then | ||||
|         self:update_content( | ||||
|           content_prefix .. displayed_response .. "\n\nError: " .. vim.inspect(err), | ||||
|           content_prefix .. displayed_response .. "\n\nError: " .. vim.inspect(stop_opts.error), | ||||
|           { scroll = scroll } | ||||
|         ) | ||||
|         return | ||||
| @ -1835,8 +1840,9 @@ function Sidebar:create_input_container(opts) | ||||
|     ---@type StreamOptions | ||||
|     ---@diagnostic disable-next-line: assign-type-mismatch | ||||
|     local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, { | ||||
|       on_start = on_start, | ||||
|       on_chunk = on_chunk, | ||||
|       on_complete = on_complete, | ||||
|       on_stop = on_stop, | ||||
|     }) | ||||
| 
 | ||||
|     Llm.stream(stream_options) | ||||
|  | ||||
| @ -141,8 +141,10 @@ L5:    pass | ||||
|     history_messages = history_messages, | ||||
|     instructions = vim.json.encode(doc), | ||||
|     mode = "suggesting", | ||||
|     on_start = function(_) end, | ||||
|     on_chunk = function(chunk) full_response = full_response .. chunk end, | ||||
|     on_complete = function(err) | ||||
|     on_stop = function(stop_opts) | ||||
|       local err = stop_opts.error | ||||
|       if err then | ||||
|         Utils.error("Error while suggesting: " .. vim.inspect(err), { once = true, title = "Avante" }) | ||||
|         return | ||||
|  | ||||
| @ -10,5 +10,13 @@ | ||||
| Act as an expert software developer. | ||||
| Always use best practices when coding. | ||||
| Respect and use existing conventions, libraries, etc that are already present in the code base. | ||||
| You have access to tools, but only use them when necessary. If a tool is not required, respond as normal. | ||||
| If you have information that you don't know, please proactively use the tools provided by users! Especially the web search tool. | ||||
| 
 | ||||
| {% if system_info -%} | ||||
| Use the appropriate shell based on the user's system info: | ||||
| {{system_info}} | ||||
| {%- endif %} | ||||
| 
 | ||||
| {% block extra_prompt %} | ||||
| {% endblock %} | ||||
|  | ||||
| @ -47,6 +47,31 @@ M.get_os_name = function() | ||||
|   end | ||||
| end | ||||
| 
 | ||||
| M.get_system_info = function() | ||||
|   local os_name = vim.loop.os_uname().sysname | ||||
|   local os_version = vim.loop.os_uname().release | ||||
|   local os_machine = vim.loop.os_uname().machine | ||||
|   local lang = os.getenv("LANG") | ||||
| 
 | ||||
|   local res = string.format( | ||||
|     "- Platform: %s-%s-%s\n- Shell: %s\n- Language: %s\n- Current date: %s", | ||||
|     os_name, | ||||
|     os_version, | ||||
|     os_machine, | ||||
|     vim.o.shell, | ||||
|     lang, | ||||
|     os.date("%Y-%m-%d") | ||||
|   ) | ||||
| 
 | ||||
|   local project_root = M.root.get() | ||||
|   if project_root then res = res .. string.format("\n- Project root: %s", project_root) end | ||||
| 
 | ||||
|   local is_git_repo = vim.fn.isdirectory(".git") == 1 | ||||
|   if is_git_repo then res = res .. "\n- The user is operating inside a git repository" end | ||||
| 
 | ||||
|   return res | ||||
| end | ||||
| 
 | ||||
| --- This function will run given shell command synchronously. | ||||
| ---@param input_cmd string | ||||
| ---@return vim.SystemCompleted | ||||
| @ -622,6 +647,7 @@ function M.parse_gitignore(gitignore_path) | ||||
|   end | ||||
| 
 | ||||
|   file:close() | ||||
|   ignore_patterns = vim.list_extend(ignore_patterns, { "%.git", "%.worktree", "__pycache__", "node_modules" }) | ||||
|   return ignore_patterns, negate_patterns | ||||
| end | ||||
| 
 | ||||
| @ -635,26 +661,28 @@ function M.is_ignored(file, ignore_patterns, negate_patterns) | ||||
|   return false | ||||
| end | ||||
| 
 | ||||
| ---@param options { directory: string, add_dirs?: boolean } | ||||
| ---@param options { directory: string, add_dirs?: boolean, depth?: integer } | ||||
| function M.scan_directory_respect_gitignore(options) | ||||
|   local directory = options.directory | ||||
|   local gitignore_path = directory .. "/.gitignore" | ||||
|   local gitignore_patterns, gitignore_negate_patterns = M.parse_gitignore(gitignore_path) | ||||
|   gitignore_patterns = vim.list_extend(gitignore_patterns, { "%.git", "%.worktree", "__pycache__", "node_modules" }) | ||||
|   return M.scan_directory({ | ||||
|     directory = directory, | ||||
|     gitignore_patterns = gitignore_patterns, | ||||
|     gitignore_negate_patterns = gitignore_negate_patterns, | ||||
|     add_dirs = options.add_dirs, | ||||
|     depth = options.depth, | ||||
|   }) | ||||
| end | ||||
| 
 | ||||
| ---@param options { directory: string, gitignore_patterns: string[], gitignore_negate_patterns: string[], add_dirs?: boolean } | ||||
| ---@param options { directory: string, gitignore_patterns: string[], gitignore_negate_patterns: string[], add_dirs?: boolean, depth?: integer, current_depth?: integer } | ||||
| function M.scan_directory(options) | ||||
|   local directory = options.directory | ||||
|   local ignore_patterns = options.gitignore_patterns | ||||
|   local negate_patterns = options.gitignore_negate_patterns | ||||
|   local add_dirs = options.add_dirs or false | ||||
|   local depth = options.depth or -1 | ||||
|   local current_depth = options.current_depth or 0 | ||||
| 
 | ||||
|   local files = {} | ||||
|   local handle = vim.loop.fs_scandir(directory) | ||||
| @ -662,6 +690,8 @@ function M.scan_directory(options) | ||||
|   if not handle then return files end | ||||
| 
 | ||||
|   while true do | ||||
|     if depth > 0 and current_depth >= depth then break end | ||||
| 
 | ||||
|     local name, type = vim.loop.fs_scandir_next(handle) | ||||
|     if not name then break end | ||||
| 
 | ||||
| @ -677,6 +707,7 @@ function M.scan_directory(options) | ||||
|           gitignore_patterns = ignore_patterns, | ||||
|           gitignore_negate_patterns = negate_patterns, | ||||
|           add_dirs = add_dirs, | ||||
|           current_depth = current_depth + 1, | ||||
|         }) | ||||
|       ) | ||||
|     elseif type == "file" then | ||||
|  | ||||
							
								
								
									
										186
									
								
								tests/llm_tools_spec.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								tests/llm_tools_spec.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,186 @@ | ||||
| local mock = require("luassert.mock") | ||||
| local stub = require("luassert.stub") | ||||
| local LlmTools = require("avante.llm_tools") | ||||
| local Utils = require("avante.utils") | ||||
| 
 | ||||
| LlmTools.confirm = function(msg) return true end | ||||
| 
 | ||||
| describe("llm_tools", function() | ||||
|   local test_dir = "/tmp/test_llm_tools" | ||||
|   local test_file = test_dir .. "/test.txt" | ||||
| 
 | ||||
|   before_each(function() | ||||
|     -- 创建测试目录和文件 | ||||
|     os.execute("mkdir -p " .. test_dir) | ||||
|     local file = io.open(test_file, "w") | ||||
|     file:write("test content") | ||||
|     file:close() | ||||
| 
 | ||||
|     -- Mock get_project_root | ||||
|     stub(Utils, "get_project_root", function() return test_dir end) | ||||
|   end) | ||||
| 
 | ||||
|   after_each(function() | ||||
|     -- 清理测试目录 | ||||
|     os.execute("rm -rf " .. test_dir) | ||||
|     -- 恢复 mock | ||||
|     Utils.get_project_root:revert() | ||||
|   end) | ||||
| 
 | ||||
|   describe("list_files", function() | ||||
|     it("should list files in directory", function() | ||||
|       local result, err = LlmTools.list_files({ rel_path = ".", depth = 1 }) | ||||
|       assert.is_nil(err) | ||||
|       assert.truthy(result:find("test.txt")) | ||||
|     end) | ||||
|   end) | ||||
| 
 | ||||
|   describe("read_file", function() | ||||
|     it("should read file content", function() | ||||
|       local content, err = LlmTools.read_file({ rel_path = "test.txt" }) | ||||
|       assert.is_nil(err) | ||||
|       assert.equals("test content", content) | ||||
|     end) | ||||
| 
 | ||||
|     it("should return error for non-existent file", function() | ||||
|       local content, err = LlmTools.read_file({ rel_path = "non_existent.txt" }) | ||||
|       assert.truthy(err) | ||||
|       assert.equals("", content) | ||||
|     end) | ||||
|   end) | ||||
| 
 | ||||
|   describe("create_file", function() | ||||
|     it("should create new file", function() | ||||
|       local success, err = LlmTools.create_file({ rel_path = "new_file.txt" }) | ||||
|       assert.is_nil(err) | ||||
|       assert.is_true(success) | ||||
| 
 | ||||
|       local file_exists = io.open(test_dir .. "/new_file.txt", "r") ~= nil | ||||
|       assert.is_true(file_exists) | ||||
|     end) | ||||
|   end) | ||||
| 
 | ||||
|   describe("create_dir", function() | ||||
|     it("should create new directory", function() | ||||
|       local success, err = LlmTools.create_dir({ rel_path = "new_dir" }) | ||||
|       assert.is_nil(err) | ||||
|       assert.is_true(success) | ||||
| 
 | ||||
|       local dir_exists = io.open(test_dir .. "/new_dir", "r") ~= nil | ||||
|       assert.is_true(dir_exists) | ||||
|     end) | ||||
|   end) | ||||
| 
 | ||||
|   describe("delete_file", function() | ||||
|     it("should delete existing file", function() | ||||
|       local success, err = LlmTools.delete_file({ rel_path = "test.txt" }) | ||||
|       assert.is_nil(err) | ||||
|       assert.is_true(success) | ||||
| 
 | ||||
|       local file_exists = io.open(test_file, "r") ~= nil | ||||
|       assert.is_false(file_exists) | ||||
|     end) | ||||
|   end) | ||||
| 
 | ||||
|   describe("search_files", function() | ||||
|     it("should find files matching pattern", function() | ||||
|       local result, err = LlmTools.search_files({ rel_path = ".", keyword = "test" }) | ||||
|       assert.is_nil(err) | ||||
|       assert.truthy(result:find("test.txt")) | ||||
|     end) | ||||
|   end) | ||||
| 
 | ||||
|   describe("search", function() | ||||
|     local original_exepath = vim.fn.exepath | ||||
| 
 | ||||
|     after_each(function() vim.fn.exepath = original_exepath end) | ||||
| 
 | ||||
|     it("should search using ripgrep when available", function() | ||||
|       -- Mock exepath to return rg path | ||||
|       vim.fn.exepath = function(cmd) | ||||
|         if cmd == "rg" then return "/usr/bin/rg" end | ||||
|         return "" | ||||
|       end | ||||
| 
 | ||||
|       -- Create a test file with searchable content | ||||
|       local file = io.open(test_dir .. "/searchable.txt", "w") | ||||
|       file:write("this is searchable content") | ||||
|       file:close() | ||||
| 
 | ||||
|       file = io.open(test_dir .. "/nothing.txt", "w") | ||||
|       file:write("this is nothing") | ||||
|       file:close() | ||||
| 
 | ||||
|       local result, err = LlmTools.search({ rel_path = ".", keyword = "searchable" }) | ||||
|       assert.is_nil(err) | ||||
|       assert.truthy(result:find("searchable.txt")) | ||||
|       assert.falsy(result:find("nothing.txt")) | ||||
|     end) | ||||
| 
 | ||||
|     it("should search using ag when rg is not available", function() | ||||
|       -- Mock exepath to return ag path | ||||
|       vim.fn.exepath = function(cmd) | ||||
|         if cmd == "ag" then return "/usr/bin/ag" end | ||||
|         return "" | ||||
|       end | ||||
| 
 | ||||
|       -- Create a test file specifically for ag | ||||
|       local file = io.open(test_dir .. "/ag_test.txt", "w") | ||||
|       file:write("content for ag test") | ||||
|       file:close() | ||||
| 
 | ||||
|       local result, err = LlmTools.search({ rel_path = ".", keyword = "ag test" }) | ||||
|       assert.is_nil(err) | ||||
|       assert.is_string(result) | ||||
|       assert.truthy(result:find("ag_test.txt")) | ||||
|     end) | ||||
| 
 | ||||
|     it("should search using grep when rg and ag are not available", function() | ||||
|       -- Mock exepath to return grep path | ||||
|       vim.fn.exepath = function(cmd) | ||||
|         if cmd == "grep" then return "/usr/bin/grep" end | ||||
|         return "" | ||||
|       end | ||||
| 
 | ||||
|       local result, err = LlmTools.search({ rel_path = ".", keyword = "test" }) | ||||
|       assert.is_nil(err) | ||||
|       assert.truthy(result:find("test.txt")) | ||||
|     end) | ||||
| 
 | ||||
|     it("should return error when no search tool is available", function() | ||||
|       -- Mock exepath to return nothing | ||||
|       vim.fn.exepath = function() return "" end | ||||
| 
 | ||||
|       local result, err = LlmTools.search({ rel_path = ".", keyword = "test" }) | ||||
|       assert.equals("", result) | ||||
|       assert.equals("No search command found", err) | ||||
|     end) | ||||
| 
 | ||||
|     it("should respect path permissions", function() | ||||
|       local result, err = LlmTools.search({ rel_path = "../outside_project", keyword = "test" }) | ||||
|       assert.truthy(err:find("No permission to access path")) | ||||
|     end) | ||||
| 
 | ||||
|     it("should handle non-existent paths", function() | ||||
|       local result, err = LlmTools.search({ rel_path = "non_existent_dir", keyword = "test" }) | ||||
|       assert.equals("", result) | ||||
|       assert.truthy(err) | ||||
|       assert.truthy(err:find("No such file or directory")) | ||||
|     end) | ||||
|   end) | ||||
| 
 | ||||
|   describe("run_command", function() | ||||
|     it("should execute command and return output", function() | ||||
|       local result, err = LlmTools.run_command({ rel_path = ".", command = "echo 'test'" }) | ||||
|       assert.is_nil(err) | ||||
|       assert.equals("test\n", result) | ||||
|     end) | ||||
| 
 | ||||
|     it("should return error when running outside current directory", function() | ||||
|       local result, err = LlmTools.run_command({ rel_path = "../outside_project", command = "echo 'test'" }) | ||||
|       assert.is_false(result) | ||||
|       assert.truthy(err) | ||||
|       assert.truthy(err:find("No permission to access path")) | ||||
|     end) | ||||
|   end) | ||||
| end) | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 yetone
						yetone