Feat: Add Amazon Bedrock provider (#1167)
This commit is contained in:
parent
cd7390de21
commit
43269cc07f
@ -425,6 +425,12 @@ Given its early stage, `avante.nvim` currently supports the following basic func
|
|||||||
> ```sh
|
> ```sh
|
||||||
> export AZURE_OPENAI_API_KEY=your-api-key
|
> export AZURE_OPENAI_API_KEY=your-api-key
|
||||||
> ```
|
> ```
|
||||||
|
>
|
||||||
|
> For Amazon Bedrock:
|
||||||
|
>
|
||||||
|
> ```sh
|
||||||
|
> export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region
|
||||||
|
> ```
|
||||||
|
|
||||||
1. Open a code file in Neovim.
|
1. Open a code file in Neovim.
|
||||||
2. Use the `:AvanteAsk` command to query the AI about the code.
|
2. Use the `:AvanteAsk` command to query the AI about the code.
|
||||||
|
@ -56,6 +56,13 @@ M._defaults = {
|
|||||||
max_tokens = 8000,
|
max_tokens = 8000,
|
||||||
},
|
},
|
||||||
---@type AvanteSupportedProvider
|
---@type AvanteSupportedProvider
|
||||||
|
bedrock = {
|
||||||
|
model = "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
timeout = 30000, -- Timeout in milliseconds
|
||||||
|
temperature = 0,
|
||||||
|
max_tokens = 8000,
|
||||||
|
},
|
||||||
|
---@type AvanteSupportedProvider
|
||||||
gemini = {
|
gemini = {
|
||||||
endpoint = "https://generativelanguage.googleapis.com/v1beta/models",
|
endpoint = "https://generativelanguage.googleapis.com/v1beta/models",
|
||||||
model = "gemini-1.5-flash-latest",
|
model = "gemini-1.5-flash-latest",
|
||||||
|
@ -23,7 +23,7 @@ local group = api.nvim_create_augroup("avante_llm", { clear = true })
|
|||||||
M.generate_prompts = function(opts)
|
M.generate_prompts = function(opts)
|
||||||
local Provider = opts.provider or P[Config.provider]
|
local Provider = opts.provider or P[Config.provider]
|
||||||
local mode = opts.mode or "planning"
|
local mode = opts.mode or "planning"
|
||||||
---@type AvanteProviderFunctor
|
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||||
local _, body_opts = P.parse_config(Provider)
|
local _, body_opts = P.parse_config(Provider)
|
||||||
local max_tokens = body_opts.max_tokens or 4096
|
local max_tokens = body_opts.max_tokens or 4096
|
||||||
|
|
||||||
@ -380,7 +380,7 @@ end
|
|||||||
---@field ask boolean
|
---@field ask boolean
|
||||||
---@field instructions string
|
---@field instructions string
|
||||||
---@field mode LlmMode
|
---@field mode LlmMode
|
||||||
---@field provider AvanteProviderFunctor | nil
|
---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil
|
||||||
---
|
---
|
||||||
---@class StreamOptions: GeneratePromptsOptions
|
---@class StreamOptions: GeneratePromptsOptions
|
||||||
---@field on_chunk AvanteChunkParser
|
---@field on_chunk AvanteChunkParser
|
||||||
|
113
lua/avante/providers/bedrock.lua
Normal file
113
lua/avante/providers/bedrock.lua
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
local Utils = require("avante.utils")
|
||||||
|
local Clipboard = require("avante.clipboard")
|
||||||
|
local P = require("avante.providers")
|
||||||
|
|
||||||
|
---@alias AvanteBedrockPayloadBuilder fun(prompt_opts: AvantePromptOptions, body_opts: table<string, any>): table<string, any>
|
||||||
|
---
|
||||||
|
---@class AvanteBedrockModelHandler
|
||||||
|
---@field role_map table<"user" | "assistant", string>
|
||||||
|
---@field parse_messages AvanteMessagesParser
|
||||||
|
---@field parse_response AvanteResponseParser
|
||||||
|
---@field build_bedrock_payload AvanteBedrockPayloadBuilder
|
||||||
|
|
||||||
|
---@class AvanteBedrockProviderFunctor
|
||||||
|
local M = {}
|
||||||
|
|
||||||
|
M.api_key_name = "BEDROCK_KEYS"
|
||||||
|
M.use_xml_format = true
|
||||||
|
|
||||||
|
M.load_model_handler = function()
|
||||||
|
local base, _ = P.parse_config(P["bedrock"])
|
||||||
|
local bedrock_model = base.model
|
||||||
|
if base.model:match("anthropic") then bedrock_model = "claude" end
|
||||||
|
|
||||||
|
local ok, model_module = pcall(require, "avante.providers.bedrock." .. bedrock_model)
|
||||||
|
if ok then
|
||||||
|
return model_module
|
||||||
|
else
|
||||||
|
local error_msg = "Bedrock model handler not found: " .. bedrock_model
|
||||||
|
Utils.error(error_msg, { once = true, title = "Avante" })
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
M.parse_response = function(ctx, data_stream, event_state, opts)
|
||||||
|
local model_handler = M.load_model_handler()
|
||||||
|
return model_handler.parse_response(ctx, data_stream, event_state, opts)
|
||||||
|
end
|
||||||
|
|
||||||
|
M.build_bedrock_payload = function(prompt_opts, body_opts)
|
||||||
|
local model_handler = M.load_model_handler()
|
||||||
|
return model_handler.build_bedrock_payload(prompt_opts, body_opts)
|
||||||
|
end
|
||||||
|
|
||||||
|
M.parse_stream_data = function(data, opts)
|
||||||
|
-- @NOTE: Decode and process Bedrock response
|
||||||
|
-- Each response contains a Base64-encoded `bytes` field, which is decoded into JSON.
|
||||||
|
-- The `type` field in the decoded JSON determines how the response is handled.
|
||||||
|
local bedrock_match = data:gmatch("event(%b{})")
|
||||||
|
for bedrock_data_match in bedrock_match do
|
||||||
|
local data = vim.json.decode(bedrock_data_match)
|
||||||
|
local data_stream = vim.base64.decode(data.bytes)
|
||||||
|
local json = vim.json.decode(data_stream)
|
||||||
|
M.parse_response({}, data_stream, json.type, opts)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param provider AvanteBedrockProviderFunctor
|
||||||
|
---@param prompt_opts AvantePromptOptions
|
||||||
|
---@return table
|
||||||
|
M.parse_curl_args = function(provider, prompt_opts)
|
||||||
|
local base, body_opts = P.parse_config(provider)
|
||||||
|
|
||||||
|
local api_key = provider.parse_api_key()
|
||||||
|
local parts = vim.split(api_key, ",")
|
||||||
|
local aws_access_key_id = parts[1]
|
||||||
|
local aws_secret_access_key = parts[2]
|
||||||
|
local aws_region = parts[3]
|
||||||
|
|
||||||
|
local endpoint = string.format(
|
||||||
|
"https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream",
|
||||||
|
aws_region,
|
||||||
|
base.model
|
||||||
|
)
|
||||||
|
|
||||||
|
local headers = {
|
||||||
|
["Content-Type"] = "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
local body_payload = M.build_bedrock_payload(prompt_opts, body_opts)
|
||||||
|
|
||||||
|
local rawArgs = {
|
||||||
|
"--aws-sigv4",
|
||||||
|
string.format("aws:amz:%s:bedrock", aws_region),
|
||||||
|
"--user",
|
||||||
|
string.format("%s:%s", aws_access_key_id, aws_secret_access_key),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
url = endpoint,
|
||||||
|
proxy = base.proxy,
|
||||||
|
insecure = base.allow_insecure,
|
||||||
|
headers = headers,
|
||||||
|
body = body_payload,
|
||||||
|
rawArgs = rawArgs,
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
M.on_error = function(result)
|
||||||
|
if not result.body then
|
||||||
|
return Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
|
||||||
|
end
|
||||||
|
|
||||||
|
local ok, body = pcall(vim.json.decode, result.body)
|
||||||
|
if not (ok and body and body.error) then
|
||||||
|
return Utils.error("Failed to parse error response", { once = true, title = "Avante" })
|
||||||
|
end
|
||||||
|
|
||||||
|
local error_msg = body.error.message
|
||||||
|
local error_type = body.error.type
|
||||||
|
|
||||||
|
Utils.error(error_msg, { once = true, title = "Avante" })
|
||||||
|
end
|
||||||
|
|
||||||
|
return M
|
73
lua/avante/providers/bedrock/claude.lua
Normal file
73
lua/avante/providers/bedrock/claude.lua
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
---@class AvanteBedrockClaudeTextMessage
|
||||||
|
---@field type "text"
|
||||||
|
---@field text string
|
||||||
|
---
|
||||||
|
---@class AvanteBedrockClaudeMessage
|
||||||
|
---@field role "user" | "assistant"
|
||||||
|
---@field content [AvanteBedrockClaudeTextMessage][]
|
||||||
|
|
||||||
|
---@class AvanteBedrockModelHandler
|
||||||
|
local M = {}
|
||||||
|
|
||||||
|
M.role_map = {
|
||||||
|
user = "user",
|
||||||
|
assistant = "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
|
M.parse_messages = function(opts)
|
||||||
|
---@type AvanteBedrockClaudeMessage[]
|
||||||
|
local messages = {}
|
||||||
|
|
||||||
|
for _, message in ipairs(opts.messages) do
|
||||||
|
table.insert(messages, {
|
||||||
|
role = M.role_map[message.role],
|
||||||
|
content = {
|
||||||
|
{
|
||||||
|
type = "text",
|
||||||
|
text = message.content,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
end
|
||||||
|
|
||||||
|
return messages
|
||||||
|
end
|
||||||
|
|
||||||
|
M.parse_response = function(ctx, data_stream, event_state, opts)
|
||||||
|
if event_state == nil then
|
||||||
|
if data_stream:match('"content_block_delta"') then
|
||||||
|
event_state = "content_block_delta"
|
||||||
|
elseif data_stream:match('"message_stop"') then
|
||||||
|
event_state = "message_stop"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if event_state == "content_block_delta" then
|
||||||
|
local ok, json = pcall(vim.json.decode, data_stream)
|
||||||
|
if not ok then return end
|
||||||
|
opts.on_chunk(json.delta.text)
|
||||||
|
elseif event_state == "message_stop" then
|
||||||
|
opts.on_complete(nil)
|
||||||
|
return
|
||||||
|
elseif event_state == "error" then
|
||||||
|
opts.on_complete(vim.json.decode(data_stream))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param prompt_opts AvantePromptOptions
|
||||||
|
---@param body_opts table
|
||||||
|
---@return table
|
||||||
|
M.build_bedrock_payload = function(prompt_opts, body_opts)
|
||||||
|
local system_prompt = prompt_opts.system_prompt or ""
|
||||||
|
local messages = M.parse_messages(prompt_opts)
|
||||||
|
local max_tokens = body_opts.max_tokens or 2000
|
||||||
|
local temperature = body_opts.temperature or 0.7
|
||||||
|
local payload = {
|
||||||
|
anthropic_version = "bedrock-2023-05-31",
|
||||||
|
max_tokens = max_tokens,
|
||||||
|
messages = messages,
|
||||||
|
system = system_prompt,
|
||||||
|
}
|
||||||
|
return vim.tbl_deep_extend("force", payload, body_opts or {})
|
||||||
|
end
|
||||||
|
|
||||||
|
return M
|
@ -32,7 +32,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
|
|||||||
---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
|
---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
|
||||||
---
|
---
|
||||||
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
|
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
|
||||||
---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput
|
---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||||
---
|
---
|
||||||
---@class ResponseParser
|
---@class ResponseParser
|
||||||
---@field on_chunk fun(chunk: string): any
|
---@field on_chunk fun(chunk: string): any
|
||||||
@ -80,6 +80,21 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
|
|||||||
---@field parse_stream_data? AvanteStreamParser
|
---@field parse_stream_data? AvanteStreamParser
|
||||||
---@field on_error? fun(result: table<string, any>): nil
|
---@field on_error? fun(result: table<string, any>): nil
|
||||||
---
|
---
|
||||||
|
---@class AvanteBedrockProviderFunctor
|
||||||
|
---@field parse_response AvanteResponseParser
|
||||||
|
---@field parse_curl_args AvanteCurlArgsParser
|
||||||
|
---@field setup fun(): nil
|
||||||
|
---@field has fun(): boolean
|
||||||
|
---@field api_key_name string
|
||||||
|
---@field tokenizer_id string | "gpt-4o"
|
||||||
|
---@field use_xml_format boolean
|
||||||
|
---@field model? string
|
||||||
|
---@field parse_api_key fun(): string | nil
|
||||||
|
---@field parse_stream_data? AvanteStreamParser
|
||||||
|
---@field on_error? fun(result: table<string, any>): nil
|
||||||
|
---@field load_model_handler fun(): AvanteBedrockModelHandler
|
||||||
|
---@field build_bedrock_payload? fun(prompt_opts: AvantePromptOptions, body_opts: table<string, any>): table<string, any>
|
||||||
|
---
|
||||||
---@class avante.Providers
|
---@class avante.Providers
|
||||||
---@field openai AvanteProviderFunctor
|
---@field openai AvanteProviderFunctor
|
||||||
---@field claude AvanteProviderFunctor
|
---@field claude AvanteProviderFunctor
|
||||||
@ -87,6 +102,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
|
|||||||
---@field azure AvanteProviderFunctor
|
---@field azure AvanteProviderFunctor
|
||||||
---@field gemini AvanteProviderFunctor
|
---@field gemini AvanteProviderFunctor
|
||||||
---@field cohere AvanteProviderFunctor
|
---@field cohere AvanteProviderFunctor
|
||||||
|
---@field bedrock AvanteBedrockProviderFunctor
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
---@class EnvironmentHandler
|
---@class EnvironmentHandler
|
||||||
@ -96,7 +112,7 @@ local E = {}
|
|||||||
---@type table<string, string>
|
---@type table<string, string>
|
||||||
E.cache = {}
|
E.cache = {}
|
||||||
|
|
||||||
---@param Opts AvanteSupportedProvider | AvanteProviderFunctor
|
---@param Opts AvanteSupportedProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||||
---@return string | nil
|
---@return string | nil
|
||||||
E.parse_envvar = function(Opts)
|
E.parse_envvar = function(Opts)
|
||||||
local api_key_name = Opts.api_key_name
|
local api_key_name = Opts.api_key_name
|
||||||
@ -158,7 +174,7 @@ end
|
|||||||
|
|
||||||
--- initialize the environment variable for current neovim session.
|
--- initialize the environment variable for current neovim session.
|
||||||
--- This will only run once and spawn a UI for users to input the envvar.
|
--- This will only run once and spawn a UI for users to input the envvar.
|
||||||
---@param opts {refresh: boolean, provider: AvanteProviderFunctor}
|
---@param opts {refresh: boolean, provider: AvanteProviderFunctor | AvanteBedrockProviderFunctor}
|
||||||
---@private
|
---@private
|
||||||
E.setup = function(opts)
|
E.setup = function(opts)
|
||||||
opts.provider.setup()
|
opts.provider.setup()
|
||||||
@ -267,7 +283,7 @@ M = setmetatable(M, {
|
|||||||
---@param t avante.Providers
|
---@param t avante.Providers
|
||||||
---@param k Provider
|
---@param k Provider
|
||||||
__index = function(t, k)
|
__index = function(t, k)
|
||||||
---@type AvanteProviderFunctor
|
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||||
local Opts = M.get_config(k)
|
local Opts = M.get_config(k)
|
||||||
|
|
||||||
---@diagnostic disable: undefined-field,no-unknown,inject-field
|
---@diagnostic disable: undefined-field,no-unknown,inject-field
|
||||||
@ -311,7 +327,7 @@ M = setmetatable(M, {
|
|||||||
M.setup = function()
|
M.setup = function()
|
||||||
vim.g.avante_login = false
|
vim.g.avante_login = false
|
||||||
|
|
||||||
---@type AvanteProviderFunctor
|
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||||
local provider = M[Config.provider]
|
local provider = M[Config.provider]
|
||||||
local auto_suggestions_provider = M[Config.auto_suggestions_provider]
|
local auto_suggestions_provider = M[Config.auto_suggestions_provider]
|
||||||
E.setup({ provider = provider })
|
E.setup({ provider = provider })
|
||||||
@ -325,13 +341,13 @@ end
|
|||||||
function M.refresh(provider)
|
function M.refresh(provider)
|
||||||
require("avante.config").override({ provider = provider })
|
require("avante.config").override({ provider = provider })
|
||||||
|
|
||||||
---@type AvanteProviderFunctor
|
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||||
local p = M[Config.provider]
|
local p = M[Config.provider]
|
||||||
E.setup({ provider = p, refresh = true })
|
E.setup({ provider = p, refresh = true })
|
||||||
Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" })
|
Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" })
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param opts AvanteProvider | AvanteSupportedProvider | AvanteProviderFunctor
|
---@param opts AvanteProvider | AvanteSupportedProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||||
---@return AvanteDefaultBaseProvider, table<string, any>
|
---@return AvanteDefaultBaseProvider, table<string, any>
|
||||||
M.parse_config = function(opts)
|
M.parse_config = function(opts)
|
||||||
---@type AvanteDefaultBaseProvider
|
---@type AvanteDefaultBaseProvider
|
||||||
@ -356,7 +372,7 @@ end
|
|||||||
|
|
||||||
---@private
|
---@private
|
||||||
---@param provider Provider
|
---@param provider Provider
|
||||||
---@return AvanteProviderFunctor
|
---@return AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||||
M.get_config = function(provider)
|
M.get_config = function(provider)
|
||||||
provider = provider or Config.provider
|
provider = provider or Config.provider
|
||||||
local cur = Config.get_provider(provider)
|
local cur = Config.get_provider(provider)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user