diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 21d4d91..0b679fe 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -5,11 +5,10 @@ local Utils = require("avante.utils") ---@class avante.CoreConfig: avante.Config local M = {} - ---@class avante.Config M.defaults = { debug = false, - ---@alias Provider "claude" | "openai" | "azure" | "gemini" | "cohere" | "copilot" | [string] + ---@alias Provider "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | [string] provider = "claude", -- Only recommend using Claude auto_suggestions_provider = "claude", ---@alias Tokenizer "tiktoken" | "hf" @@ -66,6 +65,15 @@ M.defaults = { ["local"] = false, }, ---@type AvanteSupportedProvider + vertex = { + endpoint = "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/google/models", + model = "gemini-1.5-flash-latest", + timeout = 30000, -- Timeout in milliseconds + temperature = 0, + max_tokens = 4096, + ["local"] = false, + }, + ---@type AvanteSupportedProvider cohere = { endpoint = "https://api.cohere.com/v2", model = "command-r-plus-08-2024", diff --git a/lua/avante/providers/vertex.lua b/lua/avante/providers/vertex.lua new file mode 100644 index 0000000..67a4b2b --- /dev/null +++ b/lua/avante/providers/vertex.lua @@ -0,0 +1,62 @@ +local P = require("avante.providers") +local Gemini = require("avante.providers.gemini") + +---@class AvanteProviderFunctor +local M = {} + +M.api_key_name = "cmd:gcloud auth application-default print-access-token" + +M.role_map = { + user = "user", + assistant = "model", +} + +M.parse_messages = Gemini.parse_messages +M.parse_response = Gemini.parse_response + +local function execute_command(command) + local handle = io.popen(command) + local result = handle:read("*a") + handle:close() + return result:match("^%s*(.-)%s*$") +end + +M.parse_api_key = function() + if M.api_key_name:match("^cmd:") then + local command = M.api_key_name:sub(5) + return execute_command(command) + else + return vim.fn.getenv(M.api_key_name) + end +end + +M.parse_curl_args = function(provider, code_opts) + local base, body_opts = P.parse_config(provider) + local location = vim.fn.getenv("LOCATION") or "default-location" + local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id" + local model_id = base.model or "default-model-id" + local url = base.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id) + + url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id) + + body_opts = vim.tbl_deep_extend("force", body_opts, { + generationConfig = { + temperature = body_opts.temperature, + maxOutputTokens = body_opts.max_tokens, + }, + }) + body_opts.temperature = nil + body_opts.max_tokens = nil + return { + url = url, + headers = { + ["Authorization"] = "Bearer " .. M.parse_api_key(), + ["Content-Type"] = "application/json; charset=utf-8", + }, + proxy = base.proxy, + insecure = base.allow_insecure, + body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts), + } +end + +return M