diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua index bfd9c71..851c30f 100644 --- a/lua/avante/providers/azure.lua +++ b/lua/avante/providers/azure.lua @@ -13,7 +13,14 @@ local M = {} M.api_key_name = "AZURE_OPENAI_API_KEY" -M.parse_message = O.parse_message +M.parse_message = function(opts) + local user_content = O.get_user_message(opts) + return { + { role = "system", content = opts.system_prompt }, + { role = "user", content = user_content }, + } +end + M.parse_response = O.parse_response M.parse_curl_args = function(provider, code_opts) diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index 5875acf..bb8dc3f 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -42,12 +42,6 @@ M.copilot = nil local H = {} -local version_headers = { - ["editor-version"] = "Neovim/" .. vim.version().major .. "." .. vim.version().minor .. "." .. vim.version().patch, - ["editor-plugin-version"] = "avante.nvim/0.0.0", - ["user-agent"] = "AvanteNvim/0.0.0", -} - ---@return string | nil H.find_config_path = function() local config = vim.fn.expand("$XDG_CONFIG_HOME") @@ -120,25 +114,31 @@ M.has = function() return false end -M.parse_message = O.parse_message +M.parse_message = function(opts) + local user_content = O.get_user_message(opts) + return { + { role = "system", content = opts.system_prompt }, + { role = "user", content = user_content }, + } +end + M.parse_response = O.parse_response M.parse_curl_args = function(provider, code_opts) - local base, body_opts = P.parse_config(provider) - M.refresh_token() + local base, body_opts = P.parse_config(provider) + return { url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat/completions", + timeout = base.timeout, proxy = base.proxy, insecure = base.allow_insecure, - headers = vim.tbl_deep_extend("error", { - ["authorization"] = "Bearer " .. M.copilot.token.token, - ["copilot-integration-id"] = "vscode-chat", - ["openai-organization"] = "github-copilot", - ["openai-intent"] = "conversation-panel", - ["content-type"] = "application/json", - }, version_headers), + headers = { + ["Authorization"] = "Bearer " .. M.copilot.token.token, + ["Content-Type"] = "application/json", + ["editor-version"] = "Neovim/" .. vim.version().major .. "." .. vim.version().minor .. "." .. vim.version().patch, + }, body = vim.tbl_deep_extend("force", { model = base.model, n = 1, @@ -151,16 +151,23 @@ end M.on_error = function(result) Utils.error("Received error from Copilot API: " .. result.body, { once = true, title = "Avante" }) + Utils.debug(result) end M.refresh_token = function() if not M.copilot.token or (M.copilot.token.expires_at and M.copilot.token.expires_at <= math.floor(os.time())) then curl.get("https://api.github.com/copilot_internal/v2/token", { timeout = Config.copilot.timeout, - headers = vim.tbl_deep_extend("error", { + headers = { ["Authorization"] = "token " .. M.copilot.github_token, ["Accept"] = "application/json", - }, version_headers), + ["editor-version"] = "Neovim/" + .. vim.version().major + .. "." + .. vim.version().minor + .. "." + .. vim.version().patch, + }, proxy = Config.copilot.proxy, insecure = Config.copilot.allow_insecure, on_error = function(err) @@ -185,9 +192,8 @@ M.setup = function() if not M.copilot then M.copilot = { token = nil, github_token = github_token } + M.refresh_token() end - - M.refresh_token() end return M diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index be6c538..453c580 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -301,11 +301,8 @@ function M.refresh(provider) ---@type AvanteProviderFunctor local p = M[Config.provider] - if not p.has() then - E.setup({ provider = p, refresh = true }) - else - Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" }) - end + E.setup({ provider = p, refresh = true }) + Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" }) end local default_providers = { "openai", "claude", "azure", "gemini", "copilot" } diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 4076ecf..00ac2be 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -27,7 +27,8 @@ local M = {} M.api_key_name = "OPENAI_API_KEY" -M.parse_message = function(opts) +---@param opts AvantePromptOptions +M.get_user_message = function(opts) local user_prompt = opts.base_prompt .. "\n\nCODE:\n" .. "```" @@ -56,6 +57,10 @@ M.parse_message = function(opts) .. opts.question end + return user_prompt +end + +M.parse_message = function(opts) local user_content = {} if Config.behaviour.support_paste_from_clipboard and Clipboard.has_content() then table.insert(user_content, { @@ -66,7 +71,7 @@ M.parse_message = function(opts) }) end - table.insert(user_content, { type = "text", text = user_prompt }) + table.insert(user_content, { type = "text", text = M.get_user_message(opts) }) return { { role = "system", content = opts.system_prompt },