feat(llm): add support for parsing secret vault (#200)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
8d375dd591
commit
a7d3defa3d
@ -15,6 +15,7 @@ M.defaults = {
|
|||||||
openai = {
|
openai = {
|
||||||
endpoint = "https://api.openai.com/v1",
|
endpoint = "https://api.openai.com/v1",
|
||||||
model = "gpt-4o",
|
model = "gpt-4o",
|
||||||
|
timeout = 30000, -- Timeout in milliseconds
|
||||||
temperature = 0,
|
temperature = 0,
|
||||||
max_tokens = 4096,
|
max_tokens = 4096,
|
||||||
["local"] = false,
|
["local"] = false,
|
||||||
@ -34,6 +35,7 @@ M.defaults = {
|
|||||||
endpoint = "", -- example: "https://<your-resource-name>.openai.azure.com"
|
endpoint = "", -- example: "https://<your-resource-name>.openai.azure.com"
|
||||||
deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment")
|
deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment")
|
||||||
api_version = "2024-06-01",
|
api_version = "2024-06-01",
|
||||||
|
timeout = 30000, -- Timeout in milliseconds
|
||||||
temperature = 0,
|
temperature = 0,
|
||||||
max_tokens = 4096,
|
max_tokens = 4096,
|
||||||
["local"] = false,
|
["local"] = false,
|
||||||
@ -42,22 +44,25 @@ M.defaults = {
|
|||||||
claude = {
|
claude = {
|
||||||
endpoint = "https://api.anthropic.com",
|
endpoint = "https://api.anthropic.com",
|
||||||
model = "claude-3-5-sonnet-20240620",
|
model = "claude-3-5-sonnet-20240620",
|
||||||
|
timeout = 30000, -- Timeout in milliseconds
|
||||||
temperature = 0,
|
temperature = 0,
|
||||||
max_tokens = 4096,
|
max_tokens = 4096,
|
||||||
["local"] = false,
|
["local"] = false,
|
||||||
},
|
},
|
||||||
---@type AvanteGeminiProvider
|
---@type AvanteSupportedProvider
|
||||||
gemini = {
|
gemini = {
|
||||||
endpoint = "https://generativelanguage.googleapis.com/v1beta/models",
|
endpoint = "https://generativelanguage.googleapis.com/v1beta/models",
|
||||||
model = "gemini-1.5-pro",
|
model = "gemini-1.5-pro",
|
||||||
|
timeout = 30000, -- Timeout in milliseconds
|
||||||
temperature = 0,
|
temperature = 0,
|
||||||
max_tokens = 4096,
|
max_tokens = 4096,
|
||||||
["local"] = false,
|
["local"] = false,
|
||||||
},
|
},
|
||||||
---@type AvanteGeminiProvider
|
---@type AvanteGeminiProvider
|
||||||
cohere = {
|
cohere = {
|
||||||
endpoint = "https://api.cohere.com",
|
endpoint = "https://api.cohere.com/v1",
|
||||||
model = "command-r-plus",
|
model = "command-r-plus",
|
||||||
|
timeout = 30000, -- Timeout in milliseconds
|
||||||
temperature = 0,
|
temperature = 0,
|
||||||
max_tokens = 3072,
|
max_tokens = 3072,
|
||||||
["local"] = false,
|
["local"] = false,
|
||||||
@ -189,7 +194,7 @@ end
|
|||||||
|
|
||||||
---get supported providers
|
---get supported providers
|
||||||
---@param provider Provider
|
---@param provider Provider
|
||||||
---@return AvanteProvider | fun(): AvanteProvider
|
---@return AvanteProviderFunctor
|
||||||
M.get_provider = function(provider)
|
M.get_provider = function(provider)
|
||||||
if M.options[provider] ~= nil then
|
if M.options[provider] ~= nil then
|
||||||
return vim.deepcopy(M.options[provider], true)
|
return vim.deepcopy(M.options[provider], true)
|
||||||
@ -200,8 +205,17 @@ M.get_provider = function(provider)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
M.BASE_PROVIDER_KEYS =
|
M.BASE_PROVIDER_KEYS = {
|
||||||
{ "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure", "api_key_name" }
|
"endpoint",
|
||||||
|
"model",
|
||||||
|
"local",
|
||||||
|
"deployment",
|
||||||
|
"api_version",
|
||||||
|
"proxy",
|
||||||
|
"allow_insecure",
|
||||||
|
"api_key_name",
|
||||||
|
"timeout",
|
||||||
|
}
|
||||||
|
|
||||||
---@return {width: integer, height: integer}
|
---@return {width: integer, height: integer}
|
||||||
function M.get_sidebar_layout_options()
|
function M.get_sidebar_layout_options()
|
||||||
|
@ -94,7 +94,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
|||||||
---@type AvanteHandlerOptions
|
---@type AvanteHandlerOptions
|
||||||
local handler_opts = { on_chunk = on_chunk, on_complete = on_complete }
|
local handler_opts = { on_chunk = on_chunk, on_complete = on_complete }
|
||||||
---@type AvanteCurlOutput
|
---@type AvanteCurlOutput
|
||||||
local spec = Provider.parse_curl_args(Config.get_provider(provider), code_opts)
|
local spec = Provider.parse_curl_args(Provider, code_opts)
|
||||||
|
|
||||||
---@param line string
|
---@param line string
|
||||||
local function parse_stream_data(line)
|
local function parse_stream_data(line)
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
|
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
|
||||||
|
---@field deployment string
|
||||||
|
---@field api_version string
|
||||||
|
---@field temperature number
|
||||||
|
---@field max_tokens number
|
||||||
|
|
||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
local Config = require("avante.config")
|
|
||||||
local P = require("avante.providers")
|
local P = require("avante.providers")
|
||||||
local O = require("avante.providers").openai
|
local O = require("avante.providers").openai
|
||||||
|
|
||||||
@ -8,10 +13,6 @@ local M = {}
|
|||||||
|
|
||||||
M.api_key_name = "AZURE_OPENAI_API_KEY"
|
M.api_key_name = "AZURE_OPENAI_API_KEY"
|
||||||
|
|
||||||
M.has = function()
|
|
||||||
return os.getenv(M.api_key_name) and true or false
|
|
||||||
end
|
|
||||||
|
|
||||||
M.parse_message = O.parse_message
|
M.parse_message = O.parse_message
|
||||||
M.parse_response = O.parse_response
|
M.parse_response = O.parse_response
|
||||||
|
|
||||||
@ -22,7 +23,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
}
|
}
|
||||||
if not P.env.is_local("azure") then
|
if not P.env.is_local("azure") then
|
||||||
headers["api-key"] = os.getenv(base.api_key_name or M.api_key_name)
|
headers["api-key"] = provider.parse_api_key()
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
local Config = require("avante.config")
|
|
||||||
local Tiktoken = require("avante.tiktoken")
|
local Tiktoken = require("avante.tiktoken")
|
||||||
local P = require("avante.providers")
|
local P = require("avante.providers")
|
||||||
|
|
||||||
@ -8,10 +7,6 @@ local M = {}
|
|||||||
|
|
||||||
M.api_key_name = "ANTHROPIC_API_KEY"
|
M.api_key_name = "ANTHROPIC_API_KEY"
|
||||||
|
|
||||||
M.has = function()
|
|
||||||
return os.getenv(M.api_key_name) and true or false
|
|
||||||
end
|
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_message = function(opts)
|
||||||
local code_prompt_obj = {
|
local code_prompt_obj = {
|
||||||
type = "text",
|
type = "text",
|
||||||
@ -93,7 +88,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
||||||
}
|
}
|
||||||
if not P.env.is_local("claude") then
|
if not P.env.is_local("claude") then
|
||||||
headers["x-api-key"] = os.getenv(base.api_key_name or M.api_key_name)
|
headers["x-api-key"] = provider.parse_api_key()
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -30,10 +30,6 @@ local M = {}
|
|||||||
|
|
||||||
M.api_key_name = "CO_API_KEY"
|
M.api_key_name = "CO_API_KEY"
|
||||||
|
|
||||||
M.has = function()
|
|
||||||
return os.getenv(M.api_key_name) and true or false
|
|
||||||
end
|
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_message = function(opts)
|
||||||
local user_prompt = opts.base_prompt
|
local user_prompt = opts.base_prompt
|
||||||
.. "\n\nCODE:\n"
|
.. "\n\nCODE:\n"
|
||||||
@ -103,11 +99,11 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
.. vim.version().patch,
|
.. vim.version().patch,
|
||||||
}
|
}
|
||||||
if not P.env.is_local("openai") then
|
if not P.env.is_local("openai") then
|
||||||
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
|
headers["Authorization"] = "Bearer " .. provider.parse_api_key()
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/chat",
|
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat",
|
||||||
proxy = base.proxy,
|
proxy = base.proxy,
|
||||||
insecure = base.allow_insecure,
|
insecure = base.allow_insecure,
|
||||||
headers = headers,
|
headers = headers,
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
---@class AvanteCopilotProvider: AvanteSupportedProvider
|
||||||
|
---@field timeout number
|
||||||
|
|
||||||
local curl = require("plenary.curl")
|
local curl = require("plenary.curl")
|
||||||
|
|
||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
local Config = require("avante.config")
|
|
||||||
local P = require("avante.providers")
|
local P = require("avante.providers")
|
||||||
local O = require("avante.providers").openai
|
local O = require("avante.providers").openai
|
||||||
|
|
||||||
@ -196,7 +198,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
local response = curl.get(url, {
|
local response = curl.get(url, {
|
||||||
timeout = Config.copilot.timeout,
|
timeout = base.timeout,
|
||||||
headers = headers,
|
headers = headers,
|
||||||
proxy = base.proxy,
|
proxy = base.proxy,
|
||||||
insecure = base.allow_insecure,
|
insecure = base.allow_insecure,
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
local Config = require("avante.config")
|
|
||||||
local P = require("avante.providers")
|
local P = require("avante.providers")
|
||||||
|
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
@ -7,10 +6,6 @@ local M = {}
|
|||||||
|
|
||||||
M.api_key_name = "GEMINI_API_KEY"
|
M.api_key_name = "GEMINI_API_KEY"
|
||||||
|
|
||||||
M.has = function()
|
|
||||||
return os.getenv(M.api_key_name) and true or false
|
|
||||||
end
|
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_message = function(opts)
|
||||||
local code_prompt_obj = {
|
local code_prompt_obj = {
|
||||||
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
|
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
|
||||||
@ -68,7 +63,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
.. "/"
|
.. "/"
|
||||||
.. base.model
|
.. base.model
|
||||||
.. ":streamGenerateContent?alt=sse&key="
|
.. ":streamGenerateContent?alt=sse&key="
|
||||||
.. os.getenv(base.api_key_name or M.api_key_name),
|
.. provider.parse_api_key(),
|
||||||
proxy = base.proxy,
|
proxy = base.proxy,
|
||||||
insecure = base.allow_insecure,
|
insecure = base.allow_insecure,
|
||||||
headers = { ["Content-Type"] = "application/json" },
|
headers = { ["Content-Type"] = "application/json" },
|
||||||
|
@ -52,35 +52,24 @@ local Dressing = require("avante.ui.dressing")
|
|||||||
---@field temperature? number
|
---@field temperature? number
|
||||||
---@field max_tokens? number
|
---@field max_tokens? number
|
||||||
---
|
---
|
||||||
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
|
|
||||||
---@field deployment string
|
|
||||||
---@field api_version string
|
|
||||||
---@field temperature number
|
|
||||||
---@field max_tokens number
|
|
||||||
---
|
|
||||||
---@class AvanteCopilotProvider: AvanteSupportedProvider
|
|
||||||
---@field timeout number
|
|
||||||
---
|
|
||||||
---@class AvanteGeminiProvider: AvanteDefaultBaseProvider
|
|
||||||
---@field model string
|
|
||||||
---
|
|
||||||
---@class AvanteProvider: AvanteDefaultBaseProvider
|
|
||||||
---@field parse_response_data AvanteResponseParser
|
|
||||||
---@field parse_curl_args AvanteCurlArgsParser
|
|
||||||
---@field parse_stream_data? AvanteStreamParser
|
|
||||||
---
|
|
||||||
---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil
|
---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil
|
||||||
---@alias AvanteChunkParser fun(chunk: string): any
|
---@alias AvanteChunkParser fun(chunk: string): any
|
||||||
---@alias AvanteCompleteParser fun(err: string|nil): nil
|
---@alias AvanteCompleteParser fun(err: string|nil): nil
|
||||||
---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table<string, any>
|
---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table<string, any>
|
||||||
---
|
---
|
||||||
|
---@class AvanteProvider: AvanteSupportedProvider
|
||||||
|
---@field parse_response_data AvanteResponseParser
|
||||||
|
---@field parse_curl_args? AvanteCurlArgsParser
|
||||||
|
---@field parse_stream_data? AvanteStreamParser
|
||||||
|
---
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
---@field parse_message AvanteMessageParser
|
---@field parse_message AvanteMessageParser
|
||||||
---@field parse_response AvanteResponseParser
|
---@field parse_response AvanteResponseParser
|
||||||
---@field parse_curl_args AvanteCurlArgsParser
|
---@field parse_curl_args AvanteCurlArgsParser
|
||||||
---@field setup? fun(): nil
|
---@field setup fun(): nil
|
||||||
---@field has fun(): boolean
|
---@field has fun(): boolean
|
||||||
---@field api_key_name string
|
---@field api_key_name string
|
||||||
|
---@field parse_api_key fun(): string | nil
|
||||||
---@field parse_stream_data? AvanteStreamParser
|
---@field parse_stream_data? AvanteStreamParser
|
||||||
---
|
---
|
||||||
---@class avante.Providers
|
---@class avante.Providers
|
||||||
@ -92,38 +81,50 @@ local Dressing = require("avante.ui.dressing")
|
|||||||
---@field cohere AvanteProviderFunctor
|
---@field cohere AvanteProviderFunctor
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
setmetatable(M, {
|
|
||||||
---@param t avante.Providers
|
|
||||||
---@param k Provider
|
|
||||||
__index = function(t, k)
|
|
||||||
if Config.vendors[k] ~= nil then
|
|
||||||
---@type AvanteProvider
|
|
||||||
local v = Config.vendors[k]
|
|
||||||
|
|
||||||
-- Patch from vendors similar to supported providers.
|
|
||||||
---@type AvanteProviderFunctor
|
|
||||||
t[k] = setmetatable({}, { __index = v })
|
|
||||||
-- Hack for aliasing and makes it sane for us.
|
|
||||||
t[k].parse_response = v.parse_response_data
|
|
||||||
t[k].has = function()
|
|
||||||
return os.getenv(t[k].api_key_name) and true or false
|
|
||||||
end
|
|
||||||
|
|
||||||
return t[k]
|
|
||||||
end
|
|
||||||
|
|
||||||
---@type AvanteProviderFunctor
|
|
||||||
t[k] = require("avante.providers." .. k)
|
|
||||||
return t[k]
|
|
||||||
end,
|
|
||||||
})
|
|
||||||
|
|
||||||
---@class EnvironmentHandler
|
---@class EnvironmentHandler
|
||||||
local E = {}
|
local E = {}
|
||||||
|
|
||||||
---@private
|
---@private
|
||||||
E._once = false
|
E._once = false
|
||||||
|
|
||||||
|
---@private
|
||||||
|
---@type table<string, string>
|
||||||
|
E.cache = {}
|
||||||
|
|
||||||
|
---@param Opts AvanteSupportedProvider | AvanteProviderFunctor
|
||||||
|
---@return string | nil
|
||||||
|
E.parse_envvar = function(Opts)
|
||||||
|
local api_key_name = Opts.api_key_name
|
||||||
|
if api_key_name == nil then
|
||||||
|
error("Requires api_key_name")
|
||||||
|
end
|
||||||
|
|
||||||
|
if E.cache[api_key_name] ~= nil then
|
||||||
|
return E.cache[api_key_name]
|
||||||
|
end
|
||||||
|
|
||||||
|
local cmd = api_key_name:match("^cmd:(.*)")
|
||||||
|
|
||||||
|
local key = nil
|
||||||
|
if cmd ~= nil then
|
||||||
|
local ok, job = pcall(vim.system, vim.split(cmd, " ", { trimempty = true }), { text = true })
|
||||||
|
if not ok then
|
||||||
|
Utils.error("Failed to execute command to retrieve secrets: " .. cmd, { once = true, title = "Avante" })
|
||||||
|
else
|
||||||
|
local out = job:wait()
|
||||||
|
key = out.stdout
|
||||||
|
end
|
||||||
|
else
|
||||||
|
key = os.getenv(api_key_name)
|
||||||
|
end
|
||||||
|
|
||||||
|
if key ~= nil then
|
||||||
|
E.cache[api_key_name] = key
|
||||||
|
end
|
||||||
|
|
||||||
|
return key
|
||||||
|
end
|
||||||
|
|
||||||
--- initialize the environment variable for current neovim session.
|
--- initialize the environment variable for current neovim session.
|
||||||
--- This will only run once and spawn a UI for users to input the envvar.
|
--- This will only run once and spawn a UI for users to input the envvar.
|
||||||
---@param opts {refresh: boolean, provider: AvanteProviderFunctor}
|
---@param opts {refresh: boolean, provider: AvanteProviderFunctor}
|
||||||
@ -131,7 +132,8 @@ E._once = false
|
|||||||
E.setup = function(opts)
|
E.setup = function(opts)
|
||||||
local var = opts.provider.api_key_name
|
local var = opts.provider.api_key_name
|
||||||
|
|
||||||
if var == M.AVANTE_INTERNAL_KEY then
|
-- check if var is a all caps string
|
||||||
|
if var == M.AVANTE_INTERNAL_KEY or var:match("^cmd:(.*)") then
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -149,51 +151,54 @@ E.setup = function(opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
if refresh then
|
local function mount_dressing_buffer()
|
||||||
vim.defer_fn(function()
|
vim.defer_fn(function()
|
||||||
Dressing.initialize_input_buffer({ opts = { prompt = "Enter " .. var .. ": " }, on_confirm = on_confirm })
|
-- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
|
||||||
|
local exclude_buftypes = { "qf", "nofile" }
|
||||||
|
local exclude_filetypes = {
|
||||||
|
"NvimTree",
|
||||||
|
"Outline",
|
||||||
|
"help",
|
||||||
|
"dashboard",
|
||||||
|
"alpha",
|
||||||
|
"qf",
|
||||||
|
"ministarter",
|
||||||
|
"TelescopePrompt",
|
||||||
|
"gitcommit",
|
||||||
|
"gitrebase",
|
||||||
|
"DressingInput",
|
||||||
|
}
|
||||||
|
if
|
||||||
|
not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
|
||||||
|
and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)
|
||||||
|
and not opts.provider.has()
|
||||||
|
then
|
||||||
|
Dressing.initialize_input_buffer({
|
||||||
|
opts = { prompt = "Enter " .. var .. ": " },
|
||||||
|
on_confirm = on_confirm,
|
||||||
|
})
|
||||||
|
end
|
||||||
end, 200)
|
end, 200)
|
||||||
elseif not E._once then
|
end
|
||||||
|
|
||||||
|
if refresh then
|
||||||
|
mount_dressing_buffer()
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
if not E._once then
|
||||||
E._once = true
|
E._once = true
|
||||||
api.nvim_create_autocmd({ "BufEnter", "BufWinEnter", "WinEnter" }, {
|
api.nvim_create_autocmd({ "BufEnter", "BufWinEnter", "WinEnter" }, {
|
||||||
pattern = "*",
|
pattern = "*",
|
||||||
once = true,
|
once = true,
|
||||||
callback = function()
|
callback = mount_dressing_buffer,
|
||||||
vim.defer_fn(function()
|
|
||||||
-- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
|
|
||||||
local exclude_buftypes = { "qf", "nofile" }
|
|
||||||
local exclude_filetypes = {
|
|
||||||
"NvimTree",
|
|
||||||
"Outline",
|
|
||||||
"help",
|
|
||||||
"dashboard",
|
|
||||||
"alpha",
|
|
||||||
"qf",
|
|
||||||
"ministarter",
|
|
||||||
"TelescopePrompt",
|
|
||||||
"gitcommit",
|
|
||||||
"gitrebase",
|
|
||||||
"DressingInput",
|
|
||||||
}
|
|
||||||
if
|
|
||||||
not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
|
|
||||||
and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)
|
|
||||||
and not opts.provider.has()
|
|
||||||
then
|
|
||||||
Dressing.initialize_input_buffer({
|
|
||||||
opts = { prompt = "Enter " .. var .. ": " },
|
|
||||||
on_confirm = on_confirm,
|
|
||||||
})
|
|
||||||
end
|
|
||||||
end, 200)
|
|
||||||
end,
|
|
||||||
})
|
})
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param provider Provider
|
---@param provider Provider
|
||||||
E.is_local = function(provider)
|
E.is_local = function(provider)
|
||||||
local cur = M.get(provider)
|
local cur = M.get_config(provider)
|
||||||
return cur["local"] ~= nil and cur["local"] or false
|
return cur["local"] ~= nil and cur["local"] or false
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -201,14 +206,47 @@ M.env = E
|
|||||||
|
|
||||||
M.AVANTE_INTERNAL_KEY = "__avante_env_internal"
|
M.AVANTE_INTERNAL_KEY = "__avante_env_internal"
|
||||||
|
|
||||||
|
M = setmetatable(M, {
|
||||||
|
---@param t avante.Providers
|
||||||
|
---@param k Provider
|
||||||
|
__index = function(t, k)
|
||||||
|
---@type AvanteProviderFunctor
|
||||||
|
local Opts = M.get_config(k)
|
||||||
|
|
||||||
|
if Config.vendors[k] ~= nil then
|
||||||
|
Opts.parse_response = Opts.parse_response_data
|
||||||
|
t[k] = Opts
|
||||||
|
else
|
||||||
|
t[k] = vim.tbl_deep_extend("keep", Opts, require("avante.providers." .. k))
|
||||||
|
end
|
||||||
|
|
||||||
|
t[k].parse_api_key = function()
|
||||||
|
return E.parse_envvar(t[k])
|
||||||
|
end
|
||||||
|
|
||||||
|
if t[k].has == nil then
|
||||||
|
t[k].has = function()
|
||||||
|
return E.parse_envvar(t[k]) ~= nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if t[k].setup == nil then
|
||||||
|
t[k].setup = function()
|
||||||
|
t[k].parse_api_key()
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return t[k]
|
||||||
|
end,
|
||||||
|
})
|
||||||
|
|
||||||
M.setup = function()
|
M.setup = function()
|
||||||
---@type AvanteProviderFunctor
|
---@type AvanteProviderFunctor
|
||||||
local provider = M[Config.provider]
|
local provider = M[Config.provider]
|
||||||
E.setup({ provider = provider })
|
E.setup({ provider = provider })
|
||||||
|
vim.schedule(function()
|
||||||
if provider.setup ~= nil then
|
|
||||||
provider.setup()
|
provider.setup()
|
||||||
end
|
end)
|
||||||
|
|
||||||
M.commands()
|
M.commands()
|
||||||
end
|
end
|
||||||
@ -216,6 +254,8 @@ end
|
|||||||
---@private
|
---@private
|
||||||
---@param provider Provider
|
---@param provider Provider
|
||||||
function M.refresh(provider)
|
function M.refresh(provider)
|
||||||
|
require("avante.config").override({ provider = provider })
|
||||||
|
|
||||||
---@type AvanteProviderFunctor
|
---@type AvanteProviderFunctor
|
||||||
local p = M[Config.provider]
|
local p = M[Config.provider]
|
||||||
if not p.has() then
|
if not p.has() then
|
||||||
@ -223,7 +263,6 @@ function M.refresh(provider)
|
|||||||
else
|
else
|
||||||
Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" })
|
Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" })
|
||||||
end
|
end
|
||||||
require("avante.config").override({ provider = provider })
|
|
||||||
end
|
end
|
||||||
|
|
||||||
local default_providers = { "openai", "claude", "azure", "gemini", "copilot" }
|
local default_providers = { "openai", "claude", "azure", "gemini", "copilot" }
|
||||||
@ -242,7 +281,8 @@ M.commands = function()
|
|||||||
end
|
end
|
||||||
local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or ""
|
local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or ""
|
||||||
-- join two tables
|
-- join two tables
|
||||||
local Keys = vim.list_extend(default_providers, vim.tbl_keys(Config.vendors or {}))
|
local Keys = vim.list_extend({}, default_providers)
|
||||||
|
Keys = vim.list_extend(Keys, vim.tbl_keys(Config.vendors or {}))
|
||||||
return vim.tbl_filter(function(key)
|
return vim.tbl_filter(function(key)
|
||||||
return key:find(prefix) == 1
|
return key:find(prefix) == 1
|
||||||
end, Keys)
|
end, Keys)
|
||||||
@ -280,7 +320,8 @@ end
|
|||||||
|
|
||||||
---@private
|
---@private
|
||||||
---@param provider Provider
|
---@param provider Provider
|
||||||
M.get = function(provider)
|
---@return AvanteProviderFunctor
|
||||||
|
M.get_config = function(provider)
|
||||||
local cur = Config.get_provider(provider or Config.provider)
|
local cur = Config.get_provider(provider or Config.provider)
|
||||||
return type(cur) == "function" and cur() or cur
|
return type(cur) == "function" and cur() or cur
|
||||||
end
|
end
|
||||||
|
@ -26,10 +26,6 @@ local M = {}
|
|||||||
|
|
||||||
M.api_key_name = "OPENAI_API_KEY"
|
M.api_key_name = "OPENAI_API_KEY"
|
||||||
|
|
||||||
M.has = function()
|
|
||||||
return os.getenv(M.api_key_name) and true or false
|
|
||||||
end
|
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_message = function(opts)
|
||||||
local user_prompt = opts.base_prompt
|
local user_prompt = opts.base_prompt
|
||||||
.. "\n\nCODE:\n"
|
.. "\n\nCODE:\n"
|
||||||
@ -91,7 +87,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
}
|
}
|
||||||
if not P.env.is_local("openai") then
|
if not P.env.is_local("openai") then
|
||||||
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
|
headers["Authorization"] = "Bearer " .. provider.parse_api_key()
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user