feat(llm): add support for parsing secret vault (#200)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2024-08-24 17:52:38 -04:00 committed by GitHub
parent 8d375dd591
commit a7d3defa3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 161 additions and 121 deletions

View File

@ -15,6 +15,7 @@ M.defaults = {
openai = { openai = {
endpoint = "https://api.openai.com/v1", endpoint = "https://api.openai.com/v1",
model = "gpt-4o", model = "gpt-4o",
timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 4096,
["local"] = false, ["local"] = false,
@ -34,6 +35,7 @@ M.defaults = {
endpoint = "", -- example: "https://<your-resource-name>.openai.azure.com" endpoint = "", -- example: "https://<your-resource-name>.openai.azure.com"
deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment") deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment")
api_version = "2024-06-01", api_version = "2024-06-01",
timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 4096,
["local"] = false, ["local"] = false,
@ -42,22 +44,25 @@ M.defaults = {
claude = { claude = {
endpoint = "https://api.anthropic.com", endpoint = "https://api.anthropic.com",
model = "claude-3-5-sonnet-20240620", model = "claude-3-5-sonnet-20240620",
timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 4096,
["local"] = false, ["local"] = false,
}, },
---@type AvanteGeminiProvider ---@type AvanteSupportedProvider
gemini = { gemini = {
endpoint = "https://generativelanguage.googleapis.com/v1beta/models", endpoint = "https://generativelanguage.googleapis.com/v1beta/models",
model = "gemini-1.5-pro", model = "gemini-1.5-pro",
timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 4096,
["local"] = false, ["local"] = false,
}, },
---@type AvanteGeminiProvider ---@type AvanteGeminiProvider
cohere = { cohere = {
endpoint = "https://api.cohere.com", endpoint = "https://api.cohere.com/v1",
model = "command-r-plus", model = "command-r-plus",
timeout = 30000, -- Timeout in milliseconds
temperature = 0, temperature = 0,
max_tokens = 3072, max_tokens = 3072,
["local"] = false, ["local"] = false,
@ -189,7 +194,7 @@ end
---get supported providers ---get supported providers
---@param provider Provider ---@param provider Provider
---@return AvanteProvider | fun(): AvanteProvider ---@return AvanteProviderFunctor
M.get_provider = function(provider) M.get_provider = function(provider)
if M.options[provider] ~= nil then if M.options[provider] ~= nil then
return vim.deepcopy(M.options[provider], true) return vim.deepcopy(M.options[provider], true)
@ -200,8 +205,17 @@ M.get_provider = function(provider)
end end
end end
M.BASE_PROVIDER_KEYS = M.BASE_PROVIDER_KEYS = {
{ "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure", "api_key_name" } "endpoint",
"model",
"local",
"deployment",
"api_version",
"proxy",
"allow_insecure",
"api_key_name",
"timeout",
}
---@return {width: integer, height: integer} ---@return {width: integer, height: integer}
function M.get_sidebar_layout_options() function M.get_sidebar_layout_options()

View File

@ -94,7 +94,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
---@type AvanteHandlerOptions ---@type AvanteHandlerOptions
local handler_opts = { on_chunk = on_chunk, on_complete = on_complete } local handler_opts = { on_chunk = on_chunk, on_complete = on_complete }
---@type AvanteCurlOutput ---@type AvanteCurlOutput
local spec = Provider.parse_curl_args(Config.get_provider(provider), code_opts) local spec = Provider.parse_curl_args(Provider, code_opts)
---@param line string ---@param line string
local function parse_stream_data(line) local function parse_stream_data(line)

View File

@ -1,5 +1,10 @@
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
---@field deployment string
---@field api_version string
---@field temperature number
---@field max_tokens number
local Utils = require("avante.utils") local Utils = require("avante.utils")
local Config = require("avante.config")
local P = require("avante.providers") local P = require("avante.providers")
local O = require("avante.providers").openai local O = require("avante.providers").openai
@ -8,10 +13,6 @@ local M = {}
M.api_key_name = "AZURE_OPENAI_API_KEY" M.api_key_name = "AZURE_OPENAI_API_KEY"
M.has = function()
return os.getenv(M.api_key_name) and true or false
end
M.parse_message = O.parse_message M.parse_message = O.parse_message
M.parse_response = O.parse_response M.parse_response = O.parse_response
@ -22,7 +23,7 @@ M.parse_curl_args = function(provider, code_opts)
["Content-Type"] = "application/json", ["Content-Type"] = "application/json",
} }
if not P.env.is_local("azure") then if not P.env.is_local("azure") then
headers["api-key"] = os.getenv(base.api_key_name or M.api_key_name) headers["api-key"] = provider.parse_api_key()
end end
return { return {

View File

@ -1,5 +1,4 @@
local Utils = require("avante.utils") local Utils = require("avante.utils")
local Config = require("avante.config")
local Tiktoken = require("avante.tiktoken") local Tiktoken = require("avante.tiktoken")
local P = require("avante.providers") local P = require("avante.providers")
@ -8,10 +7,6 @@ local M = {}
M.api_key_name = "ANTHROPIC_API_KEY" M.api_key_name = "ANTHROPIC_API_KEY"
M.has = function()
return os.getenv(M.api_key_name) and true or false
end
M.parse_message = function(opts) M.parse_message = function(opts)
local code_prompt_obj = { local code_prompt_obj = {
type = "text", type = "text",
@ -93,7 +88,7 @@ M.parse_curl_args = function(provider, code_opts)
["anthropic-beta"] = "prompt-caching-2024-07-31", ["anthropic-beta"] = "prompt-caching-2024-07-31",
} }
if not P.env.is_local("claude") then if not P.env.is_local("claude") then
headers["x-api-key"] = os.getenv(base.api_key_name or M.api_key_name) headers["x-api-key"] = provider.parse_api_key()
end end
return { return {

View File

@ -30,10 +30,6 @@ local M = {}
M.api_key_name = "CO_API_KEY" M.api_key_name = "CO_API_KEY"
M.has = function()
return os.getenv(M.api_key_name) and true or false
end
M.parse_message = function(opts) M.parse_message = function(opts)
local user_prompt = opts.base_prompt local user_prompt = opts.base_prompt
.. "\n\nCODE:\n" .. "\n\nCODE:\n"
@ -103,11 +99,11 @@ M.parse_curl_args = function(provider, code_opts)
.. vim.version().patch, .. vim.version().patch,
} }
if not P.env.is_local("openai") then if not P.env.is_local("openai") then
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name) headers["Authorization"] = "Bearer " .. provider.parse_api_key()
end end
return { return {
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/chat", url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat",
proxy = base.proxy, proxy = base.proxy,
insecure = base.allow_insecure, insecure = base.allow_insecure,
headers = headers, headers = headers,

View File

@ -1,7 +1,9 @@
---@class AvanteCopilotProvider: AvanteSupportedProvider
---@field timeout number
local curl = require("plenary.curl") local curl = require("plenary.curl")
local Utils = require("avante.utils") local Utils = require("avante.utils")
local Config = require("avante.config")
local P = require("avante.providers") local P = require("avante.providers")
local O = require("avante.providers").openai local O = require("avante.providers").openai
@ -196,7 +198,7 @@ M.parse_curl_args = function(provider, code_opts)
end end
local response = curl.get(url, { local response = curl.get(url, {
timeout = Config.copilot.timeout, timeout = base.timeout,
headers = headers, headers = headers,
proxy = base.proxy, proxy = base.proxy,
insecure = base.allow_insecure, insecure = base.allow_insecure,

View File

@ -1,5 +1,4 @@
local Utils = require("avante.utils") local Utils = require("avante.utils")
local Config = require("avante.config")
local P = require("avante.providers") local P = require("avante.providers")
---@class AvanteProviderFunctor ---@class AvanteProviderFunctor
@ -7,10 +6,6 @@ local M = {}
M.api_key_name = "GEMINI_API_KEY" M.api_key_name = "GEMINI_API_KEY"
M.has = function()
return os.getenv(M.api_key_name) and true or false
end
M.parse_message = function(opts) M.parse_message = function(opts)
local code_prompt_obj = { local code_prompt_obj = {
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content), text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
@ -68,7 +63,7 @@ M.parse_curl_args = function(provider, code_opts)
.. "/" .. "/"
.. base.model .. base.model
.. ":streamGenerateContent?alt=sse&key=" .. ":streamGenerateContent?alt=sse&key="
.. os.getenv(base.api_key_name or M.api_key_name), .. provider.parse_api_key(),
proxy = base.proxy, proxy = base.proxy,
insecure = base.allow_insecure, insecure = base.allow_insecure,
headers = { ["Content-Type"] = "application/json" }, headers = { ["Content-Type"] = "application/json" },

View File

@ -52,35 +52,24 @@ local Dressing = require("avante.ui.dressing")
---@field temperature? number ---@field temperature? number
---@field max_tokens? number ---@field max_tokens? number
--- ---
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
---@field deployment string
---@field api_version string
---@field temperature number
---@field max_tokens number
---
---@class AvanteCopilotProvider: AvanteSupportedProvider
---@field timeout number
---
---@class AvanteGeminiProvider: AvanteDefaultBaseProvider
---@field model string
---
---@class AvanteProvider: AvanteDefaultBaseProvider
---@field parse_response_data AvanteResponseParser
---@field parse_curl_args AvanteCurlArgsParser
---@field parse_stream_data? AvanteStreamParser
---
---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil ---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil
---@alias AvanteChunkParser fun(chunk: string): any ---@alias AvanteChunkParser fun(chunk: string): any
---@alias AvanteCompleteParser fun(err: string|nil): nil ---@alias AvanteCompleteParser fun(err: string|nil): nil
---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table<string, any> ---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table<string, any>
--- ---
---@class AvanteProvider: AvanteSupportedProvider
---@field parse_response_data AvanteResponseParser
---@field parse_curl_args? AvanteCurlArgsParser
---@field parse_stream_data? AvanteStreamParser
---
---@class AvanteProviderFunctor ---@class AvanteProviderFunctor
---@field parse_message AvanteMessageParser ---@field parse_message AvanteMessageParser
---@field parse_response AvanteResponseParser ---@field parse_response AvanteResponseParser
---@field parse_curl_args AvanteCurlArgsParser ---@field parse_curl_args AvanteCurlArgsParser
---@field setup? fun(): nil ---@field setup fun(): nil
---@field has fun(): boolean ---@field has fun(): boolean
---@field api_key_name string ---@field api_key_name string
---@field parse_api_key fun(): string | nil
---@field parse_stream_data? AvanteStreamParser ---@field parse_stream_data? AvanteStreamParser
--- ---
---@class avante.Providers ---@class avante.Providers
@ -92,38 +81,50 @@ local Dressing = require("avante.ui.dressing")
---@field cohere AvanteProviderFunctor ---@field cohere AvanteProviderFunctor
local M = {} local M = {}
setmetatable(M, {
---@param t avante.Providers
---@param k Provider
__index = function(t, k)
if Config.vendors[k] ~= nil then
---@type AvanteProvider
local v = Config.vendors[k]
-- Patch from vendors similar to supported providers.
---@type AvanteProviderFunctor
t[k] = setmetatable({}, { __index = v })
-- Hack for aliasing and makes it sane for us.
t[k].parse_response = v.parse_response_data
t[k].has = function()
return os.getenv(t[k].api_key_name) and true or false
end
return t[k]
end
---@type AvanteProviderFunctor
t[k] = require("avante.providers." .. k)
return t[k]
end,
})
---@class EnvironmentHandler ---@class EnvironmentHandler
local E = {} local E = {}
---@private ---@private
E._once = false E._once = false
---@private
---@type table<string, string>
E.cache = {}
---@param Opts AvanteSupportedProvider | AvanteProviderFunctor
---@return string | nil
E.parse_envvar = function(Opts)
local api_key_name = Opts.api_key_name
if api_key_name == nil then
error("Requires api_key_name")
end
if E.cache[api_key_name] ~= nil then
return E.cache[api_key_name]
end
local cmd = api_key_name:match("^cmd:(.*)")
local key = nil
if cmd ~= nil then
local ok, job = pcall(vim.system, vim.split(cmd, " ", { trimempty = true }), { text = true })
if not ok then
Utils.error("Failed to execute command to retrieve secrets: " .. cmd, { once = true, title = "Avante" })
else
local out = job:wait()
key = out.stdout
end
else
key = os.getenv(api_key_name)
end
if key ~= nil then
E.cache[api_key_name] = key
end
return key
end
--- initialize the environment variable for current neovim session. --- initialize the environment variable for current neovim session.
--- This will only run once and spawn a UI for users to input the envvar. --- This will only run once and spawn a UI for users to input the envvar.
---@param opts {refresh: boolean, provider: AvanteProviderFunctor} ---@param opts {refresh: boolean, provider: AvanteProviderFunctor}
@ -131,7 +132,8 @@ E._once = false
E.setup = function(opts) E.setup = function(opts)
local var = opts.provider.api_key_name local var = opts.provider.api_key_name
if var == M.AVANTE_INTERNAL_KEY then -- check if var is a all caps string
if var == M.AVANTE_INTERNAL_KEY or var:match("^cmd:(.*)") then
return return
end end
@ -149,51 +151,54 @@ E.setup = function(opts)
end end
end end
if refresh then local function mount_dressing_buffer()
vim.defer_fn(function() vim.defer_fn(function()
Dressing.initialize_input_buffer({ opts = { prompt = "Enter " .. var .. ": " }, on_confirm = on_confirm }) -- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
local exclude_buftypes = { "qf", "nofile" }
local exclude_filetypes = {
"NvimTree",
"Outline",
"help",
"dashboard",
"alpha",
"qf",
"ministarter",
"TelescopePrompt",
"gitcommit",
"gitrebase",
"DressingInput",
}
if
not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)
and not opts.provider.has()
then
Dressing.initialize_input_buffer({
opts = { prompt = "Enter " .. var .. ": " },
on_confirm = on_confirm,
})
end
end, 200) end, 200)
elseif not E._once then end
if refresh then
mount_dressing_buffer()
return
end
if not E._once then
E._once = true E._once = true
api.nvim_create_autocmd({ "BufEnter", "BufWinEnter", "WinEnter" }, { api.nvim_create_autocmd({ "BufEnter", "BufWinEnter", "WinEnter" }, {
pattern = "*", pattern = "*",
once = true, once = true,
callback = function() callback = mount_dressing_buffer,
vim.defer_fn(function()
-- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
local exclude_buftypes = { "qf", "nofile" }
local exclude_filetypes = {
"NvimTree",
"Outline",
"help",
"dashboard",
"alpha",
"qf",
"ministarter",
"TelescopePrompt",
"gitcommit",
"gitrebase",
"DressingInput",
}
if
not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)
and not opts.provider.has()
then
Dressing.initialize_input_buffer({
opts = { prompt = "Enter " .. var .. ": " },
on_confirm = on_confirm,
})
end
end, 200)
end,
}) })
end end
end end
---@param provider Provider ---@param provider Provider
E.is_local = function(provider) E.is_local = function(provider)
local cur = M.get(provider) local cur = M.get_config(provider)
return cur["local"] ~= nil and cur["local"] or false return cur["local"] ~= nil and cur["local"] or false
end end
@ -201,14 +206,47 @@ M.env = E
M.AVANTE_INTERNAL_KEY = "__avante_env_internal" M.AVANTE_INTERNAL_KEY = "__avante_env_internal"
M = setmetatable(M, {
---@param t avante.Providers
---@param k Provider
__index = function(t, k)
---@type AvanteProviderFunctor
local Opts = M.get_config(k)
if Config.vendors[k] ~= nil then
Opts.parse_response = Opts.parse_response_data
t[k] = Opts
else
t[k] = vim.tbl_deep_extend("keep", Opts, require("avante.providers." .. k))
end
t[k].parse_api_key = function()
return E.parse_envvar(t[k])
end
if t[k].has == nil then
t[k].has = function()
return E.parse_envvar(t[k]) ~= nil
end
end
if t[k].setup == nil then
t[k].setup = function()
t[k].parse_api_key()
end
end
return t[k]
end,
})
M.setup = function() M.setup = function()
---@type AvanteProviderFunctor ---@type AvanteProviderFunctor
local provider = M[Config.provider] local provider = M[Config.provider]
E.setup({ provider = provider }) E.setup({ provider = provider })
vim.schedule(function()
if provider.setup ~= nil then
provider.setup() provider.setup()
end end)
M.commands() M.commands()
end end
@ -216,6 +254,8 @@ end
---@private ---@private
---@param provider Provider ---@param provider Provider
function M.refresh(provider) function M.refresh(provider)
require("avante.config").override({ provider = provider })
---@type AvanteProviderFunctor ---@type AvanteProviderFunctor
local p = M[Config.provider] local p = M[Config.provider]
if not p.has() then if not p.has() then
@ -223,7 +263,6 @@ function M.refresh(provider)
else else
Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" }) Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" })
end end
require("avante.config").override({ provider = provider })
end end
local default_providers = { "openai", "claude", "azure", "gemini", "copilot" } local default_providers = { "openai", "claude", "azure", "gemini", "copilot" }
@ -242,7 +281,8 @@ M.commands = function()
end end
local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or "" local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or ""
-- join two tables -- join two tables
local Keys = vim.list_extend(default_providers, vim.tbl_keys(Config.vendors or {})) local Keys = vim.list_extend({}, default_providers)
Keys = vim.list_extend(Keys, vim.tbl_keys(Config.vendors or {}))
return vim.tbl_filter(function(key) return vim.tbl_filter(function(key)
return key:find(prefix) == 1 return key:find(prefix) == 1
end, Keys) end, Keys)
@ -280,7 +320,8 @@ end
---@private ---@private
---@param provider Provider ---@param provider Provider
M.get = function(provider) ---@return AvanteProviderFunctor
M.get_config = function(provider)
local cur = Config.get_provider(provider or Config.provider) local cur = Config.get_provider(provider or Config.provider)
return type(cur) == "function" and cur() or cur return type(cur) == "function" and cur() or cur
end end

View File

@ -26,10 +26,6 @@ local M = {}
M.api_key_name = "OPENAI_API_KEY" M.api_key_name = "OPENAI_API_KEY"
M.has = function()
return os.getenv(M.api_key_name) and true or false
end
M.parse_message = function(opts) M.parse_message = function(opts)
local user_prompt = opts.base_prompt local user_prompt = opts.base_prompt
.. "\n\nCODE:\n" .. "\n\nCODE:\n"
@ -91,7 +87,7 @@ M.parse_curl_args = function(provider, code_opts)
["Content-Type"] = "application/json", ["Content-Type"] = "application/json",
} }
if not P.env.is_local("openai") then if not P.env.is_local("openai") then
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name) headers["Authorization"] = "Bearer " .. provider.parse_api_key()
end end
return { return {