feat: support groq (#70)

* adding groq to config

* updated readme with groq
This commit is contained in:
franklin 2024-08-18 12:11:39 -04:00 committed by GitHub
parent 834bb9ea77
commit 0fddfc7d8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 40 additions and 5 deletions

View File

@ -81,8 +81,8 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
```lua ```lua
{ {
---@alias Provider "openai" | "claude" | "azure" | "deepseek" ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq"
provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq"
openai = { openai = {
endpoint = "https://api.openai.com", endpoint = "https://api.openai.com",
model = "gpt-4o", model = "gpt-4o",
@ -173,6 +173,12 @@ Given its early stage, `avante.nvim` currently supports the following basic func
> ```sh > ```sh
> export DEEPSEEK_API_KEY=you-api-key > export DEEPSEEK_API_KEY=you-api-key
> ``` > ```
>
> For Groq
>
> ```sh
> export GROQ_API_KEY=you-api-key
> ```
1. Open a code file in Neovim. 1. Open a code file in Neovim.
2. Use the `:AvanteAsk` command to query the AI about the code. 2. Use the `:AvanteAsk` command to query the AI about the code.

View File

@ -18,6 +18,7 @@ local E = {
claude = "ANTHROPIC_API_KEY", claude = "ANTHROPIC_API_KEY",
azure = "AZURE_OPENAI_API_KEY", azure = "AZURE_OPENAI_API_KEY",
deepseek = "DEEPSEEK_API_KEY", deepseek = "DEEPSEEK_API_KEY",
groq = "GROQ_API_KEY",
}, },
_once = false, _once = false,
} }
@ -316,6 +317,23 @@ local function call_openai_api_stream(question, code_lang, code_content, selecte
max_tokens = Config.deepseek.max_tokens, max_tokens = Config.deepseek.max_tokens,
stream = true, stream = true,
} }
elseif Config.provider == "groq" then
api_key = os.getenv(E.key("groq"))
url = Utils.trim_suffix(Config.groq.endpoint, "/") .. "/openai/v1/chat/completions"
headers = {
["Content-Type"] = "application/json",
["Authorization"] = "Bearer " .. api_key,
}
body = {
model = Config.groq.model,
messages = {
{ role = "system", content = system_prompt },
{ role = "user", content = user_prompt },
},
temperature = Config.groq.temperature,
max_tokens = Config.groq.max_tokens,
stream = true,
}
else else
url = Utils.trim_suffix(Config.openai.endpoint, "/") .. "/v1/chat/completions" url = Utils.trim_suffix(Config.openai.endpoint, "/") .. "/v1/chat/completions"
headers = { headers = {
@ -382,7 +400,12 @@ end
---@param on_chunk fun(chunk: string): any ---@param on_chunk fun(chunk: string): any
---@param on_complete fun(err: string|nil): any ---@param on_complete fun(err: string|nil): any
function M.call_ai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete) function M.call_ai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
if Config.provider == "openai" or Config.provider == "azure" or Config.provider == "deepseek" then if
Config.provider == "openai"
or Config.provider == "azure"
or Config.provider == "deepseek"
or Config.provider == "groq"
then
call_openai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete) call_openai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
elseif Config.provider == "claude" then elseif Config.provider == "claude" then
call_claude_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete) call_claude_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)

View File

@ -6,8 +6,8 @@ local M = {}
---@class avante.Config ---@class avante.Config
M.defaults = { M.defaults = {
---@alias Provider "openai" | "claude" | "azure" | "deepseek" ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq"
provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq"
openai = { openai = {
endpoint = "https://api.openai.com", endpoint = "https://api.openai.com",
model = "gpt-4o", model = "gpt-4o",
@ -33,6 +33,12 @@ M.defaults = {
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 4096,
}, },
groq = {
endpoint = "https://api.groq.com",
model = "llama-3.1-70b-versatile",
temperature = 0,
max_tokens = 4096,
},
behaviour = { behaviour = {
auto_apply_diff_after_generation = false, -- Whether to automatically apply diff after LLM response. auto_apply_diff_after_generation = false, -- Whether to automatically apply diff after LLM response.
}, },