feat: ✨ Added vertex AI provider for orgs using gemini (#840)
Co-authored-by: Shourya Sharma <shourya.sharma@complyadvantage.com>
This commit is contained in:
parent
af33f3a602
commit
839a8ee25a
@ -5,11 +5,10 @@ local Utils = require("avante.utils")
|
|||||||
|
|
||||||
---@class avante.CoreConfig: avante.Config
|
---@class avante.CoreConfig: avante.Config
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
---@class avante.Config
|
---@class avante.Config
|
||||||
M.defaults = {
|
M.defaults = {
|
||||||
debug = false,
|
debug = false,
|
||||||
---@alias Provider "claude" | "openai" | "azure" | "gemini" | "cohere" | "copilot" | [string]
|
---@alias Provider "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | [string]
|
||||||
provider = "claude", -- Only recommend using Claude
|
provider = "claude", -- Only recommend using Claude
|
||||||
auto_suggestions_provider = "claude",
|
auto_suggestions_provider = "claude",
|
||||||
---@alias Tokenizer "tiktoken" | "hf"
|
---@alias Tokenizer "tiktoken" | "hf"
|
||||||
@ -66,6 +65,15 @@ M.defaults = {
|
|||||||
["local"] = false,
|
["local"] = false,
|
||||||
},
|
},
|
||||||
---@type AvanteSupportedProvider
|
---@type AvanteSupportedProvider
|
||||||
|
vertex = {
|
||||||
|
endpoint = "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/publishers/google/models",
|
||||||
|
model = "gemini-1.5-flash-latest",
|
||||||
|
timeout = 30000, -- Timeout in milliseconds
|
||||||
|
temperature = 0,
|
||||||
|
max_tokens = 4096,
|
||||||
|
["local"] = false,
|
||||||
|
},
|
||||||
|
---@type AvanteSupportedProvider
|
||||||
cohere = {
|
cohere = {
|
||||||
endpoint = "https://api.cohere.com/v2",
|
endpoint = "https://api.cohere.com/v2",
|
||||||
model = "command-r-plus-08-2024",
|
model = "command-r-plus-08-2024",
|
||||||
|
62
lua/avante/providers/vertex.lua
Normal file
62
lua/avante/providers/vertex.lua
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
local P = require("avante.providers")
|
||||||
|
local Gemini = require("avante.providers.gemini")
|
||||||
|
|
||||||
|
---@class AvanteProviderFunctor
|
||||||
|
local M = {}
|
||||||
|
|
||||||
|
M.api_key_name = "cmd:gcloud auth application-default print-access-token"
|
||||||
|
|
||||||
|
M.role_map = {
|
||||||
|
user = "user",
|
||||||
|
assistant = "model",
|
||||||
|
}
|
||||||
|
|
||||||
|
M.parse_messages = Gemini.parse_messages
|
||||||
|
M.parse_response = Gemini.parse_response
|
||||||
|
|
||||||
|
local function execute_command(command)
|
||||||
|
local handle = io.popen(command)
|
||||||
|
local result = handle:read("*a")
|
||||||
|
handle:close()
|
||||||
|
return result:match("^%s*(.-)%s*$")
|
||||||
|
end
|
||||||
|
|
||||||
|
M.parse_api_key = function()
|
||||||
|
if M.api_key_name:match("^cmd:") then
|
||||||
|
local command = M.api_key_name:sub(5)
|
||||||
|
return execute_command(command)
|
||||||
|
else
|
||||||
|
return vim.fn.getenv(M.api_key_name)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
M.parse_curl_args = function(provider, code_opts)
|
||||||
|
local base, body_opts = P.parse_config(provider)
|
||||||
|
local location = vim.fn.getenv("LOCATION") or "default-location"
|
||||||
|
local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id"
|
||||||
|
local model_id = base.model or "default-model-id"
|
||||||
|
local url = base.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id)
|
||||||
|
|
||||||
|
url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id)
|
||||||
|
|
||||||
|
body_opts = vim.tbl_deep_extend("force", body_opts, {
|
||||||
|
generationConfig = {
|
||||||
|
temperature = body_opts.temperature,
|
||||||
|
maxOutputTokens = body_opts.max_tokens,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
body_opts.temperature = nil
|
||||||
|
body_opts.max_tokens = nil
|
||||||
|
return {
|
||||||
|
url = url,
|
||||||
|
headers = {
|
||||||
|
["Authorization"] = "Bearer " .. M.parse_api_key(),
|
||||||
|
["Content-Type"] = "application/json; charset=utf-8",
|
||||||
|
},
|
||||||
|
proxy = base.proxy,
|
||||||
|
insecure = base.allow_insecure,
|
||||||
|
body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts),
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
return M
|
Loading…
x
Reference in New Issue
Block a user