From 561f2f33800c948d261ce2398441704099596581 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 27 Aug 2024 04:43:44 -0400 Subject: [PATCH] perf(copilot): make signin and check keys API async (#275) Signed-off-by: Aaron Pham --- lua/avante/providers/copilot.lua | 196 +++++++++++++------------------ lua/avante/providers/init.lua | 12 +- 2 files changed, 83 insertions(+), 125 deletions(-) diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index 00fff16..5875acf 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -1,9 +1,10 @@ ----@class AvanteCopilotProvider: AvanteSupportedProvider ----@field timeout number +---@see https://github.com/B00TK1D/copilot-api/blob/main/api.py local curl = require("plenary.curl") +local Path = require("plenary.path") local Utils = require("avante.utils") +local Config = require("avante.config") local P = require("avante.providers") local O = require("avante.providers").openai @@ -37,8 +38,6 @@ local M = {} ---@class AvanteCopilot: table ---@field token? CopilotToken ---@field github_token? string ----@field sessionid? string ----@field machineid? string M.copilot = nil local H = {} @@ -46,31 +45,9 @@ local H = {} local version_headers = { ["editor-version"] = "Neovim/" .. vim.version().major .. "." .. vim.version().minor .. "." .. vim.version().patch, ["editor-plugin-version"] = "avante.nvim/0.0.0", - ["user-agent"] = "avante.nvim/0.0.0", + ["user-agent"] = "AvanteNvim/0.0.0", } ----@return string -H.uuid = function() - local template = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx" - return ( - string.gsub(template, "[xy]", function(c) - local v = (c == "x") and math.random(0, 0xf) or math.random(8, 0xb) - return string.format("%x", v) - end) - ) -end - ----@return string -H.machine_id = function() - local length = 65 - local hex_chars = "0123456789abcdef" - local hex = "" - for _ = 1, length do - hex = hex .. hex_chars:sub(math.random(1, #hex_chars), math.random(1, #hex_chars)) - end - return hex -end - ---@return string | nil H.find_config_path = function() local config = vim.fn.expand("$XDG_CONFIG_HOME") @@ -89,6 +66,7 @@ H.find_config_path = function() end end +---@return string | nil H.cached_token = function() -- loading token from the environment only in GitHub Codespaces local token = os.getenv("GITHUB_TOKEN") @@ -109,45 +87,33 @@ H.cached_token = function() config_path .. "/github-copilot/apps.json", } - for _, file_path in ipairs(file_paths) do - if vim.fn.filereadable(file_path) == 1 then - local userdata = vim.fn.json_decode(vim.fn.readfile(file_path)) - for key, value in pairs(userdata) do - if string.find(key, "github.com") then - return value.oauth_token - end - end - end - end + local fp = Path:new(vim + .iter(file_paths) + :filter(function(f) + return vim.fn.filereadable(f) == 1 + end) + :next()) - return nil -end + ---@type table + local creds = vim.json.decode(fp:read() or {}) + ---@type table<"token", string> + local value = vim + .iter(creds) + :filter(function(k, _) + return k:find("github.com") + end) + :fold({}, function(acc, _, v) + acc.token = v.oauth_token + return acc + end) ----@param token string ----@param sessionid string ----@param machineid string ----@return table -H.generate_headers = function(token, sessionid, machineid) - local headers = { - ["authorization"] = "Bearer " .. token, - ["x-request-id"] = H.uuid(), - ["vscode-sessionid"] = sessionid, - ["vscode-machineid"] = machineid, - ["copilot-integration-id"] = "vscode-chat", - ["openai-organization"] = "github-copilot", - ["openai-intent"] = "conversation-panel", - ["content-type"] = "application/json", - } - for key, value in pairs(version_headers) do - headers[key] = value - end - return headers + return value.token or nil end M.api_key_name = P.AVANTE_INTERNAL_KEY M.has = function() - if Utils.has("copilot.lua") or Utils.has("copilot.vim") or H.find_config_path() then + if Utils.has("copilot.lua") or Utils.has("copilot.vim") or H.cached_token() ~= nil then return true end Utils.warn("copilot is not setup correctly. Please use copilot.lua or copilot.vim for authentication.") @@ -158,6 +124,57 @@ M.parse_message = O.parse_message M.parse_response = O.parse_response M.parse_curl_args = function(provider, code_opts) + local base, body_opts = P.parse_config(provider) + + M.refresh_token() + + return { + url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat/completions", + proxy = base.proxy, + insecure = base.allow_insecure, + headers = vim.tbl_deep_extend("error", { + ["authorization"] = "Bearer " .. M.copilot.token.token, + ["copilot-integration-id"] = "vscode-chat", + ["openai-organization"] = "github-copilot", + ["openai-intent"] = "conversation-panel", + ["content-type"] = "application/json", + }, version_headers), + body = vim.tbl_deep_extend("force", { + model = base.model, + n = 1, + top_p = 1, + stream = true, + messages = M.parse_message(code_opts), + }, body_opts), + } +end + +M.on_error = function(result) + Utils.error("Received error from Copilot API: " .. result.body, { once = true, title = "Avante" }) +end + +M.refresh_token = function() + if not M.copilot.token or (M.copilot.token.expires_at and M.copilot.token.expires_at <= math.floor(os.time())) then + curl.get("https://api.github.com/copilot_internal/v2/token", { + timeout = Config.copilot.timeout, + headers = vim.tbl_deep_extend("error", { + ["Authorization"] = "token " .. M.copilot.github_token, + ["Accept"] = "application/json", + }, version_headers), + proxy = Config.copilot.proxy, + insecure = Config.copilot.allow_insecure, + on_error = function(err) + error("Failed to get response: " .. vim.inspect(err)) + end, + callback = function(output) + M.copilot.token = vim.json.decode(output.body) + vim.g.avante_login = true + end, + }) + end +end + +M.setup = function() local github_token = H.cached_token() if not github_token then @@ -165,67 +182,12 @@ M.parse_curl_args = function(provider, code_opts) "No GitHub token found, please use `:Copilot auth` to setup with `copilot.lua` or `:Copilot setup` with `copilot.vim`" ) end - local base, body_opts = P.parse_config(provider) - local on_done = function() - return { - url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat/completions", - proxy = base.proxy, - insecure = base.allow_insecure, - headers = H.generate_headers(M.copilot.token.token, M.copilot.sessionid, M.copilot.machineid), - body = vim.tbl_deep_extend("force", { - model = base.model, - n = 1, - top_p = 1, - stream = true, - messages = M.parse_message(code_opts), - }, body_opts), - } - end - - local result = nil - - if not M.copilot.token or (M.copilot.token.expires_at and M.copilot.token.expires_at <= math.floor(os.time())) then - local sessionid = H.uuid() .. tostring(math.floor(os.time() * 1000)) - - local url = "https://api.github.com/copilot_internal/v2/token" - local headers = { - ["Authorization"] = "token " .. github_token, - ["Accept"] = "application/json", - } - for key, value in pairs(version_headers) do - headers[key] = value - end - - local response = curl.get(url, { - timeout = base.timeout, - headers = headers, - proxy = base.proxy, - insecure = base.allow_insecure, - on_error = function(err) - error("Failed to get response: " .. vim.inspect(err)) - end, - }) - - M.copilot.sessionid = sessionid - M.copilot.token = vim.json.decode(response.body) - result = on_done() - else - result = on_done() - end - - return result -end - -M.setup = function() if not M.copilot then - M.copilot = { - sessionid = nil, - token = nil, - github_token = H.cached_token(), - machineid = H.machine_id(), - } + M.copilot = { token = nil, github_token = github_token } end + + M.refresh_token() end return M diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index f2a9770..65d4127 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -45,6 +45,7 @@ local Dressing = require("avante.ui.dressing") ---@field model? string ---@field local? boolean ---@field proxy? string +---@field timeout? integer ---@field allow_insecure? boolean ---@field api_key_name? string ---@field _shellenv? string @@ -111,7 +112,6 @@ E.parse_envvar = function(Opts) local key = nil - vim.g.avante_login = false if cmd ~= nil then -- NOTE: in case api_key_name is cmd, and users still set envvar -- We will try to get envvar first @@ -171,13 +171,7 @@ end E.setup = function(opts) local var = opts.provider.api_key_name - -- check if var is a all caps string - if var == M.AVANTE_INTERNAL_KEY then - return - elseif var:match("^cmd:(.*)") then - opts.provider.setup() - return - end + opts.provider.setup() local refresh = opts.refresh or false @@ -285,6 +279,8 @@ M = setmetatable(M, { }) M.setup = function() + vim.g.avante_login = false + ---@type AvanteProviderFunctor local provider = M[Config.provider] E.setup({ provider = provider })