From a7d3defa3d084e19a5e9e5aa4713a8a0233c6d72 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 24 Aug 2024 17:52:38 -0400 Subject: [PATCH] feat(llm): add support for parsing secret vault (#200) Signed-off-by: Aaron Pham --- lua/avante/config.lua | 24 +++- lua/avante/llm.lua | 2 +- lua/avante/providers/azure.lua | 13 +- lua/avante/providers/claude.lua | 7 +- lua/avante/providers/cohere.lua | 8 +- lua/avante/providers/copilot.lua | 6 +- lua/avante/providers/gemini.lua | 7 +- lua/avante/providers/init.lua | 209 ++++++++++++++++++------------- lua/avante/providers/openai.lua | 6 +- 9 files changed, 161 insertions(+), 121 deletions(-) diff --git a/lua/avante/config.lua b/lua/avante/config.lua index e511521..51e3cdc 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -15,6 +15,7 @@ M.defaults = { openai = { endpoint = "https://api.openai.com/v1", model = "gpt-4o", + timeout = 30000, -- Timeout in milliseconds temperature = 0, max_tokens = 4096, ["local"] = false, @@ -34,6 +35,7 @@ M.defaults = { endpoint = "", -- example: "https://.openai.azure.com" deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment") api_version = "2024-06-01", + timeout = 30000, -- Timeout in milliseconds temperature = 0, max_tokens = 4096, ["local"] = false, @@ -42,22 +44,25 @@ M.defaults = { claude = { endpoint = "https://api.anthropic.com", model = "claude-3-5-sonnet-20240620", + timeout = 30000, -- Timeout in milliseconds temperature = 0, max_tokens = 4096, ["local"] = false, }, - ---@type AvanteGeminiProvider + ---@type AvanteSupportedProvider gemini = { endpoint = "https://generativelanguage.googleapis.com/v1beta/models", model = "gemini-1.5-pro", + timeout = 30000, -- Timeout in milliseconds temperature = 0, max_tokens = 4096, ["local"] = false, }, ---@type AvanteGeminiProvider cohere = { - endpoint = "https://api.cohere.com", + endpoint = "https://api.cohere.com/v1", model = "command-r-plus", + timeout = 30000, -- Timeout in milliseconds temperature = 0, max_tokens = 3072, ["local"] = false, @@ -189,7 +194,7 @@ end ---get supported providers ---@param provider Provider ----@return AvanteProvider | fun(): AvanteProvider +---@return AvanteProviderFunctor M.get_provider = function(provider) if M.options[provider] ~= nil then return vim.deepcopy(M.options[provider], true) @@ -200,8 +205,17 @@ M.get_provider = function(provider) end end -M.BASE_PROVIDER_KEYS = - { "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure", "api_key_name" } +M.BASE_PROVIDER_KEYS = { + "endpoint", + "model", + "local", + "deployment", + "api_version", + "proxy", + "allow_insecure", + "api_key_name", + "timeout", +} ---@return {width: integer, height: integer} function M.get_sidebar_layout_options() diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 6be209d..ad6d96f 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -94,7 +94,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content, ---@type AvanteHandlerOptions local handler_opts = { on_chunk = on_chunk, on_complete = on_complete } ---@type AvanteCurlOutput - local spec = Provider.parse_curl_args(Config.get_provider(provider), code_opts) + local spec = Provider.parse_curl_args(Provider, code_opts) ---@param line string local function parse_stream_data(line) diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua index 7763fab..bfd9c71 100644 --- a/lua/avante/providers/azure.lua +++ b/lua/avante/providers/azure.lua @@ -1,5 +1,10 @@ +---@class AvanteAzureProvider: AvanteDefaultBaseProvider +---@field deployment string +---@field api_version string +---@field temperature number +---@field max_tokens number + local Utils = require("avante.utils") -local Config = require("avante.config") local P = require("avante.providers") local O = require("avante.providers").openai @@ -8,10 +13,6 @@ local M = {} M.api_key_name = "AZURE_OPENAI_API_KEY" -M.has = function() - return os.getenv(M.api_key_name) and true or false -end - M.parse_message = O.parse_message M.parse_response = O.parse_response @@ -22,7 +23,7 @@ M.parse_curl_args = function(provider, code_opts) ["Content-Type"] = "application/json", } if not P.env.is_local("azure") then - headers["api-key"] = os.getenv(base.api_key_name or M.api_key_name) + headers["api-key"] = provider.parse_api_key() end return { diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index dff439b..82bb466 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -1,5 +1,4 @@ local Utils = require("avante.utils") -local Config = require("avante.config") local Tiktoken = require("avante.tiktoken") local P = require("avante.providers") @@ -8,10 +7,6 @@ local M = {} M.api_key_name = "ANTHROPIC_API_KEY" -M.has = function() - return os.getenv(M.api_key_name) and true or false -end - M.parse_message = function(opts) local code_prompt_obj = { type = "text", @@ -93,7 +88,7 @@ M.parse_curl_args = function(provider, code_opts) ["anthropic-beta"] = "prompt-caching-2024-07-31", } if not P.env.is_local("claude") then - headers["x-api-key"] = os.getenv(base.api_key_name or M.api_key_name) + headers["x-api-key"] = provider.parse_api_key() end return { diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua index f5d57c8..b70f6dc 100644 --- a/lua/avante/providers/cohere.lua +++ b/lua/avante/providers/cohere.lua @@ -30,10 +30,6 @@ local M = {} M.api_key_name = "CO_API_KEY" -M.has = function() - return os.getenv(M.api_key_name) and true or false -end - M.parse_message = function(opts) local user_prompt = opts.base_prompt .. "\n\nCODE:\n" @@ -103,11 +99,11 @@ M.parse_curl_args = function(provider, code_opts) .. vim.version().patch, } if not P.env.is_local("openai") then - headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name) + headers["Authorization"] = "Bearer " .. provider.parse_api_key() end return { - url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/chat", + url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat", proxy = base.proxy, insecure = base.allow_insecure, headers = headers, diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index 2899a9c..00fff16 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -1,7 +1,9 @@ +---@class AvanteCopilotProvider: AvanteSupportedProvider +---@field timeout number + local curl = require("plenary.curl") local Utils = require("avante.utils") -local Config = require("avante.config") local P = require("avante.providers") local O = require("avante.providers").openai @@ -196,7 +198,7 @@ M.parse_curl_args = function(provider, code_opts) end local response = curl.get(url, { - timeout = Config.copilot.timeout, + timeout = base.timeout, headers = headers, proxy = base.proxy, insecure = base.allow_insecure, diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index c0e2b65..3f09a1d 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -1,5 +1,4 @@ local Utils = require("avante.utils") -local Config = require("avante.config") local P = require("avante.providers") ---@class AvanteProviderFunctor @@ -7,10 +6,6 @@ local M = {} M.api_key_name = "GEMINI_API_KEY" -M.has = function() - return os.getenv(M.api_key_name) and true or false -end - M.parse_message = function(opts) local code_prompt_obj = { text = string.format("```%s\n%s```", opts.code_lang, opts.code_content), @@ -68,7 +63,7 @@ M.parse_curl_args = function(provider, code_opts) .. "/" .. base.model .. ":streamGenerateContent?alt=sse&key=" - .. os.getenv(base.api_key_name or M.api_key_name), + .. provider.parse_api_key(), proxy = base.proxy, insecure = base.allow_insecure, headers = { ["Content-Type"] = "application/json" }, diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index d1c395c..2244660 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -52,35 +52,24 @@ local Dressing = require("avante.ui.dressing") ---@field temperature? number ---@field max_tokens? number --- ----@class AvanteAzureProvider: AvanteDefaultBaseProvider ----@field deployment string ----@field api_version string ----@field temperature number ----@field max_tokens number ---- ----@class AvanteCopilotProvider: AvanteSupportedProvider ----@field timeout number ---- ----@class AvanteGeminiProvider: AvanteDefaultBaseProvider ----@field model string ---- ----@class AvanteProvider: AvanteDefaultBaseProvider ----@field parse_response_data AvanteResponseParser ----@field parse_curl_args AvanteCurlArgsParser ----@field parse_stream_data? AvanteStreamParser ---- ---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil ---@alias AvanteChunkParser fun(chunk: string): any ---@alias AvanteCompleteParser fun(err: string|nil): nil ---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table --- +---@class AvanteProvider: AvanteSupportedProvider +---@field parse_response_data AvanteResponseParser +---@field parse_curl_args? AvanteCurlArgsParser +---@field parse_stream_data? AvanteStreamParser +--- ---@class AvanteProviderFunctor ---@field parse_message AvanteMessageParser ---@field parse_response AvanteResponseParser ---@field parse_curl_args AvanteCurlArgsParser ----@field setup? fun(): nil +---@field setup fun(): nil ---@field has fun(): boolean ---@field api_key_name string +---@field parse_api_key fun(): string | nil ---@field parse_stream_data? AvanteStreamParser --- ---@class avante.Providers @@ -92,38 +81,50 @@ local Dressing = require("avante.ui.dressing") ---@field cohere AvanteProviderFunctor local M = {} -setmetatable(M, { - ---@param t avante.Providers - ---@param k Provider - __index = function(t, k) - if Config.vendors[k] ~= nil then - ---@type AvanteProvider - local v = Config.vendors[k] - - -- Patch from vendors similar to supported providers. - ---@type AvanteProviderFunctor - t[k] = setmetatable({}, { __index = v }) - -- Hack for aliasing and makes it sane for us. - t[k].parse_response = v.parse_response_data - t[k].has = function() - return os.getenv(t[k].api_key_name) and true or false - end - - return t[k] - end - - ---@type AvanteProviderFunctor - t[k] = require("avante.providers." .. k) - return t[k] - end, -}) - ---@class EnvironmentHandler local E = {} ---@private E._once = false +---@private +---@type table +E.cache = {} + +---@param Opts AvanteSupportedProvider | AvanteProviderFunctor +---@return string | nil +E.parse_envvar = function(Opts) + local api_key_name = Opts.api_key_name + if api_key_name == nil then + error("Requires api_key_name") + end + + if E.cache[api_key_name] ~= nil then + return E.cache[api_key_name] + end + + local cmd = api_key_name:match("^cmd:(.*)") + + local key = nil + if cmd ~= nil then + local ok, job = pcall(vim.system, vim.split(cmd, " ", { trimempty = true }), { text = true }) + if not ok then + Utils.error("Failed to execute command to retrieve secrets: " .. cmd, { once = true, title = "Avante" }) + else + local out = job:wait() + key = out.stdout + end + else + key = os.getenv(api_key_name) + end + + if key ~= nil then + E.cache[api_key_name] = key + end + + return key +end + --- initialize the environment variable for current neovim session. --- This will only run once and spawn a UI for users to input the envvar. ---@param opts {refresh: boolean, provider: AvanteProviderFunctor} @@ -131,7 +132,8 @@ E._once = false E.setup = function(opts) local var = opts.provider.api_key_name - if var == M.AVANTE_INTERNAL_KEY then + -- check if var is a all caps string + if var == M.AVANTE_INTERNAL_KEY or var:match("^cmd:(.*)") then return end @@ -149,51 +151,54 @@ E.setup = function(opts) end end - if refresh then + local function mount_dressing_buffer() vim.defer_fn(function() - Dressing.initialize_input_buffer({ opts = { prompt = "Enter " .. var .. ": " }, on_confirm = on_confirm }) + -- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf + local exclude_buftypes = { "qf", "nofile" } + local exclude_filetypes = { + "NvimTree", + "Outline", + "help", + "dashboard", + "alpha", + "qf", + "ministarter", + "TelescopePrompt", + "gitcommit", + "gitrebase", + "DressingInput", + } + if + not vim.tbl_contains(exclude_buftypes, vim.bo.buftype) + and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype) + and not opts.provider.has() + then + Dressing.initialize_input_buffer({ + opts = { prompt = "Enter " .. var .. ": " }, + on_confirm = on_confirm, + }) + end end, 200) - elseif not E._once then + end + + if refresh then + mount_dressing_buffer() + return + end + + if not E._once then E._once = true api.nvim_create_autocmd({ "BufEnter", "BufWinEnter", "WinEnter" }, { pattern = "*", once = true, - callback = function() - vim.defer_fn(function() - -- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf - local exclude_buftypes = { "qf", "nofile" } - local exclude_filetypes = { - "NvimTree", - "Outline", - "help", - "dashboard", - "alpha", - "qf", - "ministarter", - "TelescopePrompt", - "gitcommit", - "gitrebase", - "DressingInput", - } - if - not vim.tbl_contains(exclude_buftypes, vim.bo.buftype) - and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype) - and not opts.provider.has() - then - Dressing.initialize_input_buffer({ - opts = { prompt = "Enter " .. var .. ": " }, - on_confirm = on_confirm, - }) - end - end, 200) - end, + callback = mount_dressing_buffer, }) end end ---@param provider Provider E.is_local = function(provider) - local cur = M.get(provider) + local cur = M.get_config(provider) return cur["local"] ~= nil and cur["local"] or false end @@ -201,14 +206,47 @@ M.env = E M.AVANTE_INTERNAL_KEY = "__avante_env_internal" +M = setmetatable(M, { + ---@param t avante.Providers + ---@param k Provider + __index = function(t, k) + ---@type AvanteProviderFunctor + local Opts = M.get_config(k) + + if Config.vendors[k] ~= nil then + Opts.parse_response = Opts.parse_response_data + t[k] = Opts + else + t[k] = vim.tbl_deep_extend("keep", Opts, require("avante.providers." .. k)) + end + + t[k].parse_api_key = function() + return E.parse_envvar(t[k]) + end + + if t[k].has == nil then + t[k].has = function() + return E.parse_envvar(t[k]) ~= nil + end + end + + if t[k].setup == nil then + t[k].setup = function() + t[k].parse_api_key() + end + end + + return t[k] + end, +}) + M.setup = function() ---@type AvanteProviderFunctor local provider = M[Config.provider] E.setup({ provider = provider }) - - if provider.setup ~= nil then + vim.schedule(function() provider.setup() - end + end) M.commands() end @@ -216,6 +254,8 @@ end ---@private ---@param provider Provider function M.refresh(provider) + require("avante.config").override({ provider = provider }) + ---@type AvanteProviderFunctor local p = M[Config.provider] if not p.has() then @@ -223,7 +263,6 @@ function M.refresh(provider) else Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" }) end - require("avante.config").override({ provider = provider }) end local default_providers = { "openai", "claude", "azure", "gemini", "copilot" } @@ -242,7 +281,8 @@ M.commands = function() end local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or "" -- join two tables - local Keys = vim.list_extend(default_providers, vim.tbl_keys(Config.vendors or {})) + local Keys = vim.list_extend({}, default_providers) + Keys = vim.list_extend(Keys, vim.tbl_keys(Config.vendors or {})) return vim.tbl_filter(function(key) return key:find(prefix) == 1 end, Keys) @@ -280,7 +320,8 @@ end ---@private ---@param provider Provider -M.get = function(provider) +---@return AvanteProviderFunctor +M.get_config = function(provider) local cur = Config.get_provider(provider or Config.provider) return type(cur) == "function" and cur() or cur end diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 635f03e..c362b17 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -26,10 +26,6 @@ local M = {} M.api_key_name = "OPENAI_API_KEY" -M.has = function() - return os.getenv(M.api_key_name) and true or false -end - M.parse_message = function(opts) local user_prompt = opts.base_prompt .. "\n\nCODE:\n" @@ -91,7 +87,7 @@ M.parse_curl_args = function(provider, code_opts) ["Content-Type"] = "application/json", } if not P.env.is_local("openai") then - headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name) + headers["Authorization"] = "Bearer " .. provider.parse_api_key() end return {