diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index bfdbeb6..47305ff 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -31,7 +31,7 @@ local Config = require("avante.config") local Path = require("plenary.path") local Utils = require("avante.utils") local P = require("avante.providers") -local O = require("avante.providers").openai +local OpenAI = require("avante.providers").openai local H = {} @@ -215,10 +215,31 @@ M.parse_messages = function(opts) vim .iter(opts.messages) :each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end) + if opts.tool_result then + table.insert(messages, { + role = M.role_map["assistant"], + tool_calls = { + { + id = opts.tool_use.id, + type = "function", + ["function"] = { + name = opts.tool_use.name, + arguments = opts.tool_use.input_json, + }, + }, + }, + }) + local result_content = opts.tool_result.content or "" + table.insert(messages, { + role = "tool", + tool_call_id = opts.tool_result.tool_use_id, + content = opts.tool_result.is_error and "Error: " .. result_content or result_content, + }) + end return messages end -M.parse_response = O.parse_response +M.parse_response = OpenAI.parse_response M.parse_curl_args = function(provider, code_opts) -- refresh token synchronously, only if it has expired @@ -227,6 +248,13 @@ M.parse_curl_args = function(provider, code_opts) local base, body_opts = P.parse_config(provider) + local tools = {} + if code_opts.tools then + for _, tool in ipairs(code_opts.tools) do + table.insert(tools, OpenAI.transform_tool(tool)) + end + end + return { url = H.chat_completion_url(base.endpoint), timeout = base.timeout, @@ -242,6 +270,7 @@ M.parse_curl_args = function(provider, code_opts) model = base.model, messages = M.parse_messages(code_opts), stream = true, + tools = tools, }, body_opts), } end diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 54ab509..1ff6cf8 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -60,9 +60,19 @@ local P = require("avante.providers") ---@field type string ---@field description string +---@class AvanteProviderFunctor +local M = {} + +M.api_key_name = "OPENAI_API_KEY" + +M.role_map = { + user = "user", + assistant = "assistant", +} + ---@param tool AvanteLLMTool ---@return AvanteOpenAITool -local function transform_tool(tool) +function M.transform_tool(tool) local input_schema_properties = {} local required = {} for _, field in ipairs(tool.param.fields) do @@ -90,16 +100,6 @@ local function transform_tool(tool) return res end ----@class AvanteProviderFunctor -local M = {} - -M.api_key_name = "OPENAI_API_KEY" - -M.role_map = { - user = "user", - assistant = "assistant", -} - M.is_openrouter = function(url) return url:match("^https://openrouter%.ai/") end ---@param opts AvantePromptOptions @@ -301,7 +301,7 @@ M.parse_curl_args = function(provider, code_opts) local tools = {} if code_opts.tools then for _, tool in ipairs(code_opts.tools) do - table.insert(tools, transform_tool(tool)) + table.insert(tools, M.transform_tool(tool)) end end