refactor: better value name (#1261)
This commit is contained in:
parent
9bad591e8a
commit
ce55d7ac9e
@ -25,8 +25,8 @@ M.generate_prompts = function(opts)
|
||||
local Provider = opts.provider or P[Config.provider]
|
||||
local mode = opts.mode or "planning"
|
||||
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||
local _, body_opts = P.parse_config(Provider)
|
||||
local max_tokens = body_opts.max_tokens or 4096
|
||||
local _, request_body = P.parse_config(Provider)
|
||||
local max_tokens = request_body.max_tokens or 4096
|
||||
|
||||
-- Check if the instructions contains an image path
|
||||
local image_paths = {}
|
||||
|
@ -18,31 +18,34 @@ M.parse_response = O.parse_response
|
||||
M.parse_response_without_stream = O.parse_response_without_stream
|
||||
|
||||
M.parse_curl_args = function(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
|
||||
local headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
}
|
||||
if P.env.require_api_key(base) then headers["api-key"] = provider.parse_api_key() end
|
||||
if P.env.require_api_key(provider_conf) then headers["api-key"] = provider.parse_api_key() end
|
||||
|
||||
-- NOTE: When using "o" series set the supported parameters only
|
||||
if O.is_o_series_model(base.model) then
|
||||
body_opts.max_tokens = nil
|
||||
body_opts.temperature = 1
|
||||
if O.is_o_series_model(provider_conf.model) then
|
||||
request_body.max_tokens = nil
|
||||
request_body.temperature = 1
|
||||
end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(
|
||||
base.endpoint,
|
||||
"/openai/deployments/" .. base.deployment .. "/chat/completions?api-version=" .. base.api_version
|
||||
provider_conf.endpoint,
|
||||
"/openai/deployments/"
|
||||
.. provider_conf.deployment
|
||||
.. "/chat/completions?api-version="
|
||||
.. provider_conf.api_version
|
||||
),
|
||||
proxy = base.proxy,
|
||||
insecure = base.allow_insecure,
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = headers,
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
messages = M.parse_messages(prompt_opts),
|
||||
stream = true,
|
||||
}, body_opts),
|
||||
}, request_body),
|
||||
}
|
||||
end
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
local Utils = require("avante.utils")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local P = require("avante.providers")
|
||||
|
||||
---@alias AvanteBedrockPayloadBuilder fun(prompt_opts: AvantePromptOptions, body_opts: table<string, any>): table<string, any>
|
||||
@ -17,17 +16,14 @@ M.api_key_name = "BEDROCK_KEYS"
|
||||
M.use_xml_format = true
|
||||
|
||||
M.load_model_handler = function()
|
||||
local base, _ = P.parse_config(P["bedrock"])
|
||||
local bedrock_model = base.model
|
||||
if base.model:match("anthropic") then bedrock_model = "claude" end
|
||||
local provider_conf, _ = P.parse_config(P["bedrock"])
|
||||
local bedrock_model = provider_conf.model
|
||||
if provider_conf.model:match("anthropic") then bedrock_model = "claude" end
|
||||
|
||||
local ok, model_module = pcall(require, "avante.providers.bedrock." .. bedrock_model)
|
||||
if ok then
|
||||
return model_module
|
||||
else
|
||||
local error_msg = "Bedrock model handler not found: " .. bedrock_model
|
||||
Utils.error(error_msg, { once = true, title = "Avante" })
|
||||
end
|
||||
if ok then return model_module end
|
||||
local error_msg = "Bedrock model handler not found: " .. bedrock_model
|
||||
error(error_msg)
|
||||
end
|
||||
|
||||
M.parse_response = function(ctx, data_stream, event_state, opts)
|
||||
@ -46,8 +42,8 @@ M.parse_stream_data = function(data, opts)
|
||||
-- The `type` field in the decoded JSON determines how the response is handled.
|
||||
local bedrock_match = data:gmatch("event(%b{})")
|
||||
for bedrock_data_match in bedrock_match do
|
||||
local data = vim.json.decode(bedrock_data_match)
|
||||
local data_stream = vim.base64.decode(data.bytes)
|
||||
local jsn = vim.json.decode(bedrock_data_match)
|
||||
local data_stream = vim.base64.decode(jsn.bytes)
|
||||
local json = vim.json.decode(data_stream)
|
||||
M.parse_response({}, data_stream, json.type, opts)
|
||||
end
|
||||
@ -60,6 +56,7 @@ M.parse_curl_args = function(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
|
||||
local api_key = provider.parse_api_key()
|
||||
if api_key == nil then error("Cannot get the bedrock api key!") end
|
||||
local parts = vim.split(api_key, ",")
|
||||
local aws_access_key_id = parts[1]
|
||||
local aws_secret_access_key = parts[2]
|
||||
@ -108,7 +105,6 @@ M.on_error = function(result)
|
||||
end
|
||||
|
||||
local error_msg = body.error.message
|
||||
local error_type = body.error.type
|
||||
|
||||
Utils.error(error_msg, { once = true, title = "Avante" })
|
||||
end
|
||||
|
@ -226,7 +226,7 @@ end
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
---@return table
|
||||
M.parse_curl_args = function(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
|
||||
local headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
@ -234,7 +234,7 @@ M.parse_curl_args = function(provider, prompt_opts)
|
||||
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
||||
}
|
||||
|
||||
if P.env.require_api_key(base) then headers["x-api-key"] = provider.parse_api_key() end
|
||||
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end
|
||||
|
||||
local messages = M.parse_messages(prompt_opts)
|
||||
|
||||
@ -246,12 +246,12 @@ M.parse_curl_args = function(provider, prompt_opts)
|
||||
end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(base.endpoint, "/v1/messages"),
|
||||
proxy = base.proxy,
|
||||
insecure = base.allow_insecure,
|
||||
url = Utils.url_join(provider_conf.endpoint, "/v1/messages"),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = headers,
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = base.model,
|
||||
model = provider_conf.model,
|
||||
system = {
|
||||
{
|
||||
type = "text",
|
||||
@ -262,7 +262,7 @@ M.parse_curl_args = function(provider, prompt_opts)
|
||||
messages = messages,
|
||||
tools = tools,
|
||||
stream = true,
|
||||
}, body_opts),
|
||||
}, request_body),
|
||||
}
|
||||
end
|
||||
|
||||
|
@ -70,7 +70,7 @@ M.parse_stream_data = function(data, opts)
|
||||
end
|
||||
|
||||
M.parse_curl_args = function(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
|
||||
local headers = {
|
||||
["Accept"] = "application/json",
|
||||
@ -82,17 +82,17 @@ M.parse_curl_args = function(provider, prompt_opts)
|
||||
.. "."
|
||||
.. vim.version().patch,
|
||||
}
|
||||
if P.env.require_api_key(base) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
|
||||
if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(base.endpoint, "/chat"),
|
||||
proxy = base.proxy,
|
||||
insecure = base.allow_insecure,
|
||||
url = Utils.url_join(provider_conf.endpoint, "/chat"),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = headers,
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = base.model,
|
||||
model = provider_conf.model,
|
||||
stream = true,
|
||||
}, M.parse_messages(prompt_opts), body_opts),
|
||||
}, M.parse_messages(prompt_opts), request_body),
|
||||
}
|
||||
end
|
||||
|
||||
|
@ -249,7 +249,7 @@ M.parse_curl_args = function(provider, prompt_opts)
|
||||
-- (this should rarely happen, as we refresh the token in the background)
|
||||
H.refresh_token(false, false)
|
||||
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
|
||||
local tools = {}
|
||||
if prompt_opts.tools then
|
||||
@ -259,10 +259,10 @@ M.parse_curl_args = function(provider, prompt_opts)
|
||||
end
|
||||
|
||||
return {
|
||||
url = H.chat_completion_url(base.endpoint),
|
||||
timeout = base.timeout,
|
||||
proxy = base.proxy,
|
||||
insecure = base.allow_insecure,
|
||||
url = H.chat_completion_url(provider_conf.endpoint),
|
||||
timeout = provider_conf.timeout,
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
["Authorization"] = "Bearer " .. M.state.github_token.token,
|
||||
@ -270,11 +270,11 @@ M.parse_curl_args = function(provider, prompt_opts)
|
||||
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
|
||||
},
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = base.model,
|
||||
model = provider_conf.model,
|
||||
messages = M.parse_messages(prompt_opts),
|
||||
stream = true,
|
||||
tools = tools,
|
||||
}, body_opts),
|
||||
}, request_body),
|
||||
}
|
||||
end
|
||||
|
||||
|
@ -82,26 +82,29 @@ M.parse_response = function(ctx, data_stream, _, opts)
|
||||
end
|
||||
|
||||
M.parse_curl_args = function(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
|
||||
body_opts = vim.tbl_deep_extend("force", body_opts, {
|
||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||
generationConfig = {
|
||||
temperature = body_opts.temperature,
|
||||
maxOutputTokens = body_opts.max_tokens,
|
||||
temperature = request_body.temperature,
|
||||
maxOutputTokens = request_body.max_tokens,
|
||||
},
|
||||
})
|
||||
body_opts.temperature = nil
|
||||
body_opts.max_tokens = nil
|
||||
request_body.temperature = nil
|
||||
request_body.max_tokens = nil
|
||||
|
||||
local api_key = provider.parse_api_key()
|
||||
if api_key == nil then error("Cannot get the gemini api key!") end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(base.endpoint, base.model .. ":streamGenerateContent?alt=sse&key=" .. api_key),
|
||||
proxy = base.proxy,
|
||||
insecure = base.allow_insecure,
|
||||
url = Utils.url_join(
|
||||
provider_conf.endpoint,
|
||||
provider_conf.model .. ":streamGenerateContent?alt=sse&key=" .. api_key
|
||||
),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = { ["Content-Type"] = "application/json" },
|
||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts),
|
||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body),
|
||||
}
|
||||
end
|
||||
|
||||
|
@ -346,9 +346,9 @@ M = setmetatable(M, {
|
||||
if t[k].has == nil then t[k].has = function() return E.parse_envvar(t[k]) ~= nil end end
|
||||
|
||||
if t[k].setup == nil then
|
||||
local base = M.parse_config(t[k])
|
||||
local provider_conf = M.parse_config(t[k])
|
||||
t[k].setup = function()
|
||||
if E.require_api_key(base) then t[k].parse_api_key() end
|
||||
if E.require_api_key(provider_conf) then t[k].parse_api_key() end
|
||||
require("avante.tokenizers").setup(t[k].tokenizer_id)
|
||||
end
|
||||
end
|
||||
|
@ -275,14 +275,14 @@ M.parse_response_without_stream = function(data, _, opts)
|
||||
end
|
||||
|
||||
M.parse_curl_args = function(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
local disable_tools = base.disable_tools or false
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
local headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
}
|
||||
|
||||
if P.env.require_api_key(base) then
|
||||
if P.env.require_api_key(provider_conf) then
|
||||
local api_key = provider.parse_api_key()
|
||||
if api_key == nil then
|
||||
error(Config.provider .. " API key is not set, please set it in your environment variable or config file")
|
||||
@ -290,18 +290,18 @@ M.parse_curl_args = function(provider, prompt_opts)
|
||||
headers["Authorization"] = "Bearer " .. api_key
|
||||
end
|
||||
|
||||
if M.is_openrouter(base.endpoint) then
|
||||
if M.is_openrouter(provider_conf.endpoint) then
|
||||
headers["HTTP-Referer"] = "https://github.com/yetone/avante.nvim"
|
||||
headers["X-Title"] = "Avante.nvim"
|
||||
body_opts.include_reasoning = true
|
||||
request_body.include_reasoning = true
|
||||
end
|
||||
|
||||
-- NOTE: When using "o" series set the supported parameters only
|
||||
local stream = true
|
||||
if M.is_o_series_model(base.model) then
|
||||
body_opts.max_completion_tokens = body_opts.max_tokens
|
||||
body_opts.max_tokens = nil
|
||||
body_opts.temperature = 1
|
||||
if M.is_o_series_model(provider_conf.model) then
|
||||
request_body.max_completion_tokens = request_body.max_tokens
|
||||
request_body.max_tokens = nil
|
||||
request_body.temperature = 1
|
||||
end
|
||||
|
||||
local tools = nil
|
||||
@ -312,20 +312,20 @@ M.parse_curl_args = function(provider, prompt_opts)
|
||||
end
|
||||
end
|
||||
|
||||
Utils.debug("endpoint", base.endpoint)
|
||||
Utils.debug("model", base.model)
|
||||
Utils.debug("endpoint", provider_conf.endpoint)
|
||||
Utils.debug("model", provider_conf.model)
|
||||
|
||||
return {
|
||||
url = Utils.url_join(base.endpoint, "/chat/completions"),
|
||||
proxy = base.proxy,
|
||||
insecure = base.allow_insecure,
|
||||
url = Utils.url_join(provider_conf.endpoint, "/chat/completions"),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
headers = headers,
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = base.model,
|
||||
model = provider_conf.model,
|
||||
messages = M.parse_messages(prompt_opts),
|
||||
stream = stream,
|
||||
tools = tools,
|
||||
}, body_opts),
|
||||
}, request_body),
|
||||
}
|
||||
end
|
||||
|
||||
|
@ -32,22 +32,22 @@ M.parse_api_key = function()
|
||||
end
|
||||
|
||||
M.parse_curl_args = function(provider, prompt_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
local location = vim.fn.getenv("LOCATION") or "default-location"
|
||||
local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id"
|
||||
local model_id = base.model or "default-model-id"
|
||||
local url = base.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id)
|
||||
local model_id = provider_conf.model or "default-model-id"
|
||||
local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id)
|
||||
|
||||
url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id)
|
||||
|
||||
body_opts = vim.tbl_deep_extend("force", body_opts, {
|
||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||
generationConfig = {
|
||||
temperature = body_opts.temperature,
|
||||
maxOutputTokens = body_opts.max_tokens,
|
||||
temperature = request_body.temperature,
|
||||
maxOutputTokens = request_body.max_tokens,
|
||||
},
|
||||
})
|
||||
body_opts.temperature = nil
|
||||
body_opts.max_tokens = nil
|
||||
request_body.temperature = nil
|
||||
request_body.max_tokens = nil
|
||||
local bearer_token = M.parse_api_key()
|
||||
|
||||
return {
|
||||
@ -56,9 +56,9 @@ M.parse_curl_args = function(provider, prompt_opts)
|
||||
["Authorization"] = "Bearer " .. bearer_token,
|
||||
["Content-Type"] = "application/json; charset=utf-8",
|
||||
},
|
||||
proxy = base.proxy,
|
||||
insecure = base.allow_insecure,
|
||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts),
|
||||
proxy = provider_conf.proxy,
|
||||
insecure = provider_conf.allow_insecure,
|
||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body),
|
||||
}
|
||||
end
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user