perf(copilot): make signin and check keys API async (#275)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2024-08-27 04:43:44 -04:00 committed by GitHub
parent d543d0ed53
commit 561f2f3380
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 125 deletions

View File

@ -1,9 +1,10 @@
---@class AvanteCopilotProvider: AvanteSupportedProvider ---@see https://github.com/B00TK1D/copilot-api/blob/main/api.py
---@field timeout number
local curl = require("plenary.curl") local curl = require("plenary.curl")
local Path = require("plenary.path")
local Utils = require("avante.utils") local Utils = require("avante.utils")
local Config = require("avante.config")
local P = require("avante.providers") local P = require("avante.providers")
local O = require("avante.providers").openai local O = require("avante.providers").openai
@ -37,8 +38,6 @@ local M = {}
---@class AvanteCopilot: table<string, any> ---@class AvanteCopilot: table<string, any>
---@field token? CopilotToken ---@field token? CopilotToken
---@field github_token? string ---@field github_token? string
---@field sessionid? string
---@field machineid? string
M.copilot = nil M.copilot = nil
local H = {} local H = {}
@ -46,31 +45,9 @@ local H = {}
local version_headers = { local version_headers = {
["editor-version"] = "Neovim/" .. vim.version().major .. "." .. vim.version().minor .. "." .. vim.version().patch, ["editor-version"] = "Neovim/" .. vim.version().major .. "." .. vim.version().minor .. "." .. vim.version().patch,
["editor-plugin-version"] = "avante.nvim/0.0.0", ["editor-plugin-version"] = "avante.nvim/0.0.0",
["user-agent"] = "avante.nvim/0.0.0", ["user-agent"] = "AvanteNvim/0.0.0",
} }
---@return string
H.uuid = function()
local template = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx"
return (
string.gsub(template, "[xy]", function(c)
local v = (c == "x") and math.random(0, 0xf) or math.random(8, 0xb)
return string.format("%x", v)
end)
)
end
---@return string
H.machine_id = function()
local length = 65
local hex_chars = "0123456789abcdef"
local hex = ""
for _ = 1, length do
hex = hex .. hex_chars:sub(math.random(1, #hex_chars), math.random(1, #hex_chars))
end
return hex
end
---@return string | nil ---@return string | nil
H.find_config_path = function() H.find_config_path = function()
local config = vim.fn.expand("$XDG_CONFIG_HOME") local config = vim.fn.expand("$XDG_CONFIG_HOME")
@ -89,6 +66,7 @@ H.find_config_path = function()
end end
end end
---@return string | nil
H.cached_token = function() H.cached_token = function()
-- loading token from the environment only in GitHub Codespaces -- loading token from the environment only in GitHub Codespaces
local token = os.getenv("GITHUB_TOKEN") local token = os.getenv("GITHUB_TOKEN")
@ -109,45 +87,33 @@ H.cached_token = function()
config_path .. "/github-copilot/apps.json", config_path .. "/github-copilot/apps.json",
} }
for _, file_path in ipairs(file_paths) do local fp = Path:new(vim
if vim.fn.filereadable(file_path) == 1 then .iter(file_paths)
local userdata = vim.fn.json_decode(vim.fn.readfile(file_path)) :filter(function(f)
for key, value in pairs(userdata) do return vim.fn.filereadable(f) == 1
if string.find(key, "github.com") then end)
return value.oauth_token :next())
end
end
end
end
return nil ---@type table<string, any>
end local creds = vim.json.decode(fp:read() or {})
---@type table<"token", string>
local value = vim
.iter(creds)
:filter(function(k, _)
return k:find("github.com")
end)
:fold({}, function(acc, _, v)
acc.token = v.oauth_token
return acc
end)
---@param token string return value.token or nil
---@param sessionid string
---@param machineid string
---@return table<string, string>
H.generate_headers = function(token, sessionid, machineid)
local headers = {
["authorization"] = "Bearer " .. token,
["x-request-id"] = H.uuid(),
["vscode-sessionid"] = sessionid,
["vscode-machineid"] = machineid,
["copilot-integration-id"] = "vscode-chat",
["openai-organization"] = "github-copilot",
["openai-intent"] = "conversation-panel",
["content-type"] = "application/json",
}
for key, value in pairs(version_headers) do
headers[key] = value
end
return headers
end end
M.api_key_name = P.AVANTE_INTERNAL_KEY M.api_key_name = P.AVANTE_INTERNAL_KEY
M.has = function() M.has = function()
if Utils.has("copilot.lua") or Utils.has("copilot.vim") or H.find_config_path() then if Utils.has("copilot.lua") or Utils.has("copilot.vim") or H.cached_token() ~= nil then
return true return true
end end
Utils.warn("copilot is not setup correctly. Please use copilot.lua or copilot.vim for authentication.") Utils.warn("copilot is not setup correctly. Please use copilot.lua or copilot.vim for authentication.")
@ -158,6 +124,57 @@ M.parse_message = O.parse_message
M.parse_response = O.parse_response M.parse_response = O.parse_response
M.parse_curl_args = function(provider, code_opts) M.parse_curl_args = function(provider, code_opts)
local base, body_opts = P.parse_config(provider)
M.refresh_token()
return {
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat/completions",
proxy = base.proxy,
insecure = base.allow_insecure,
headers = vim.tbl_deep_extend("error", {
["authorization"] = "Bearer " .. M.copilot.token.token,
["copilot-integration-id"] = "vscode-chat",
["openai-organization"] = "github-copilot",
["openai-intent"] = "conversation-panel",
["content-type"] = "application/json",
}, version_headers),
body = vim.tbl_deep_extend("force", {
model = base.model,
n = 1,
top_p = 1,
stream = true,
messages = M.parse_message(code_opts),
}, body_opts),
}
end
M.on_error = function(result)
Utils.error("Received error from Copilot API: " .. result.body, { once = true, title = "Avante" })
end
M.refresh_token = function()
if not M.copilot.token or (M.copilot.token.expires_at and M.copilot.token.expires_at <= math.floor(os.time())) then
curl.get("https://api.github.com/copilot_internal/v2/token", {
timeout = Config.copilot.timeout,
headers = vim.tbl_deep_extend("error", {
["Authorization"] = "token " .. M.copilot.github_token,
["Accept"] = "application/json",
}, version_headers),
proxy = Config.copilot.proxy,
insecure = Config.copilot.allow_insecure,
on_error = function(err)
error("Failed to get response: " .. vim.inspect(err))
end,
callback = function(output)
M.copilot.token = vim.json.decode(output.body)
vim.g.avante_login = true
end,
})
end
end
M.setup = function()
local github_token = H.cached_token() local github_token = H.cached_token()
if not github_token then if not github_token then
@ -165,67 +182,12 @@ M.parse_curl_args = function(provider, code_opts)
"No GitHub token found, please use `:Copilot auth` to setup with `copilot.lua` or `:Copilot setup` with `copilot.vim`" "No GitHub token found, please use `:Copilot auth` to setup with `copilot.lua` or `:Copilot setup` with `copilot.vim`"
) )
end end
local base, body_opts = P.parse_config(provider)
local on_done = function()
return {
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat/completions",
proxy = base.proxy,
insecure = base.allow_insecure,
headers = H.generate_headers(M.copilot.token.token, M.copilot.sessionid, M.copilot.machineid),
body = vim.tbl_deep_extend("force", {
model = base.model,
n = 1,
top_p = 1,
stream = true,
messages = M.parse_message(code_opts),
}, body_opts),
}
end
local result = nil
if not M.copilot.token or (M.copilot.token.expires_at and M.copilot.token.expires_at <= math.floor(os.time())) then
local sessionid = H.uuid() .. tostring(math.floor(os.time() * 1000))
local url = "https://api.github.com/copilot_internal/v2/token"
local headers = {
["Authorization"] = "token " .. github_token,
["Accept"] = "application/json",
}
for key, value in pairs(version_headers) do
headers[key] = value
end
local response = curl.get(url, {
timeout = base.timeout,
headers = headers,
proxy = base.proxy,
insecure = base.allow_insecure,
on_error = function(err)
error("Failed to get response: " .. vim.inspect(err))
end,
})
M.copilot.sessionid = sessionid
M.copilot.token = vim.json.decode(response.body)
result = on_done()
else
result = on_done()
end
return result
end
M.setup = function()
if not M.copilot then if not M.copilot then
M.copilot = { M.copilot = { token = nil, github_token = github_token }
sessionid = nil,
token = nil,
github_token = H.cached_token(),
machineid = H.machine_id(),
}
end end
M.refresh_token()
end end
return M return M

View File

@ -45,6 +45,7 @@ local Dressing = require("avante.ui.dressing")
---@field model? string ---@field model? string
---@field local? boolean ---@field local? boolean
---@field proxy? string ---@field proxy? string
---@field timeout? integer
---@field allow_insecure? boolean ---@field allow_insecure? boolean
---@field api_key_name? string ---@field api_key_name? string
---@field _shellenv? string ---@field _shellenv? string
@ -111,7 +112,6 @@ E.parse_envvar = function(Opts)
local key = nil local key = nil
vim.g.avante_login = false
if cmd ~= nil then if cmd ~= nil then
-- NOTE: in case api_key_name is cmd, and users still set envvar -- NOTE: in case api_key_name is cmd, and users still set envvar
-- We will try to get envvar first -- We will try to get envvar first
@ -171,13 +171,7 @@ end
E.setup = function(opts) E.setup = function(opts)
local var = opts.provider.api_key_name local var = opts.provider.api_key_name
-- check if var is a all caps string opts.provider.setup()
if var == M.AVANTE_INTERNAL_KEY then
return
elseif var:match("^cmd:(.*)") then
opts.provider.setup()
return
end
local refresh = opts.refresh or false local refresh = opts.refresh or false
@ -285,6 +279,8 @@ M = setmetatable(M, {
}) })
M.setup = function() M.setup = function()
vim.g.avante_login = false
---@type AvanteProviderFunctor ---@type AvanteProviderFunctor
local provider = M[Config.provider] local provider = M[Config.provider]
E.setup({ provider = provider }) E.setup({ provider = provider })