diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 26479d0..08d68ca 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -7,9 +7,9 @@ local M = {} ---@class avante.Config M.defaults = { debug = false, - ---Currently, default supported providers include "claude", "openai", "azure", "deepseek", "groq" + ---Currently, default supported providers include "claude", "openai", "azure", "deepseek", "groq", "gemini" ---For custom provider, see README.md - ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | "copilot" | string + ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | "copilot" | "gemini" | string provider = "claude", ---@type AvanteSupportedProvider openai = { @@ -62,6 +62,13 @@ M.defaults = { max_tokens = 4096, ["local"] = false, }, + ---@type AvanteGeminiProvider + gemini = { + endpoint = "", + type = "gemini", + model = "gemini-1.5-pro", + options = {}, + }, ---To add support for custom provider, follow the format below ---See https://github.com/yetone/avante.nvim/README.md#custom-providers for more details ---@type {[string]: AvanteProvider} diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 1cab2ae..73d6efc 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -54,6 +54,7 @@ local E = { azure = "AZURE_OPENAI_API_KEY", deepseek = "DEEPSEEK_API_KEY", groq = "GROQ_API_KEY", + gemini = "GEMINI_API_KEY", copilot = function() if Utils.has("copilot.lua") or Utils.has("copilot.vim") or Utils.copilot.find_config_path() then return true @@ -270,7 +271,11 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m --- ---@alias AvanteOpenAIMessage AvanteBaseMessage --- ----@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage +---@class AvanteGeminiMessage +---@field role "user" +---@field parts { text: string }[] +--- +---@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage | AvanteGeminiMessage --- ---@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): AvanteChatMessage[] --- @@ -302,6 +307,11 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m ---@field allow_insecure boolean ---@field timeout number --- +---@class AvanteGeminiProvider: AvanteDefaultBaseProvider +---@field model string +---@field type string +---@field options table +--- ---@class AvanteProvider: AvanteDefaultBaseProvider ---@field model? string ---@field api_key_name string @@ -636,6 +646,93 @@ M.make_groq_curl_args = function(code_opts) } end +------------------------------Gemini------------------------------ + +---@param opts AvantePromptOptions +---@return AvanteGeminiMessage[] +M.make_gemini_message = function(opts) + local code_prompt_obj = { + text = string.format("```%s\n%s```", opts.code_lang, opts.code_content), + } + + if opts.selected_code_content then + code_prompt_obj.text = string.format("```%s\n%s```", opts.code_lang, opts.code_content) + end + + -- parts ready + local message_content = { + code_prompt_obj, + } + + if opts.selected_code_content then + local selected_code_obj = { + text = string.format("```%s\n%s```", opts.code_lang, opts.selected_code_content), + } + + table.insert(message_content, selected_code_obj) + end + + -- insert a part into parts + table.insert(message_content, { + text = string.format("%s", opts.question), + }) + + -- local user_prompt_obj = { + -- text = base_user_prompt, + -- } + + -- insert another part into parts + -- table.insert(message_content, user_prompt_obj) + + return { + { + role = "user", + parts = message_content, + }, + } +end + +---@type AvanteResponseParser +M.parse_gemini_response = function(data_stream, event_state, opts) + local json = vim.json.decode(data_stream) + opts.on_chunk(json.candidates[1].content.parts[1].text) +end + +---@type AvanteCurlArgsBuilder +M.make_gemini_curl_args = function(code_opts) + local endpoint = "" + if Config.gemini.endpoint == "" then + endpoint = "https://generativelanguage.googleapis.com/v1beta/models/" + .. Config.gemini.model + .. ":streamGenerateContent?alt=sse&key=" + .. E.value("gemini") + end + -- Prepare the body with contents and options (only if options are not empty) + local body = { + systemInstruction = { + role = "user", + parts = { + { + text = system_prompt .. base_user_prompt, + }, + }, + }, + contents = M.make_gemini_message(code_opts), + } + if next(Config.gemini.options) ~= nil then -- Check if options table is not empty + for k, v in pairs(Config.gemini.options) do + body[k] = v + end + end + return { + url = endpoint, + headers = { + ["Content-Type"] = "application/json", + }, + body = body, + } +end + ------------------------------Logic------------------------------ local group = vim.api.nvim_create_augroup("AvanteLLM", { clear = true })