feat: add Gemini support (#142)

Co-authored-by: Jihun Kim <jihun.kim.uk@gmail.com>
This commit is contained in:
jihunkim0 2024-08-21 17:52:25 +01:00 committed by GitHub
parent c41ad591a1
commit 45a47075e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 107 additions and 3 deletions

View File

@ -7,9 +7,9 @@ local M = {}
---@class avante.Config ---@class avante.Config
M.defaults = { M.defaults = {
debug = false, debug = false,
---Currently, default supported providers include "claude", "openai", "azure", "deepseek", "groq" ---Currently, default supported providers include "claude", "openai", "azure", "deepseek", "groq", "gemini"
---For custom provider, see README.md ---For custom provider, see README.md
---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | "copilot" | string ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | "copilot" | "gemini" | string
provider = "claude", provider = "claude",
---@type AvanteSupportedProvider ---@type AvanteSupportedProvider
openai = { openai = {
@ -62,6 +62,13 @@ M.defaults = {
max_tokens = 4096, max_tokens = 4096,
["local"] = false, ["local"] = false,
}, },
---@type AvanteGeminiProvider
gemini = {
endpoint = "",
type = "gemini",
model = "gemini-1.5-pro",
options = {},
},
---To add support for custom provider, follow the format below ---To add support for custom provider, follow the format below
---See https://github.com/yetone/avante.nvim/README.md#custom-providers for more details ---See https://github.com/yetone/avante.nvim/README.md#custom-providers for more details
---@type {[string]: AvanteProvider} ---@type {[string]: AvanteProvider}

View File

@ -54,6 +54,7 @@ local E = {
azure = "AZURE_OPENAI_API_KEY", azure = "AZURE_OPENAI_API_KEY",
deepseek = "DEEPSEEK_API_KEY", deepseek = "DEEPSEEK_API_KEY",
groq = "GROQ_API_KEY", groq = "GROQ_API_KEY",
gemini = "GEMINI_API_KEY",
copilot = function() copilot = function()
if Utils.has("copilot.lua") or Utils.has("copilot.vim") or Utils.copilot.find_config_path() then if Utils.has("copilot.lua") or Utils.has("copilot.vim") or Utils.copilot.find_config_path() then
return true return true
@ -270,7 +271,11 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
--- ---
---@alias AvanteOpenAIMessage AvanteBaseMessage ---@alias AvanteOpenAIMessage AvanteBaseMessage
--- ---
---@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage ---@class AvanteGeminiMessage
---@field role "user"
---@field parts { text: string }[]
---
---@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage | AvanteGeminiMessage
--- ---
---@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): AvanteChatMessage[] ---@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): AvanteChatMessage[]
--- ---
@ -302,6 +307,11 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
---@field allow_insecure boolean ---@field allow_insecure boolean
---@field timeout number ---@field timeout number
--- ---
---@class AvanteGeminiProvider: AvanteDefaultBaseProvider
---@field model string
---@field type string
---@field options table
---
---@class AvanteProvider: AvanteDefaultBaseProvider ---@class AvanteProvider: AvanteDefaultBaseProvider
---@field model? string ---@field model? string
---@field api_key_name string ---@field api_key_name string
@ -636,6 +646,93 @@ M.make_groq_curl_args = function(code_opts)
} }
end end
------------------------------Gemini------------------------------
---@param opts AvantePromptOptions
---@return AvanteGeminiMessage[]
M.make_gemini_message = function(opts)
local code_prompt_obj = {
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
}
if opts.selected_code_content then
code_prompt_obj.text = string.format("<code_context>```%s\n%s```</code_context>", opts.code_lang, opts.code_content)
end
-- parts ready
local message_content = {
code_prompt_obj,
}
if opts.selected_code_content then
local selected_code_obj = {
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.selected_code_content),
}
table.insert(message_content, selected_code_obj)
end
-- insert a part into parts
table.insert(message_content, {
text = string.format("<question>%s</question>", opts.question),
})
-- local user_prompt_obj = {
-- text = base_user_prompt,
-- }
-- insert another part into parts
-- table.insert(message_content, user_prompt_obj)
return {
{
role = "user",
parts = message_content,
},
}
end
---@type AvanteResponseParser
M.parse_gemini_response = function(data_stream, event_state, opts)
local json = vim.json.decode(data_stream)
opts.on_chunk(json.candidates[1].content.parts[1].text)
end
---@type AvanteCurlArgsBuilder
M.make_gemini_curl_args = function(code_opts)
local endpoint = ""
if Config.gemini.endpoint == "" then
endpoint = "https://generativelanguage.googleapis.com/v1beta/models/"
.. Config.gemini.model
.. ":streamGenerateContent?alt=sse&key="
.. E.value("gemini")
end
-- Prepare the body with contents and options (only if options are not empty)
local body = {
systemInstruction = {
role = "user",
parts = {
{
text = system_prompt .. base_user_prompt,
},
},
},
contents = M.make_gemini_message(code_opts),
}
if next(Config.gemini.options) ~= nil then -- Check if options table is not empty
for k, v in pairs(Config.gemini.options) do
body[k] = v
end
end
return {
url = endpoint,
headers = {
["Content-Type"] = "application/json",
},
body = body,
}
end
------------------------------Logic------------------------------ ------------------------------Logic------------------------------
local group = vim.api.nvim_create_augroup("AvanteLLM", { clear = true }) local group = vim.api.nvim_create_augroup("AvanteLLM", { clear = true })