diff --git a/README.md b/README.md index d7e8b85..de40377 100644 --- a/README.md +++ b/README.md @@ -425,6 +425,12 @@ Given its early stage, `avante.nvim` currently supports the following basic func > ```sh > export AZURE_OPENAI_API_KEY=your-api-key > ``` +> +> For Amazon Bedrock: +> +> ```sh +> export BEDROCK_KEYS=aws_access_key_id,aws_secret_access_key,aws_region +> ``` 1. Open a code file in Neovim. 2. Use the `:AvanteAsk` command to query the AI about the code. diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 163dc43..08e772d 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -56,6 +56,13 @@ M._defaults = { max_tokens = 8000, }, ---@type AvanteSupportedProvider + bedrock = { + model = "anthropic.claude-3-5-sonnet-20240620-v1:0", + timeout = 30000, -- Timeout in milliseconds + temperature = 0, + max_tokens = 8000, + }, + ---@type AvanteSupportedProvider gemini = { endpoint = "https://generativelanguage.googleapis.com/v1beta/models", model = "gemini-1.5-flash-latest", diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index fb2ed72..cd5d1bc 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -23,7 +23,7 @@ local group = api.nvim_create_augroup("avante_llm", { clear = true }) M.generate_prompts = function(opts) local Provider = opts.provider or P[Config.provider] local mode = opts.mode or "planning" - ---@type AvanteProviderFunctor + ---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor local _, body_opts = P.parse_config(Provider) local max_tokens = body_opts.max_tokens or 4096 @@ -380,7 +380,7 @@ end ---@field ask boolean ---@field instructions string ---@field mode LlmMode ----@field provider AvanteProviderFunctor | nil +---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil --- ---@class StreamOptions: GeneratePromptsOptions ---@field on_chunk AvanteChunkParser diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua new file mode 100644 index 0000000..2b0cd7e --- /dev/null +++ b/lua/avante/providers/bedrock.lua @@ -0,0 +1,113 @@ +local Utils = require("avante.utils") +local Clipboard = require("avante.clipboard") +local P = require("avante.providers") + +---@alias AvanteBedrockPayloadBuilder fun(prompt_opts: AvantePromptOptions, body_opts: table): table +--- +---@class AvanteBedrockModelHandler +---@field role_map table<"user" | "assistant", string> +---@field parse_messages AvanteMessagesParser +---@field parse_response AvanteResponseParser +---@field build_bedrock_payload AvanteBedrockPayloadBuilder + +---@class AvanteBedrockProviderFunctor +local M = {} + +M.api_key_name = "BEDROCK_KEYS" +M.use_xml_format = true + +M.load_model_handler = function() + local base, _ = P.parse_config(P["bedrock"]) + local bedrock_model = base.model + if base.model:match("anthropic") then bedrock_model = "claude" end + + local ok, model_module = pcall(require, "avante.providers.bedrock." .. bedrock_model) + if ok then + return model_module + else + local error_msg = "Bedrock model handler not found: " .. bedrock_model + Utils.error(error_msg, { once = true, title = "Avante" }) + end +end + +M.parse_response = function(ctx, data_stream, event_state, opts) + local model_handler = M.load_model_handler() + return model_handler.parse_response(ctx, data_stream, event_state, opts) +end + +M.build_bedrock_payload = function(prompt_opts, body_opts) + local model_handler = M.load_model_handler() + return model_handler.build_bedrock_payload(prompt_opts, body_opts) +end + +M.parse_stream_data = function(data, opts) + -- @NOTE: Decode and process Bedrock response + -- Each response contains a Base64-encoded `bytes` field, which is decoded into JSON. + -- The `type` field in the decoded JSON determines how the response is handled. + local bedrock_match = data:gmatch("event(%b{})") + for bedrock_data_match in bedrock_match do + local data = vim.json.decode(bedrock_data_match) + local data_stream = vim.base64.decode(data.bytes) + local json = vim.json.decode(data_stream) + M.parse_response({}, data_stream, json.type, opts) + end +end + +---@param provider AvanteBedrockProviderFunctor +---@param prompt_opts AvantePromptOptions +---@return table +M.parse_curl_args = function(provider, prompt_opts) + local base, body_opts = P.parse_config(provider) + + local api_key = provider.parse_api_key() + local parts = vim.split(api_key, ",") + local aws_access_key_id = parts[1] + local aws_secret_access_key = parts[2] + local aws_region = parts[3] + + local endpoint = string.format( + "https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", + aws_region, + base.model + ) + + local headers = { + ["Content-Type"] = "application/json", + } + + local body_payload = M.build_bedrock_payload(prompt_opts, body_opts) + + local rawArgs = { + "--aws-sigv4", + string.format("aws:amz:%s:bedrock", aws_region), + "--user", + string.format("%s:%s", aws_access_key_id, aws_secret_access_key), + } + + return { + url = endpoint, + proxy = base.proxy, + insecure = base.allow_insecure, + headers = headers, + body = body_payload, + rawArgs = rawArgs, + } +end + +M.on_error = function(result) + if not result.body then + return Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" }) + end + + local ok, body = pcall(vim.json.decode, result.body) + if not (ok and body and body.error) then + return Utils.error("Failed to parse error response", { once = true, title = "Avante" }) + end + + local error_msg = body.error.message + local error_type = body.error.type + + Utils.error(error_msg, { once = true, title = "Avante" }) +end + +return M diff --git a/lua/avante/providers/bedrock/claude.lua b/lua/avante/providers/bedrock/claude.lua new file mode 100644 index 0000000..47c6b40 --- /dev/null +++ b/lua/avante/providers/bedrock/claude.lua @@ -0,0 +1,73 @@ +---@class AvanteBedrockClaudeTextMessage +---@field type "text" +---@field text string +--- +---@class AvanteBedrockClaudeMessage +---@field role "user" | "assistant" +---@field content [AvanteBedrockClaudeTextMessage][] + +---@class AvanteBedrockModelHandler +local M = {} + +M.role_map = { + user = "user", + assistant = "assistant", +} + +M.parse_messages = function(opts) + ---@type AvanteBedrockClaudeMessage[] + local messages = {} + + for _, message in ipairs(opts.messages) do + table.insert(messages, { + role = M.role_map[message.role], + content = { + { + type = "text", + text = message.content, + }, + }, + }) + end + + return messages +end + +M.parse_response = function(ctx, data_stream, event_state, opts) + if event_state == nil then + if data_stream:match('"content_block_delta"') then + event_state = "content_block_delta" + elseif data_stream:match('"message_stop"') then + event_state = "message_stop" + end + end + if event_state == "content_block_delta" then + local ok, json = pcall(vim.json.decode, data_stream) + if not ok then return end + opts.on_chunk(json.delta.text) + elseif event_state == "message_stop" then + opts.on_complete(nil) + return + elseif event_state == "error" then + opts.on_complete(vim.json.decode(data_stream)) + end +end + +---@param prompt_opts AvantePromptOptions +---@param body_opts table +---@return table +M.build_bedrock_payload = function(prompt_opts, body_opts) + local system_prompt = prompt_opts.system_prompt or "" + local messages = M.parse_messages(prompt_opts) + local max_tokens = body_opts.max_tokens or 2000 + local temperature = body_opts.temperature or 0.7 + local payload = { + anthropic_version = "bedrock-2023-05-31", + max_tokens = max_tokens, + messages = messages, + system = system_prompt, + } + return vim.tbl_deep_extend("force", payload, body_opts or {}) +end + +return M diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 375c458..1802f3a 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -32,7 +32,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[] --- ---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table | string, headers: table, rawArgs: string[] | nil} ----@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput +---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput --- ---@class ResponseParser ---@field on_chunk fun(chunk: string): any @@ -80,6 +80,21 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field parse_stream_data? AvanteStreamParser ---@field on_error? fun(result: table): nil --- +---@class AvanteBedrockProviderFunctor +---@field parse_response AvanteResponseParser +---@field parse_curl_args AvanteCurlArgsParser +---@field setup fun(): nil +---@field has fun(): boolean +---@field api_key_name string +---@field tokenizer_id string | "gpt-4o" +---@field use_xml_format boolean +---@field model? string +---@field parse_api_key fun(): string | nil +---@field parse_stream_data? AvanteStreamParser +---@field on_error? fun(result: table): nil +---@field load_model_handler fun(): AvanteBedrockModelHandler +---@field build_bedrock_payload? fun(prompt_opts: AvantePromptOptions, body_opts: table): table +--- ---@class avante.Providers ---@field openai AvanteProviderFunctor ---@field claude AvanteProviderFunctor @@ -87,6 +102,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field azure AvanteProviderFunctor ---@field gemini AvanteProviderFunctor ---@field cohere AvanteProviderFunctor +---@field bedrock AvanteBedrockProviderFunctor local M = {} ---@class EnvironmentHandler @@ -96,7 +112,7 @@ local E = {} ---@type table E.cache = {} ----@param Opts AvanteSupportedProvider | AvanteProviderFunctor +---@param Opts AvanteSupportedProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor ---@return string | nil E.parse_envvar = function(Opts) local api_key_name = Opts.api_key_name @@ -158,7 +174,7 @@ end --- initialize the environment variable for current neovim session. --- This will only run once and spawn a UI for users to input the envvar. ----@param opts {refresh: boolean, provider: AvanteProviderFunctor} +---@param opts {refresh: boolean, provider: AvanteProviderFunctor | AvanteBedrockProviderFunctor} ---@private E.setup = function(opts) opts.provider.setup() @@ -267,7 +283,7 @@ M = setmetatable(M, { ---@param t avante.Providers ---@param k Provider __index = function(t, k) - ---@type AvanteProviderFunctor + ---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor local Opts = M.get_config(k) ---@diagnostic disable: undefined-field,no-unknown,inject-field @@ -311,7 +327,7 @@ M = setmetatable(M, { M.setup = function() vim.g.avante_login = false - ---@type AvanteProviderFunctor + ---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor local provider = M[Config.provider] local auto_suggestions_provider = M[Config.auto_suggestions_provider] E.setup({ provider = provider }) @@ -325,13 +341,13 @@ end function M.refresh(provider) require("avante.config").override({ provider = provider }) - ---@type AvanteProviderFunctor + ---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor local p = M[Config.provider] E.setup({ provider = p, refresh = true }) Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" }) end ----@param opts AvanteProvider | AvanteSupportedProvider | AvanteProviderFunctor +---@param opts AvanteProvider | AvanteSupportedProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor ---@return AvanteDefaultBaseProvider, table M.parse_config = function(opts) ---@type AvanteDefaultBaseProvider @@ -356,7 +372,7 @@ end ---@private ---@param provider Provider ----@return AvanteProviderFunctor +---@return AvanteProviderFunctor | AvanteBedrockProviderFunctor M.get_config = function(provider) provider = provider or Config.provider local cur = Config.get_provider(provider)