fix(llm): persistent key check for override class (#158)

* fix(llm): make sure to allow passing custom module

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>

* fix: correct custom class

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>

* fix: correct attribute

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>

---------

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2024-08-22 23:52:49 -04:00 committed by GitHub
parent 49fabfc358
commit 6475407d0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 24 additions and 26 deletions

View File

@ -204,7 +204,8 @@ M.get_provider = function(provider)
end
end
M.BASE_PROVIDER_KEYS = { "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure" }
M.BASE_PROVIDER_KEYS =
{ "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure", "api_key_name" }
---@return {width: integer, height: integer}
function M.get_sidebar_layout_options()

View File

@ -6,10 +6,10 @@ local O = require("avante.providers").openai
---@class AvanteProviderFunctor
local M = {}
M.API_KEY = "AZURE_OPENAI_API_KEY"
M.api_key_name = "AZURE_OPENAI_API_KEY"
M.has = function()
return os.getenv(M.API_KEY) and true or false
return os.getenv(M.api_key_name) and true or false
end
M.parse_message = O.parse_message
@ -22,7 +22,7 @@ M.parse_curl_args = function(provider, code_opts)
["Content-Type"] = "application/json",
}
if not P.env.is_local("azure") then
headers["api-key"] = os.getenv(M.API_KEY)
headers["api-key"] = os.getenv(base.api_key_name or M.api_key_name)
end
return {

View File

@ -6,10 +6,10 @@ local P = require("avante.providers")
---@class AvanteProviderFunctor
local M = {}
M.API_KEY = "ANTHROPIC_API_KEY"
M.api_key_name = "ANTHROPIC_API_KEY"
M.has = function()
return os.getenv(M.API_KEY) and true or false
return os.getenv(M.api_key_name) and true or false
end
M.parse_message = function(opts)
@ -93,7 +93,7 @@ M.parse_curl_args = function(provider, code_opts)
["anthropic-beta"] = "prompt-caching-2024-07-31",
}
if not P.env.is_local("claude") then
headers["x-api-key"] = os.getenv(M.API_KEY)
headers["x-api-key"] = os.getenv(base.api_key_name or M.api_key_name)
end
return {

View File

@ -142,7 +142,7 @@ H.generate_headers = function(token, sessionid, machineid)
return headers
end
M.API_KEY = P.AVANTE_INTERNAL_KEY
M.api_key_name = P.AVANTE_INTERNAL_KEY
M.has = function()
if Utils.has("copilot.lua") or Utils.has("copilot.vim") or H.find_config_path() then

View File

@ -6,10 +6,10 @@ local O = require("avante.providers").openai
---@class AvanteProviderFunctor
local M = {}
M.API_KEY = "DEEPSEEK_API_KEY"
M.api_key_name = "DEEPSEEK_API_KEY"
M.has = function()
return os.getenv(M.API_KEY) and true or false
return os.getenv(M.api_key_name) and true or false
end
M.parse_message = O.parse_message
@ -22,7 +22,7 @@ M.parse_curl_args = function(provider, code_opts)
["Content-Type"] = "application/json",
}
if not P.env.is_local("deepseek") then
headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY)
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
end
return {

View File

@ -5,10 +5,10 @@ local P = require("avante.providers")
---@class AvanteProviderFunctor
local M = {}
M.API_KEY = "GROQ_API_KEY"
M.api_key_name = "GROQ_API_KEY"
M.has = function()
return os.getenv(M.API_KEY) and true or false
return os.getenv(M.api_key_name) and true or false
end
M.parse_message = function(opts)
@ -68,7 +68,7 @@ M.parse_curl_args = function(provider, code_opts)
.. "/"
.. base.model
.. ":streamGenerateContent?alt=sse&key="
.. os.getenv(M.API_KEY),
.. os.getenv(base.api_key_name or M.api_key_name),
proxy = base.proxy,
insecure = base.allow_insecure,
headers = { ["Content-Type"] = "application/json" },

View File

@ -6,10 +6,10 @@ local O = require("avante.providers").openai
---@class AvanteProviderFunctor
local M = {}
M.API_KEY = "GROQ_API_KEY"
M.api_key_name = "GROQ_API_KEY"
M.has = function()
return os.getenv(M.API_KEY) and true or false
return os.getenv(M.api_key_name) and true or false
end
M.parse_message = O.parse_message
@ -22,7 +22,7 @@ M.parse_curl_args = function(provider, code_opts)
["Content-Type"] = "application/json",
}
if not P.env.is_local("groq") then
headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY)
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
end
return {

View File

@ -80,7 +80,7 @@ local Dressing = require("avante.ui.dressing")
---@field parse_curl_args AvanteCurlArgsParser
---@field setup? fun(): nil
---@field has fun(): boolean
---@field API_KEY string
---@field api_key_name string
---@field parse_stream_data? AvanteStreamParser
---
---@class avante.Providers
@ -102,13 +102,10 @@ setmetatable(M, {
local v = Config.vendors[k]
-- Patch from vendors similar to supported providers.
---@type AvanteProviderFunctor
t[k] = setmetatable({}, { __index = v })
t[k].API_KEY = v.api_key_name
-- Hack for aliasing and makes it sane for us.
t[k].parse_response = v.parse_response_data
t[k].has = function()
return os.getenv(v.api_key_name) and true or false
end
return t[k]
end
@ -130,7 +127,7 @@ E._once = false
---@param opts {refresh: boolean, provider: AvanteProviderFunctor}
---@private
E.setup = function(opts)
local var = opts.provider.API_KEY
local var = opts.provider.api_key_name
if var == M.AVANTE_INTERNAL_KEY then
return

View File

@ -24,10 +24,10 @@ local P = require("avante.providers")
---@class AvanteProviderFunctor
local M = {}
M.API_KEY = "OPENAI_API_KEY"
M.api_key_name = "OPENAI_API_KEY"
M.has = function()
return os.getenv(M.API_KEY) and true or false
return os.getenv(M.api_key_name) and true or false
end
M.parse_message = function(opts)
@ -91,7 +91,7 @@ M.parse_curl_args = function(provider, code_opts)
["Content-Type"] = "application/json",
}
if not P.env.is_local("openai") then
headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY)
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
end
return {