fix(llm): persistent key check for override class (#158)
* fix(llm): make sure to allow passing custom module Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * fix: correct custom class Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * fix: correct attribute Signed-off-by: Aaron Pham <contact@aarnphm.xyz> --------- Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
49fabfc358
commit
6475407d0d
@ -204,7 +204,8 @@ M.get_provider = function(provider)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
M.BASE_PROVIDER_KEYS = { "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure" }
|
M.BASE_PROVIDER_KEYS =
|
||||||
|
{ "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure", "api_key_name" }
|
||||||
|
|
||||||
---@return {width: integer, height: integer}
|
---@return {width: integer, height: integer}
|
||||||
function M.get_sidebar_layout_options()
|
function M.get_sidebar_layout_options()
|
||||||
|
@ -6,10 +6,10 @@ local O = require("avante.providers").openai
|
|||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.API_KEY = "AZURE_OPENAI_API_KEY"
|
M.api_key_name = "AZURE_OPENAI_API_KEY"
|
||||||
|
|
||||||
M.has = function()
|
M.has = function()
|
||||||
return os.getenv(M.API_KEY) and true or false
|
return os.getenv(M.api_key_name) and true or false
|
||||||
end
|
end
|
||||||
|
|
||||||
M.parse_message = O.parse_message
|
M.parse_message = O.parse_message
|
||||||
@ -22,7 +22,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
}
|
}
|
||||||
if not P.env.is_local("azure") then
|
if not P.env.is_local("azure") then
|
||||||
headers["api-key"] = os.getenv(M.API_KEY)
|
headers["api-key"] = os.getenv(base.api_key_name or M.api_key_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -6,10 +6,10 @@ local P = require("avante.providers")
|
|||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.API_KEY = "ANTHROPIC_API_KEY"
|
M.api_key_name = "ANTHROPIC_API_KEY"
|
||||||
|
|
||||||
M.has = function()
|
M.has = function()
|
||||||
return os.getenv(M.API_KEY) and true or false
|
return os.getenv(M.api_key_name) and true or false
|
||||||
end
|
end
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_message = function(opts)
|
||||||
@ -93,7 +93,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
||||||
}
|
}
|
||||||
if not P.env.is_local("claude") then
|
if not P.env.is_local("claude") then
|
||||||
headers["x-api-key"] = os.getenv(M.API_KEY)
|
headers["x-api-key"] = os.getenv(base.api_key_name or M.api_key_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -142,7 +142,7 @@ H.generate_headers = function(token, sessionid, machineid)
|
|||||||
return headers
|
return headers
|
||||||
end
|
end
|
||||||
|
|
||||||
M.API_KEY = P.AVANTE_INTERNAL_KEY
|
M.api_key_name = P.AVANTE_INTERNAL_KEY
|
||||||
|
|
||||||
M.has = function()
|
M.has = function()
|
||||||
if Utils.has("copilot.lua") or Utils.has("copilot.vim") or H.find_config_path() then
|
if Utils.has("copilot.lua") or Utils.has("copilot.vim") or H.find_config_path() then
|
||||||
|
@ -6,10 +6,10 @@ local O = require("avante.providers").openai
|
|||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.API_KEY = "DEEPSEEK_API_KEY"
|
M.api_key_name = "DEEPSEEK_API_KEY"
|
||||||
|
|
||||||
M.has = function()
|
M.has = function()
|
||||||
return os.getenv(M.API_KEY) and true or false
|
return os.getenv(M.api_key_name) and true or false
|
||||||
end
|
end
|
||||||
|
|
||||||
M.parse_message = O.parse_message
|
M.parse_message = O.parse_message
|
||||||
@ -22,7 +22,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
}
|
}
|
||||||
if not P.env.is_local("deepseek") then
|
if not P.env.is_local("deepseek") then
|
||||||
headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY)
|
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -5,10 +5,10 @@ local P = require("avante.providers")
|
|||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.API_KEY = "GROQ_API_KEY"
|
M.api_key_name = "GROQ_API_KEY"
|
||||||
|
|
||||||
M.has = function()
|
M.has = function()
|
||||||
return os.getenv(M.API_KEY) and true or false
|
return os.getenv(M.api_key_name) and true or false
|
||||||
end
|
end
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_message = function(opts)
|
||||||
@ -68,7 +68,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
.. "/"
|
.. "/"
|
||||||
.. base.model
|
.. base.model
|
||||||
.. ":streamGenerateContent?alt=sse&key="
|
.. ":streamGenerateContent?alt=sse&key="
|
||||||
.. os.getenv(M.API_KEY),
|
.. os.getenv(base.api_key_name or M.api_key_name),
|
||||||
proxy = base.proxy,
|
proxy = base.proxy,
|
||||||
insecure = base.allow_insecure,
|
insecure = base.allow_insecure,
|
||||||
headers = { ["Content-Type"] = "application/json" },
|
headers = { ["Content-Type"] = "application/json" },
|
||||||
|
@ -6,10 +6,10 @@ local O = require("avante.providers").openai
|
|||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.API_KEY = "GROQ_API_KEY"
|
M.api_key_name = "GROQ_API_KEY"
|
||||||
|
|
||||||
M.has = function()
|
M.has = function()
|
||||||
return os.getenv(M.API_KEY) and true or false
|
return os.getenv(M.api_key_name) and true or false
|
||||||
end
|
end
|
||||||
|
|
||||||
M.parse_message = O.parse_message
|
M.parse_message = O.parse_message
|
||||||
@ -22,7 +22,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
}
|
}
|
||||||
if not P.env.is_local("groq") then
|
if not P.env.is_local("groq") then
|
||||||
headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY)
|
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -80,7 +80,7 @@ local Dressing = require("avante.ui.dressing")
|
|||||||
---@field parse_curl_args AvanteCurlArgsParser
|
---@field parse_curl_args AvanteCurlArgsParser
|
||||||
---@field setup? fun(): nil
|
---@field setup? fun(): nil
|
||||||
---@field has fun(): boolean
|
---@field has fun(): boolean
|
||||||
---@field API_KEY string
|
---@field api_key_name string
|
||||||
---@field parse_stream_data? AvanteStreamParser
|
---@field parse_stream_data? AvanteStreamParser
|
||||||
---
|
---
|
||||||
---@class avante.Providers
|
---@class avante.Providers
|
||||||
@ -102,13 +102,10 @@ setmetatable(M, {
|
|||||||
local v = Config.vendors[k]
|
local v = Config.vendors[k]
|
||||||
|
|
||||||
-- Patch from vendors similar to supported providers.
|
-- Patch from vendors similar to supported providers.
|
||||||
|
---@type AvanteProviderFunctor
|
||||||
t[k] = setmetatable({}, { __index = v })
|
t[k] = setmetatable({}, { __index = v })
|
||||||
t[k].API_KEY = v.api_key_name
|
|
||||||
-- Hack for aliasing and makes it sane for us.
|
-- Hack for aliasing and makes it sane for us.
|
||||||
t[k].parse_response = v.parse_response_data
|
t[k].parse_response = v.parse_response_data
|
||||||
t[k].has = function()
|
|
||||||
return os.getenv(v.api_key_name) and true or false
|
|
||||||
end
|
|
||||||
|
|
||||||
return t[k]
|
return t[k]
|
||||||
end
|
end
|
||||||
@ -130,7 +127,7 @@ E._once = false
|
|||||||
---@param opts {refresh: boolean, provider: AvanteProviderFunctor}
|
---@param opts {refresh: boolean, provider: AvanteProviderFunctor}
|
||||||
---@private
|
---@private
|
||||||
E.setup = function(opts)
|
E.setup = function(opts)
|
||||||
local var = opts.provider.API_KEY
|
local var = opts.provider.api_key_name
|
||||||
|
|
||||||
if var == M.AVANTE_INTERNAL_KEY then
|
if var == M.AVANTE_INTERNAL_KEY then
|
||||||
return
|
return
|
||||||
|
@ -24,10 +24,10 @@ local P = require("avante.providers")
|
|||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.API_KEY = "OPENAI_API_KEY"
|
M.api_key_name = "OPENAI_API_KEY"
|
||||||
|
|
||||||
M.has = function()
|
M.has = function()
|
||||||
return os.getenv(M.API_KEY) and true or false
|
return os.getenv(M.api_key_name) and true or false
|
||||||
end
|
end
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_message = function(opts)
|
||||||
@ -91,7 +91,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
}
|
}
|
||||||
if not P.env.is_local("openai") then
|
if not P.env.is_local("openai") then
|
||||||
headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY)
|
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user