feat(llm): support local LLM (#86)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2024-08-19 08:35:36 -04:00 committed by GitHub
parent 330d214c14
commit 02eb39ae48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 5 deletions

View File

@ -330,6 +330,22 @@ vendors = {
}, },
``` ```
## Local LLM
If you want to use local LLM that has a OpenAI-compatible server, set `["local"] = true`:
```lua
openai = {
endpoint = "http://127.0.0.1:3000",
model = "code-gemma",
temperature = 0,
max_tokens = 4096,
["local"] = true,
},
```
You will be responsible for setting up the server yourself before using Neovim.
</details> </details>
## License ## License

View File

@ -9,42 +9,52 @@ M.defaults = {
debug = false, debug = false,
---Currently, default supported providers include "claude", "openai", "azure", "deepseek", "groq" ---Currently, default supported providers include "claude", "openai", "azure", "deepseek", "groq"
---For custom provider, see README.md ---For custom provider, see README.md
---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | [string] ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | string
provider = "claude", provider = "claude",
---@type AvanteSupportedProvider
openai = { openai = {
endpoint = "https://api.openai.com", endpoint = "https://api.openai.com",
model = "gpt-4o", model = "gpt-4o",
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 4096,
["local"] = false,
}, },
---@type AvanteAzureProvider
azure = { azure = {
endpoint = "", -- example: "https://<your-resource-name>.openai.azure.com" endpoint = "", -- example: "https://<your-resource-name>.openai.azure.com"
deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment") deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment")
api_version = "2024-06-01", api_version = "2024-06-01",
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 4096,
["local"] = false,
}, },
---@type AvanteSupportedProvider
claude = { claude = {
endpoint = "https://api.anthropic.com", endpoint = "https://api.anthropic.com",
model = "claude-3-5-sonnet-20240620", model = "claude-3-5-sonnet-20240620",
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 4096,
["local"] = false,
}, },
---@type AvanteSupportedProvider
deepseek = { deepseek = {
endpoint = "https://api.deepseek.com", endpoint = "https://api.deepseek.com",
model = "deepseek-coder", model = "deepseek-coder",
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 4096,
["local"] = false,
}, },
---@type AvanteSupportedProvider
groq = { groq = {
endpoint = "https://api.groq.com", endpoint = "https://api.groq.com",
model = "llama-3.1-70b-versatile", model = "llama-3.1-70b-versatile",
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 4096,
["local"] = false,
}, },
---To add support for custom provider, follow the format below ---To add support for custom provider, follow the format below
---See https://github.com/yetone/avante.nvim/README.md#custom-providers for more details ---See https://github.com/yetone/avante.nvim/README.md#custom-providers for more details
---@type table<string, AvanteProvider> ---@type {[string]: AvanteProvider}
vendors = {}, vendors = {},
---Specify the behaviour of avante.nvim ---Specify the behaviour of avante.nvim
---1. auto_apply_diff_after_generation: Whether to automatically apply diff after LLM response. ---1. auto_apply_diff_after_generation: Whether to automatically apply diff after LLM response.

View File

@ -28,16 +28,23 @@ local E = {
}, },
} }
E = setmetatable(E, { setmetatable(E, {
---@param k Provider ---@param k Provider
__index = function(_, k) __index = function(_, k)
local builtins = E.env[k] local builtins = E.env[k]
if builtins then if builtins then
if Config.options[k]["local"] then
return true
end
return os.getenv(builtins) and true or false return os.getenv(builtins) and true or false
end end
---@type AvanteProvider | nil
local external = Config.vendors[k] local external = Config.vendors[k]
if external then if external then
if external["local"] then
return true
end
return os.getenv(external.api_key_name) and true or false return os.getenv(external.api_key_name) and true or false
end end
end, end,
@ -46,6 +53,7 @@ E = setmetatable(E, {
---@private ---@private
E._once = false E._once = false
---@param provider Provider
E.is_default = function(provider) E.is_default = function(provider)
return E.env[provider] and true or false return E.env[provider] and true or false
end end
@ -60,6 +68,7 @@ E.key = function(provider)
return E.env[provider] return E.env[provider]
end end
---@type AvanteProvider | nil
local external = Config.vendors[provider] local external = Config.vendors[provider]
if external then if external then
return external.api_key_name return external.api_key_name
@ -68,8 +77,22 @@ E.key = function(provider)
end end
end end
---@param provider Provider
E.is_local = function(provider)
if Config.options[provider] then
return Config.options[provider]["local"]
elseif Config.vendors[provider] then
return Config.vendors[provider]["local"]
else
return false
end
end
---@param provider? Provider ---@param provider? Provider
E.value = function(provider) E.value = function(provider)
if E.is_local(provider or Config.provider) then
return "dummy"
end
return os.getenv(E.key(provider or Config.provider)) return os.getenv(E.key(provider or Config.provider))
end end
@ -88,7 +111,7 @@ E.setup = function(var, refresh)
vim.fn.setenv(var, value) vim.fn.setenv(var, value)
else else
if not E[Config.provider] then if not E[Config.provider] then
Util.warn("Failed to set " .. var .. ". Avante won't work as expected", { once = true, title = "Avante" }) Utils.warn("Failed to set " .. var .. ". Avante won't work as expected", { once = true, title = "Avante" })
end end
end end
end end
@ -208,9 +231,22 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
---@field on_complete fun(err: string|nil): any ---@field on_complete fun(err: string|nil): any
---@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil ---@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil
--- ---
---@class AvanteProvider ---@class AvanteDefaultBaseProvider
---@field endpoint string ---@field endpoint string
---@field local? boolean
---
---@class AvanteSupportedProvider: AvanteDefaultBaseProvider
---@field model string ---@field model string
---@field temperature number
---@field max_tokens number
---
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
---@field deployment string
---@field api_version string
---@field temperature number
---@field max_tokens number
---
---@class AvanteProvider: AvanteDefaultBaseProvider
---@field api_key_name string ---@field api_key_name string
---@field parse_response_data AvanteResponseParser ---@field parse_response_data AvanteResponseParser
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput ---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput