fix(env): remove fallback and respect one env (#60)

Make sure to set the corresponding env.
This commit is contained in:
Hanchin Hsieh 2024-08-18 17:54:29 +08:00 committed by GitHub
parent d885bd9680
commit c8a764b3a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,17 +12,13 @@ local Tiktoken = require("avante.tiktoken")
---@class avante.AiBot ---@class avante.AiBot
local M = {} local M = {}
---@class Environment: table<[string], any>
---@field [string] string the environment variable name
---@field fallback? string Optional fallback API key environment variable name
---@class EnvironmentHandler: table<[Provider], string> ---@class EnvironmentHandler: table<[Provider], string>
local E = { local E = {
---@type table<Provider, Environment | string> ---@type table<Provider, string>
env = { env = {
openai = "OPENAI_API_KEY", openai = "OPENAI_API_KEY",
claude = "ANTHROPIC_API_KEY", claude = "ANTHROPIC_API_KEY",
azure = { "AZURE_OPENAI_API_KEY", fallback = "OPENAI_API_KEY" }, azure = "AZURE_OPENAI_API_KEY",
}, },
_once = false, _once = false,
} }
@ -30,20 +26,7 @@ local E = {
E = setmetatable(E, { E = setmetatable(E, {
---@param k Provider ---@param k Provider
__index = function(_, k) __index = function(_, k)
local envvar = E.env[k] return os.getenv(E.env[k]) and true or false
if type(envvar) == "string" then
local value = os.getenv(envvar)
return value and true or false
elseif type(envvar) == "table" then
local main_key = envvar[1]
local value = os.getenv(main_key)
if value then
return true
elseif envvar.fallback then
return os.getenv(envvar.fallback) and true or false
end
end
return false
end, end,
}) })
@ -137,8 +120,17 @@ E.setup = function(var)
vim.defer_fn(function() vim.defer_fn(function()
-- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf -- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
local exclude_buftypes = { "dashboard", "alpha", "qf", "nofile" } local exclude_buftypes = { "dashboard", "alpha", "qf", "nofile" }
local exclude_filetypes = local exclude_filetypes = {
{ "NvimTree", "Outline", "help", "dashboard", "alpha", "qf", "ministarter", "TelescopePrompt", "gitcommit" } "NvimTree",
"Outline",
"help",
"dashboard",
"alpha",
"qf",
"ministarter",
"TelescopePrompt",
"gitcommit",
}
if if
not vim.tbl_contains(exclude_buftypes, vim.bo.buftype) not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype) and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)