From d2775135a3efeda596a2fe446daaed9b0193f708 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Fri, 23 Aug 2024 09:36:40 -0400 Subject: [PATCH] feat(llm): cohere support (#167) should be good set of defaults now, one in US, one in canada, and microsoft :/ Signed-off-by: Aaron Pham --- lua/avante/config.lua | 12 +++- lua/avante/llm.lua | 6 +- lua/avante/providers/cohere.lua | 121 ++++++++++++++++++++++++++++++++ lua/avante/providers/init.lua | 3 +- lua/avante/sidebar.lua | 3 + 5 files changed, 142 insertions(+), 3 deletions(-) create mode 100644 lua/avante/providers/cohere.lua diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 7ca2a65..6de7ee5 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -42,14 +42,24 @@ M.defaults = { claude = { endpoint = "https://api.anthropic.com", model = "claude-3-5-sonnet-20240620", - ["local"] = false, temperature = 0, max_tokens = 4096, + ["local"] = false, }, ---@type AvanteGeminiProvider gemini = { endpoint = "https://generativelanguage.googleapis.com/v1beta/models", model = "gemini-1.5-pro", + temperature = 0, + max_tokens = 4096, + ["local"] = false, + }, + ---@type AvanteGeminiProvider + cohere = { + endpoint = "https://api.cohere.com", + model = "command-r-plus", + temperature = 0, + max_tokens = 3072, ["local"] = false, }, ---To add support for custom provider, follow the format below diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 520303d..6be209d 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -137,7 +137,11 @@ M.stream = function(question, code_lang, code_content, selected_content_content, end Provider.parse_stream_data(data, handler_opts) else - parse_stream_data(data) + if Provider.parse_stream_data ~= nil then + Provider.parse_stream_data(data, handler_opts) + else + parse_stream_data(data) + end end end) end, diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua new file mode 100644 index 0000000..f5d57c8 --- /dev/null +++ b/lua/avante/providers/cohere.lua @@ -0,0 +1,121 @@ +local Utils = require("avante.utils") +local P = require("avante.providers") + +---@alias CohereFinishReason "COMPLETE" | "LENGTH" | "ERROR" +--- +---@class CohereChatStreamResponse +---@field event_type "stream-start" | "text-generation" | "stream-end" +---@field is_finished boolean +--- +---@class CohereTextGenerationResponse: CohereChatStreamResponse +---@field text string +--- +---@class CohereStreamEndResponse: CohereChatStreamResponse +---@field response CohereChatResponse +---@field finish_reason CohereFinishReason +--- +---@class CohereChatResponse +---@field text string +---@field generation_id string +---@field chat_history CohereMessage[] +---@field finish_reason CohereFinishReason +---@field meta {api_version: {version: integer}, billed_units: {input_tokens: integer, output_tokens: integer}, tokens: {input_tokens: integer, output_tokens: integer}} +--- +---@class CohereMessage +---@field role? "USER" | "SYSTEM" | "CHATBOT" +---@field message string +--- +---@class AvanteProviderFunctor +local M = {} + +M.api_key_name = "CO_API_KEY" + +M.has = function() + return os.getenv(M.api_key_name) and true or false +end + +M.parse_message = function(opts) + local user_prompt = opts.base_prompt + .. "\n\nCODE:\n" + .. "```" + .. opts.code_lang + .. "\n" + .. opts.code_content + .. "\n```" + .. "\n\nQUESTION:\n" + .. opts.question + + if opts.selected_code_content ~= nil then + user_prompt = opts.base_prompt + .. "\n\nCODE CONTEXT:\n" + .. "```" + .. opts.code_lang + .. "\n" + .. opts.code_content + .. "\n```" + .. "\n\nCODE:\n" + .. "```" + .. opts.code_lang + .. "\n" + .. opts.selected_code_content + .. "\n```" + .. "\n\nQUESTION:\n" + .. opts.question + end + + return { + preamble = opts.system_prompt, + message = user_prompt, + } +end + +M.parse_stream_data = function(data, opts) + ---@type CohereChatStreamResponse + local json = vim.json.decode(data) + if json.is_finished then + opts.on_complete(nil) + return + end + if json.event_type ~= nil then + ---@cast json CohereStreamEndResponse + if json.event_type == "stream-end" and json.finish_reason == "COMPLETE" then + opts.on_complete(nil) + return + end + ---@cast json CohereTextGenerationResponse + if json.event_type == "text-generation" then + opts.on_chunk(json.text) + end + end +end + +M.parse_curl_args = function(provider, code_opts) + local base, body_opts = P.parse_config(provider) + + local headers = { + ["Accept"] = "application/json", + ["Content-Type"] = "application/json", + ["X-Client-Name"] = "avante.nvim/Neovim/" + .. vim.version().major + .. "." + .. vim.version().minor + .. "." + .. vim.version().patch, + } + if not P.env.is_local("openai") then + headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name) + end + + return { + url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/chat", + proxy = base.proxy, + insecure = base.allow_insecure, + headers = headers, + body = vim.tbl_deep_extend("force", { + model = base.model, + stream = true, + }, M.parse_message(code_opts), body_opts), + } +end + +return M diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 9251ad6..d1c395c 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -46,6 +46,7 @@ local Dressing = require("avante.ui.dressing") ---@field local? boolean ---@field proxy? string ---@field allow_insecure? boolean +---@field api_key_name? string --- ---@class AvanteSupportedProvider: AvanteDefaultBaseProvider ---@field temperature? number @@ -64,7 +65,6 @@ local Dressing = require("avante.ui.dressing") ---@field model string --- ---@class AvanteProvider: AvanteDefaultBaseProvider ----@field api_key_name string ---@field parse_response_data AvanteResponseParser ---@field parse_curl_args AvanteCurlArgsParser ---@field parse_stream_data? AvanteStreamParser @@ -89,6 +89,7 @@ local Dressing = require("avante.ui.dressing") ---@field claude AvanteProviderFunctor ---@field azure AvanteProviderFunctor ---@field gemini AvanteProviderFunctor +---@field cohere AvanteProviderFunctor local M = {} setmetatable(M, { diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 6d79598..03a4693 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -1183,6 +1183,9 @@ Available commands: chat_history = {} save_chat_history(self, chat_history) self:update_content("Chat history cleared", { focus = false, scroll = false }) + vim.defer_fn(function() + self:close() + end, 1000) return else -- Unknown command