diff --git a/lua/avante/config.lua b/lua/avante/config.lua index e055a8d..d156e60 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -9,7 +9,7 @@ M.defaults = { debug = false, ---Currently, default supported providers include "claude", "openai", "azure", "deepseek", "groq" ---For custom provider, see README.md - ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | string + ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | "copilot" | string provider = "claude", ---@type AvanteSupportedProvider openai = { @@ -19,6 +19,16 @@ M.defaults = { max_tokens = 4096, ["local"] = false, }, + ---@type AvanteCopilotProvider + copilot = { + endpoint = "https://api.githubcopilot.com", + model = "gpt-4o-2024-05-13", + proxy = nil, -- [protocol://]host[:port] Use this proxy + allow_insecure = false, -- Allow insecure server connections + timeout = 30000, -- Timeout in milliseconds + temperature = 0, + max_tokens = 8192, + }, ---@type AvanteAzureProvider azure = { endpoint = "", -- example: "https://.openai.azure.com" diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index f5eb074..62d0805 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -12,25 +12,69 @@ local M = {} M.CANCEL_PATTERN = "AvanteLLMEscape" +---@class CopilotToken +---@field annotations_enabled boolean +---@field chat_enabled boolean +---@field chat_jetbrains_enabled boolean +---@field code_quote_enabled boolean +---@field codesearch boolean +---@field copilotignore_enabled boolean +---@field endpoints {api: string, ["origin-tracker"]: string, proxy: string, telemetry: string} +---@field expires_at integer +---@field individual boolean +---@field nes_enabled boolean +---@field prompt_8k boolean +---@field public_suggestions string +---@field refresh_in integer +---@field sku string +---@field snippy_load_test_enabled boolean +---@field telemetry string +---@field token string +---@field tracking_id string +---@field vsc_electron_fetcher boolean +---@field xcode boolean +---@field xcode_chat boolean +--- +---@privaate +---@class AvanteCopilot: table +---@field proxy string +---@field allow_insecure boolean +---@field token? CopilotToken +---@field github_token? string +---@field sessionid? string +---@field machineid? string +M.copilot = nil + ---@class EnvironmentHandler: table<[Provider], string> local E = { - ---@type table + ---@type table env = { openai = "OPENAI_API_KEY", claude = "ANTHROPIC_API_KEY", azure = "AZURE_OPENAI_API_KEY", deepseek = "DEEPSEEK_API_KEY", groq = "GROQ_API_KEY", + copilot = function() + if Utils.has("copilot.lua") or Utils.has("copilot.vim") then + return true + end + Utils.warn("copilot is not setup correctly. Please use copilot.lua or copilot.vim for authentication.") + return false + end, }, } setmetatable(E, { ---@param k Provider __index = function(_, k) + if E.is_local(k) then + return true + end + local builtins = E.env[k] if builtins then - if Config.options[k]["local"] then - return true + if type(builtins) == "function" then + return builtins() end return os.getenv(builtins) and true or false end @@ -38,9 +82,6 @@ setmetatable(E, { ---@type AvanteProvider | nil local external = Config.vendors[k] if external then - if external["local"] then - return true - end return os.getenv(external.api_key_name) and true or false end end, @@ -54,6 +95,8 @@ E.is_default = function(provider) return E.env[provider] and true or false end +local AVANTE_INTERNAL_KEY = "__avante_internal" + --- return the environment variable name for the given provider ---@param provider? Provider ---@return string the envvar key @@ -61,16 +104,16 @@ E.key = function(provider) provider = provider or Config.provider if E.is_default(provider) then - return E.env[provider] + local result = E.env[provider] + return type(result) == "function" and AVANTE_INTERNAL_KEY or result end ---@type AvanteProvider | nil local external = Config.vendors[provider] if external then return external.api_key_name - else - error("Failed to find provider: " .. provider, 2) end + error("Failed to find provider: " .. provider, 2) end ---@param provider Provider @@ -87,17 +130,21 @@ end ---@param provider? Provider E.value = function(provider) if E.is_local(provider or Config.provider) then - return "dummy" + return "__avante_dummy" end return os.getenv(E.key(provider or Config.provider)) end --- intialize the environment variable for current neovim session. --- This will only run once and spawn a UI for users to input the envvar. ----@param var Provider supported providers +---@param var string supported providers ---@param refresh? boolean ---@private E.setup = function(var, refresh) + if var == AVANTE_INTERNAL_KEY then + return + end + refresh = refresh or false ---@param value string @@ -243,6 +290,11 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m ---@field temperature number ---@field max_tokens number --- +---@class AvanteCopilotProvider: AvanteSupportedProvider +---@field proxy string | nil +---@field allow_insecure boolean +---@field timeout number +--- ---@class AvanteProvider: AvanteDefaultBaseProvider ---@field model? string ---@field api_key_name string @@ -424,6 +476,75 @@ M.make_openai_curl_args = function(code_opts) } end +------------------------------Copilot------------------------------ +---@type AvanteAiMessageBuilder +M.make_copilot_message = M.make_openai_message + +---@type AvanteResponseParser +M.parse_copilot_response = M.parse_openai_response + +---@type AvanteCurlArgsBuilder +M.make_copilot_curl_args = function(code_opts) + local github_token = Utils.copilot.cached_token() + + if not github_token then + error( + "No GitHub token found, please use `:Copilot auth` to setup with `copilot.lua` or `:Copilot setup` with `copilot.vim`" + ) + end + + local on_done = function() + return { + url = Utils.trim(Config.copilot.endpoint, { suffix = "/" }) .. "/chat/completions", + proxy = Config.copilot.proxy, + insecure = Config.copilot.allow_insecure, + headers = Utils.copilot.generate_headers(M.copilot.token.token, M.copilot.sessionid, M.copilot.machineid), + body = { + mode = Config.copilot.model, + n = 1, + top_p = 1, + stream = true, + temperature = Config.copilot.temperature, + max_tokens = Config.copilot.max_tokens, + messages = M.make_copilot_message(code_opts), + }, + } + end + + local result = nil + + if not M.copilot.token or (M.copilot.token.expires_at and M.copilot.token.expires_at <= math.floor(os.time())) then + local sessionid = Utils.copilot.uuid() .. tostring(math.floor(os.time() * 1000)) + + local url = "https://api.github.com/copilot_internal/v2/token" + local headers = { + ["Authorization"] = "token " .. github_token, + ["Accept"] = "application/json", + } + for key, value in pairs(Utils.copilot.version_headers) do + headers[key] = value + end + + local response = curl.get(url, { + timeout = Config.copilot.timeout, + headers = headers, + proxy = M.copilot.proxy, + insecure = M.copilot.allow_insecure, + on_error = function(err) + error("Failed to get response: " .. vim.inspect(err)) + end, + }) + + M.copilot.sessionid = sessionid + M.copilot.token = vim.json.decode(response.body) + result = on_done() + else + result = on_done() + end + + return result +end + ------------------------------Azure------------------------------ ---@type AvanteAiMessageBuilder @@ -613,6 +734,17 @@ end ---@private function M.setup() + if Config.provider == "copilot" and not M.copilot then + M.copilot = { + proxy = Config.copilot.proxy, + allow_insecure = Config.copilot.allow_insecure, + github_token = Utils.copilot.cached_token(), + sessionid = nil, + token = nil, + machineid = Utils.copilot.machine_id(), + } + end + local has = E[Config.provider] if not has then E.setup(E.key()) diff --git a/lua/avante/utils/copilot.lua b/lua/avante/utils/copilot.lua new file mode 100644 index 0000000..2c33e70 --- /dev/null +++ b/lua/avante/utils/copilot.lua @@ -0,0 +1,109 @@ +---This file COPY and MODIFIED based on: https://github.com/CopilotC-Nvim/CopilotChat.nvim/blob/canary/lua/CopilotChat/copilot.lua#L560 + +---@class avante.utils.copilot +local M = {} + +local version_headers = { + ["editor-version"] = "Neovim/" .. vim.version().major .. "." .. vim.version().minor .. "." .. vim.version().patch, + ["editor-plugin-version"] = "avante.nvim/0.0.0", + ["user-agent"] = "avante.nvim/0.0.0", +} + +---@return string +M.uuid = function() + local template = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx" + return ( + string.gsub(template, "[xy]", function(c) + local v = (c == "x") and math.random(0, 0xf) or math.random(8, 0xb) + return string.format("%x", v) + end) + ) +end + +---@return string +M.machine_id = function() + local length = 65 + local hex_chars = "0123456789abcdef" + local hex = "" + for _ = 1, length do + hex = hex .. hex_chars:sub(math.random(1, #hex_chars), math.random(1, #hex_chars)) + end + return hex +end + +---@return string | nil +local function find_config_path() + local config = vim.fn.expand("$XDG_CONFIG_HOME") + if config and vim.fn.isdirectory(config) > 0 then + return config + elseif vim.fn.has("win32") > 0 then + config = vim.fn.expand("~/AppData/Local") + if vim.fn.isdirectory(config) > 0 then + return config + end + else + config = vim.fn.expand("~/.config") + if vim.fn.isdirectory(config) > 0 then + return config + end + end +end + +M.cached_token = function() + -- loading token from the environment only in GitHub Codespaces + local token = os.getenv("GITHUB_TOKEN") + local codespaces = os.getenv("CODESPACES") + if token and codespaces then + return token + end + + -- loading token from the file + local config_path = find_config_path() + if not config_path then + return nil + end + + -- token can be sometimes in apps.json sometimes in hosts.json + local file_paths = { + config_path .. "/github-copilot/hosts.json", + config_path .. "/github-copilot/apps.json", + } + + for _, file_path in ipairs(file_paths) do + if vim.fn.filereadable(file_path) == 1 then + local userdata = vim.fn.json_decode(vim.fn.readfile(file_path)) + for key, value in pairs(userdata) do + if string.find(key, "github.com") then + return value.oauth_token + end + end + end + end + + return nil +end + +---@param token string +---@param sessionid string +---@param machineid string +---@return table +M.generate_headers = function(token, sessionid, machineid) + local headers = { + ["authorization"] = "Bearer " .. token, + ["x-request-id"] = M.uuid(), + ["vscode-sessionid"] = sessionid, + ["vscode-machineid"] = machineid, + ["copilot-integration-id"] = "vscode-chat", + ["openai-organization"] = "github-copilot", + ["openai-intent"] = "conversation-panel", + ["content-type"] = "application/json", + } + for key, value in pairs(version_headers) do + headers[key] = value + end + return headers +end + +M.version_headers = version_headers + +return M diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index ed85d10..7a9897e 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1,8 +1,8 @@ local api = vim.api -local fn = vim.fn ---@class avante.Utils: LazyUtilCore ---@field colors avante.util.colors +---@field copilot avante.utils.copilot local M = {} setmetatable(M, { @@ -18,6 +18,13 @@ setmetatable(M, { end, }) +---Check if a plugin is installed +---@param plugin string +---@return boolean +M.has = function(plugin) + return require("lazy.core.config").plugins[plugin] ~= nil +end + ---@param str string ---@param opts? {suffix?: string, prefix?: string} function M.trim(str, opts)