feat(providers): add support for custom vendors (#74)
* feat(providers): add support for custom vendors Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * fix: override configuration not setup Signed-off-by: Aaron Pham <contact@aarnphm.xyz> --------- Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
5fa4f701dd
commit
2700cad921
93
README.md
93
README.md
@ -258,6 +258,99 @@ lua_ls = {
|
|||||||
|
|
||||||
Then you can set `dev = true` in your `lazy` config for development.
|
Then you can set `dev = true` in your `lazy` config for development.
|
||||||
|
|
||||||
|
## Custom Providers
|
||||||
|
|
||||||
|
To add support for custom providers, one add `AvanteProvider` spec into `opts.vendors`:
|
||||||
|
|
||||||
|
```lua
|
||||||
|
{
|
||||||
|
provider = "my-custom-provider", -- You can then change this provider here
|
||||||
|
vendors = {
|
||||||
|
["my-custom-provider"] = {...}
|
||||||
|
},
|
||||||
|
windows = {
|
||||||
|
wrap_line = true,
|
||||||
|
width = 30, -- default % based on available width
|
||||||
|
},
|
||||||
|
--- @class AvanteConflictUserConfig
|
||||||
|
diff = {
|
||||||
|
debug = false,
|
||||||
|
autojump = true,
|
||||||
|
---@type string | fun(): any
|
||||||
|
list_opener = "copen",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
A custom provider should following the following spec:
|
||||||
|
|
||||||
|
```lua
|
||||||
|
---@type AvanteProvider
|
||||||
|
{
|
||||||
|
endpoint = "https://api.openai.com/v1/chat/completions", -- The full endpoint of the provider
|
||||||
|
model = "gpt-4o", -- The model name to use with this provider
|
||||||
|
api_key_name = "OPENAI_API_KEY", -- The name of the environment variable that contains the API key
|
||||||
|
--- This function below will be used to parse in cURL arguments.
|
||||||
|
--- It takes in the provider options as the first argument, followed by code_opts retrieved from given buffer.
|
||||||
|
--- This code_opts include:
|
||||||
|
--- - question: Input from the users
|
||||||
|
--- - code_lang: the language of given code buffer
|
||||||
|
--- - code_content: content of code buffer
|
||||||
|
--- - selected_code_content: (optional) If given code content is selected in visual mode as context.
|
||||||
|
---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||||
|
parse_curl_args = function(opts, code_opts) end
|
||||||
|
--- This function will be used to parse incoming SSE stream
|
||||||
|
--- It takes in the data stream as the first argument, followed by opts retrieved from given buffer.
|
||||||
|
--- This opts include:
|
||||||
|
--- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk
|
||||||
|
--- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk
|
||||||
|
--- - event_state: SSE event state.
|
||||||
|
---@type fun(data_stream: string, opts: ResponseParser): nil
|
||||||
|
parse_response_data = function(data_stream, opts) end
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Full working example of perplexity</summary>
|
||||||
|
|
||||||
|
```lua
|
||||||
|
vendors = {
|
||||||
|
---@type AvanteProvider
|
||||||
|
perplexity = {
|
||||||
|
endpoint = "https://api.perplexity.ai/chat/completions",
|
||||||
|
model = "llama-3.1-sonar-large-128k-online",
|
||||||
|
api_key_name = "PPLX_API_KEY",
|
||||||
|
--- this function below will be used to parse in cURL arguments.
|
||||||
|
parse_curl_args = function(opts, code_opts)
|
||||||
|
local Llm = require "avante.llm"
|
||||||
|
return {
|
||||||
|
url = opts.endpoint,
|
||||||
|
headers = {
|
||||||
|
["Accept"] = "application/json",
|
||||||
|
["Content-Type"] = "application/json",
|
||||||
|
["Authorization"] = "Bearer " .. os.getenv(opts.api_key_name),
|
||||||
|
},
|
||||||
|
body = {
|
||||||
|
model = opts.model,
|
||||||
|
messages = Llm.make_openai_message(code_opts), -- you can make your own message, but this is very advanced
|
||||||
|
temperature = 0,
|
||||||
|
max_tokens = 8192,
|
||||||
|
stream = true, -- this will be set by default.
|
||||||
|
},
|
||||||
|
}
|
||||||
|
end,
|
||||||
|
-- The below function is used if the vendors has specific SSE spec that is not claude or openai.
|
||||||
|
parse_response_data = function(data_stream, opts)
|
||||||
|
local Llm = require "avante.llm"
|
||||||
|
Llm.parse_openai_response(data_stream, opts)
|
||||||
|
end,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
avante.nvim is licensed under the Apache License. For more details, please refer to the [LICENSE](./LICENSE) file.
|
avante.nvim is licensed under the Apache License. For more details, please refer to the [LICENSE](./LICENSE) file.
|
||||||
|
@ -6,7 +6,7 @@ local M = {}
|
|||||||
|
|
||||||
---@class avante.Config
|
---@class avante.Config
|
||||||
M.defaults = {
|
M.defaults = {
|
||||||
---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq"
|
---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | [string]
|
||||||
provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq"
|
provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq"
|
||||||
openai = {
|
openai = {
|
||||||
endpoint = "https://api.openai.com",
|
endpoint = "https://api.openai.com",
|
||||||
@ -39,6 +39,10 @@ M.defaults = {
|
|||||||
temperature = 0,
|
temperature = 0,
|
||||||
max_tokens = 4096,
|
max_tokens = 4096,
|
||||||
},
|
},
|
||||||
|
--- To add support for custom provider, follow the format below
|
||||||
|
--- See https://github.com/yetone/avante.nvim/README.md#custom-providers for more details
|
||||||
|
---@type table<string, AvanteProvider>
|
||||||
|
vendors = {},
|
||||||
behaviour = {
|
behaviour = {
|
||||||
auto_apply_diff_after_generation = false, -- Whether to automatically apply diff after LLM response.
|
auto_apply_diff_after_generation = false, -- Whether to automatically apply diff after LLM response.
|
||||||
},
|
},
|
||||||
@ -100,6 +104,11 @@ function M.setup(opts)
|
|||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---@param opts? avante.Config
|
||||||
|
function M.override(opts)
|
||||||
|
M.options = vim.tbl_deep_extend("force", M.options, opts or {})
|
||||||
|
end
|
||||||
|
|
||||||
M = setmetatable(M, {
|
M = setmetatable(M, {
|
||||||
__index = function(_, k)
|
__index = function(_, k)
|
||||||
if M.options[k] then
|
if M.options[k] then
|
||||||
|
@ -201,7 +201,7 @@ function M.setup(opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
require("avante.diff").setup()
|
require("avante.diff").setup()
|
||||||
require("avante.ai_bot").setup()
|
require("avante.llm").setup()
|
||||||
|
|
||||||
-- setup helpers
|
-- setup helpers
|
||||||
H.autocmds()
|
H.autocmds()
|
||||||
|
@ -8,13 +8,13 @@ local Tiktoken = require("avante.tiktoken")
|
|||||||
local Dressing = require("avante.ui.dressing")
|
local Dressing = require("avante.ui.dressing")
|
||||||
|
|
||||||
---@private
|
---@private
|
||||||
---@class AvanteAiBotInternal
|
---@class AvanteLLMInternal
|
||||||
local H = {}
|
local H = {}
|
||||||
|
|
||||||
---@class avante.AiBot
|
---@class avante.LLM
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.CANCEL_PATTERN = "AvanteAiBotEscape"
|
M.CANCEL_PATTERN = "AvanteLLMEscape"
|
||||||
|
|
||||||
---@class EnvironmentHandler: table<[Provider], string>
|
---@class EnvironmentHandler: table<[Provider], string>
|
||||||
local E = {
|
local E = {
|
||||||
@ -31,16 +31,41 @@ local E = {
|
|||||||
E = setmetatable(E, {
|
E = setmetatable(E, {
|
||||||
---@param k Provider
|
---@param k Provider
|
||||||
__index = function(_, k)
|
__index = function(_, k)
|
||||||
return os.getenv(E.env[k]) and true or false
|
local builtins = E.env[k]
|
||||||
|
if builtins then
|
||||||
|
return os.getenv(builtins) and true or false
|
||||||
|
end
|
||||||
|
|
||||||
|
local external = Config.vendors[k]
|
||||||
|
if external then
|
||||||
|
return os.getenv(external.api_key_name) and true or false
|
||||||
|
end
|
||||||
end,
|
end,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
---@private
|
||||||
E._once = false
|
E._once = false
|
||||||
|
|
||||||
|
E.is_default = function(provider)
|
||||||
|
return E.env[provider] and true or false
|
||||||
|
end
|
||||||
|
|
||||||
--- return the environment variable name for the given provider
|
--- return the environment variable name for the given provider
|
||||||
---@param provider? Provider
|
---@param provider? Provider
|
||||||
---@return string the envvar key
|
---@return string the envvar key
|
||||||
E.key = function(provider)
|
E.key = function(provider)
|
||||||
return E.env[provider or Config.provider]
|
provider = provider or Config.provider
|
||||||
|
|
||||||
|
if E.is_default(provider) then
|
||||||
|
return E.env[provider]
|
||||||
|
end
|
||||||
|
|
||||||
|
local external = Config.vendors[provider]
|
||||||
|
if external then
|
||||||
|
return external.api_key_name
|
||||||
|
else
|
||||||
|
error("Failed to find provider: " .. provider, 2)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param provider? Provider
|
---@param provider? Provider
|
||||||
@ -52,6 +77,7 @@ end
|
|||||||
--- This will only run once and spawn a UI for users to input the envvar.
|
--- This will only run once and spawn a UI for users to input the envvar.
|
||||||
---@param var Provider supported providers
|
---@param var Provider supported providers
|
||||||
---@param refresh? boolean
|
---@param refresh? boolean
|
||||||
|
---@private
|
||||||
E.setup = function(var, refresh)
|
E.setup = function(var, refresh)
|
||||||
refresh = refresh or false
|
refresh = refresh or false
|
||||||
|
|
||||||
@ -160,7 +186,19 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
|
|||||||
---@field code_content string
|
---@field code_content string
|
||||||
---@field selected_code_content? string
|
---@field selected_code_content? string
|
||||||
---
|
---
|
||||||
---@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): {role: "user" | "system", content: string | table<string, any>}[]
|
---@class AvanteBaseMessage
|
||||||
|
---@field role "user" | "system"
|
||||||
|
---@field content string
|
||||||
|
---
|
||||||
|
---@class AvanteClaudeMessage: AvanteBaseMessage
|
||||||
|
---@field role "user"
|
||||||
|
---@field content {type: "text", text: string, cache_control?: {type: "ephemeral"}}[]
|
||||||
|
---
|
||||||
|
---@alias AvanteOpenAIMessage AvanteBaseMessage
|
||||||
|
---
|
||||||
|
---@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage
|
||||||
|
---
|
||||||
|
---@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): AvanteChatMessage[]
|
||||||
---
|
---
|
||||||
---@class AvanteCurlOutput: {url: string, body: table<string, any> | string, headers: table<string, string>}
|
---@class AvanteCurlOutput: {url: string, body: table<string, any> | string, headers: table<string, string>}
|
||||||
---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
|
---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||||
@ -169,12 +207,19 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
|
|||||||
---@field event_state string
|
---@field event_state string
|
||||||
---@field on_chunk fun(chunk: string): any
|
---@field on_chunk fun(chunk: string): any
|
||||||
---@field on_complete fun(err: string|nil): any
|
---@field on_complete fun(err: string|nil): any
|
||||||
---@field on_error? fun(err_type: string): nil
|
---@alias AvanteResponseParser fun(data_stream: string, opts: ResponseParser): nil
|
||||||
---@alias AvanteAiResponseParser fun(data_stream: string, opts: ResponseParser): nil
|
---
|
||||||
|
---@class AvanteProvider
|
||||||
|
---@field endpoint string
|
||||||
|
---@field model string
|
||||||
|
---@field api_key_name string
|
||||||
|
---@field parse_response_data AvanteResponseParser
|
||||||
|
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||||
|
|
||||||
------------------------------Anthropic------------------------------
|
------------------------------Anthropic------------------------------
|
||||||
|
|
||||||
---@type AvanteAiMessageBuilder
|
---@param opts AvantePromptOptions
|
||||||
|
---@return AvanteClaudeMessage[]
|
||||||
H.make_claude_message = function(opts)
|
H.make_claude_message = function(opts)
|
||||||
local code_prompt_obj = {
|
local code_prompt_obj = {
|
||||||
type = "text",
|
type = "text",
|
||||||
@ -232,7 +277,7 @@ H.make_claude_message = function(opts)
|
|||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
---@type AvanteAiResponseParser
|
---@type AvanteResponseParser
|
||||||
H.parse_claude_response = function(data_stream, opts)
|
H.parse_claude_response = function(data_stream, opts)
|
||||||
if opts.event_state == "content_block_delta" then
|
if opts.event_state == "content_block_delta" then
|
||||||
local json = vim.json.decode(data_stream)
|
local json = vim.json.decode(data_stream)
|
||||||
@ -268,7 +313,8 @@ end
|
|||||||
|
|
||||||
------------------------------OpenAI------------------------------
|
------------------------------OpenAI------------------------------
|
||||||
|
|
||||||
---@type AvanteAiMessageBuilder
|
---@param opts AvantePromptOptions
|
||||||
|
---@return AvanteOpenAIMessage[]
|
||||||
H.make_openai_message = function(opts)
|
H.make_openai_message = function(opts)
|
||||||
local user_prompt = base_user_prompt
|
local user_prompt = base_user_prompt
|
||||||
.. "\n\nCODE:\n"
|
.. "\n\nCODE:\n"
|
||||||
@ -304,7 +350,7 @@ H.make_openai_message = function(opts)
|
|||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
---@type AvanteAiResponseParser
|
---@type AvanteResponseParser
|
||||||
H.parse_openai_response = function(data_stream, opts)
|
H.parse_openai_response = function(data_stream, opts)
|
||||||
if data_stream:match('"%[DONE%]":') then
|
if data_stream:match('"%[DONE%]":') then
|
||||||
opts.on_complete(nil)
|
opts.on_complete(nil)
|
||||||
@ -346,7 +392,7 @@ end
|
|||||||
---@type AvanteAiMessageBuilder
|
---@type AvanteAiMessageBuilder
|
||||||
H.make_azure_message = H.make_openai_message
|
H.make_azure_message = H.make_openai_message
|
||||||
|
|
||||||
---@type AvanteAiResponseParser
|
---@type AvanteResponseParser
|
||||||
H.parse_azure_response = H.parse_openai_response
|
H.parse_azure_response = H.parse_openai_response
|
||||||
|
|
||||||
---@type AvanteCurlArgsBuilder
|
---@type AvanteCurlArgsBuilder
|
||||||
@ -375,7 +421,7 @@ end
|
|||||||
---@type AvanteAiMessageBuilder
|
---@type AvanteAiMessageBuilder
|
||||||
H.make_deepseek_message = H.make_openai_message
|
H.make_deepseek_message = H.make_openai_message
|
||||||
|
|
||||||
---@type AvanteAiResponseParser
|
---@type AvanteResponseParser
|
||||||
H.parse_deepseek_response = H.parse_openai_response
|
H.parse_deepseek_response = H.parse_openai_response
|
||||||
|
|
||||||
---@type AvanteCurlArgsBuilder
|
---@type AvanteCurlArgsBuilder
|
||||||
@ -401,7 +447,7 @@ end
|
|||||||
---@type AvanteAiMessageBuilder
|
---@type AvanteAiMessageBuilder
|
||||||
H.make_groq_message = H.make_openai_message
|
H.make_groq_message = H.make_openai_message
|
||||||
|
|
||||||
---@type AvanteAiResponseParser
|
---@type AvanteResponseParser
|
||||||
H.parse_groq_response = H.parse_openai_response
|
H.parse_groq_response = H.parse_openai_response
|
||||||
|
|
||||||
---@type AvanteCurlArgsBuilder
|
---@type AvanteCurlArgsBuilder
|
||||||
@ -424,7 +470,7 @@ end
|
|||||||
|
|
||||||
------------------------------Logic------------------------------
|
------------------------------Logic------------------------------
|
||||||
|
|
||||||
local group = vim.api.nvim_create_augroup("AvanteAiBot", { clear = true })
|
local group = vim.api.nvim_create_augroup("AvanteLLM", { clear = true })
|
||||||
local active_job = nil
|
local active_job = nil
|
||||||
|
|
||||||
---@param question string
|
---@param question string
|
||||||
@ -433,17 +479,35 @@ local active_job = nil
|
|||||||
---@param selected_content_content string | nil
|
---@param selected_content_content string | nil
|
||||||
---@param on_chunk fun(chunk: string): any
|
---@param on_chunk fun(chunk: string): any
|
||||||
---@param on_complete fun(err: string|nil): any
|
---@param on_complete fun(err: string|nil): any
|
||||||
M.invoke_llm_stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
|
M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
|
||||||
local provider = Config.provider
|
local provider = Config.provider
|
||||||
local event_state = nil
|
local event_state = nil
|
||||||
|
|
||||||
---@type AvanteCurlOutput
|
local code_opts = {
|
||||||
local spec = H["make_" .. provider .. "_curl_args"]({
|
|
||||||
question = question,
|
question = question,
|
||||||
code_lang = code_lang,
|
code_lang = code_lang,
|
||||||
code_content = code_content,
|
code_content = code_content,
|
||||||
selected_code_content = selected_content_content,
|
selected_code_content = selected_content_content,
|
||||||
})
|
}
|
||||||
|
local handler_opts = vim.deepcopy({ on_chunk = on_chunk, on_complete = on_complete, event_state = event_state }, true)
|
||||||
|
|
||||||
|
---@type AvanteCurlOutput
|
||||||
|
local spec = nil
|
||||||
|
|
||||||
|
---@type AvanteProvider
|
||||||
|
local ProviderConfig = nil
|
||||||
|
|
||||||
|
if E.is_default(provider) then
|
||||||
|
spec = H["make_" .. provider .. "_curl_args"](code_opts)
|
||||||
|
else
|
||||||
|
ProviderConfig = Config.vendors[provider]
|
||||||
|
spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
|
||||||
|
end
|
||||||
|
if spec.body.stream == nil then
|
||||||
|
spec = vim.tbl_deep_extend("force", spec, {
|
||||||
|
body = { stream = true },
|
||||||
|
})
|
||||||
|
end
|
||||||
|
|
||||||
---@param line string
|
---@param line string
|
||||||
local function parse_and_call(line)
|
local function parse_and_call(line)
|
||||||
@ -454,10 +518,11 @@ M.invoke_llm_stream = function(question, code_lang, code_content, selected_conte
|
|||||||
end
|
end
|
||||||
local data_match = line:match("^data: (.+)$")
|
local data_match = line:match("^data: (.+)$")
|
||||||
if data_match then
|
if data_match then
|
||||||
H["parse_" .. provider .. "_response"](
|
if ProviderConfig ~= nil then
|
||||||
data_match,
|
ProviderConfig.parse_response_data(data_match, handler_opts)
|
||||||
vim.deepcopy({ on_chunk = on_chunk, on_complete = on_complete, event_state = event_state }, true)
|
else
|
||||||
)
|
H["parse_" .. provider .. "_response"](data_match, handler_opts)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -521,7 +586,7 @@ function M.refresh(provider)
|
|||||||
else
|
else
|
||||||
vim.notify_once("Switch to provider: " .. provider, vim.log.levels.INFO)
|
vim.notify_once("Switch to provider: " .. provider, vim.log.levels.INFO)
|
||||||
end
|
end
|
||||||
require("avante").setup({ provider = provider })
|
require("avante.config").override({ provider = provider })
|
||||||
end
|
end
|
||||||
|
|
||||||
M.commands = function()
|
M.commands = function()
|
||||||
@ -536,11 +601,25 @@ M.commands = function()
|
|||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or ""
|
local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or ""
|
||||||
|
-- join two tables
|
||||||
|
local Keys = vim.list_extend(vim.tbl_keys(E.env), vim.tbl_keys(Config.vendors))
|
||||||
return vim.tbl_filter(function(key)
|
return vim.tbl_filter(function(key)
|
||||||
return key:find(prefix) == 1
|
return key:find(prefix) == 1
|
||||||
end, vim.tbl_keys(E.env))
|
end, Keys)
|
||||||
end,
|
end,
|
||||||
})
|
})
|
||||||
end
|
end
|
||||||
|
|
||||||
return M
|
return setmetatable(M, {
|
||||||
|
__index = function(t, k)
|
||||||
|
local h = H[k]
|
||||||
|
if h then
|
||||||
|
return H[k]
|
||||||
|
end
|
||||||
|
local v = t[k]
|
||||||
|
if v then
|
||||||
|
return t[k]
|
||||||
|
end
|
||||||
|
error("Failed to find key: " .. k)
|
||||||
|
end,
|
||||||
|
})
|
@ -7,7 +7,7 @@ local N = require("nui-components")
|
|||||||
local Config = require("avante.config")
|
local Config = require("avante.config")
|
||||||
local View = require("avante.view")
|
local View = require("avante.view")
|
||||||
local Diff = require("avante.diff")
|
local Diff = require("avante.diff")
|
||||||
local AiBot = require("avante.ai_bot")
|
local Llm = require("avante.llm")
|
||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
|
|
||||||
local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated"
|
local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated"
|
||||||
@ -141,7 +141,7 @@ function Sidebar:intialize()
|
|||||||
mode = { "n" },
|
mode = { "n" },
|
||||||
key = "q",
|
key = "q",
|
||||||
handler = function()
|
handler = function()
|
||||||
api.nvim_exec_autocmds("User", { pattern = AiBot.CANCEL_PATTERN })
|
api.nvim_exec_autocmds("User", { pattern = Llm.CANCEL_PATTERN })
|
||||||
self.renderer:close()
|
self.renderer:close()
|
||||||
end,
|
end,
|
||||||
},
|
},
|
||||||
@ -149,7 +149,7 @@ function Sidebar:intialize()
|
|||||||
mode = { "n" },
|
mode = { "n" },
|
||||||
key = "<Esc>",
|
key = "<Esc>",
|
||||||
handler = function()
|
handler = function()
|
||||||
api.nvim_exec_autocmds("User", { pattern = AiBot.CANCEL_PATTERN })
|
api.nvim_exec_autocmds("User", { pattern = Llm.CANCEL_PATTERN })
|
||||||
self.renderer:close()
|
self.renderer:close()
|
||||||
end,
|
end,
|
||||||
},
|
},
|
||||||
@ -245,6 +245,9 @@ end
|
|||||||
---@param content string concatenated content of the buffer
|
---@param content string concatenated content of the buffer
|
||||||
---@param opts? {focus?: boolean, stream?: boolean, scroll?: boolean, callback?: fun(): nil} whether to focus the result view
|
---@param opts? {focus?: boolean, stream?: boolean, scroll?: boolean, callback?: fun(): nil} whether to focus the result view
|
||||||
function Sidebar:update_content(content, opts)
|
function Sidebar:update_content(content, opts)
|
||||||
|
if not self.view.buf then
|
||||||
|
return
|
||||||
|
end
|
||||||
opts = vim.tbl_deep_extend("force", { focus = true, scroll = true, stream = false, callback = nil }, opts or {})
|
opts = vim.tbl_deep_extend("force", { focus = true, scroll = true, stream = false, callback = nil }, opts or {})
|
||||||
if opts.stream then
|
if opts.stream then
|
||||||
vim.schedule(function()
|
vim.schedule(function()
|
||||||
@ -643,9 +646,16 @@ function Sidebar:render()
|
|||||||
signal.is_loading = true
|
signal.is_loading = true
|
||||||
local state = signal:get_value()
|
local state = signal:get_value()
|
||||||
local request = state.text
|
local request = state.text
|
||||||
|
---@type string
|
||||||
|
local model
|
||||||
|
|
||||||
local provider_config = Config[Config.provider]
|
local builtins_provider_config = Config[Config.provider]
|
||||||
local model = provider_config and provider_config.model or "default"
|
if builtins_provider_config ~= nil then
|
||||||
|
model = builtins_provider_config.model
|
||||||
|
else
|
||||||
|
local vendor_provider_config = Config.vendors[Config.provider]
|
||||||
|
model = vendor_provider_config and vendor_provider_config.model or "default"
|
||||||
|
end
|
||||||
|
|
||||||
local timestamp = get_timestamp()
|
local timestamp = get_timestamp()
|
||||||
|
|
||||||
@ -670,50 +680,43 @@ function Sidebar:render()
|
|||||||
|
|
||||||
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
|
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
|
||||||
|
|
||||||
AiBot.invoke_llm_stream(
|
Llm.stream(request, filetype, content_with_line_numbers, selected_code_content_with_line_numbers, function(chunk)
|
||||||
request,
|
signal.is_loading = true
|
||||||
filetype,
|
full_response = full_response .. chunk
|
||||||
content_with_line_numbers,
|
self:update_content(chunk, { stream = true, scroll = false })
|
||||||
selected_code_content_with_line_numbers,
|
vim.schedule(function()
|
||||||
function(chunk)
|
vim.cmd("redraw")
|
||||||
signal.is_loading = true
|
end)
|
||||||
full_response = full_response .. chunk
|
end, function(err)
|
||||||
self:update_content(chunk, { stream = true, scroll = false })
|
signal.is_loading = false
|
||||||
vim.schedule(function()
|
|
||||||
vim.cmd("redraw")
|
|
||||||
end)
|
|
||||||
end,
|
|
||||||
function(err)
|
|
||||||
signal.is_loading = false
|
|
||||||
|
|
||||||
if err ~= nil then
|
if err ~= nil then
|
||||||
self:update_content(content_prefix .. full_response .. "\n\n🚨 Error: " .. vim.inspect(err))
|
self:update_content(content_prefix .. full_response .. "\n\n🚨 Error: " .. vim.inspect(err))
|
||||||
return
|
return
|
||||||
end
|
|
||||||
|
|
||||||
-- Execute when the stream request is actually completed
|
|
||||||
self:update_content(
|
|
||||||
content_prefix
|
|
||||||
.. full_response
|
|
||||||
.. "\n\n🎉🎉🎉 **Generation complete!** Please review the code suggestions above.\n\n",
|
|
||||||
{
|
|
||||||
callback = function()
|
|
||||||
api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN })
|
|
||||||
end,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
-- Save chat history
|
|
||||||
table.insert(chat_history or {}, {
|
|
||||||
timestamp = timestamp,
|
|
||||||
provider = Config.provider,
|
|
||||||
model = model,
|
|
||||||
request = request,
|
|
||||||
response = full_response,
|
|
||||||
})
|
|
||||||
save_chat_history(self, chat_history)
|
|
||||||
end
|
end
|
||||||
)
|
|
||||||
|
-- Execute when the stream request is actually completed
|
||||||
|
self:update_content(
|
||||||
|
content_prefix
|
||||||
|
.. full_response
|
||||||
|
.. "\n\n🎉🎉🎉 **Generation complete!** Please review the code suggestions above.\n\n",
|
||||||
|
{
|
||||||
|
callback = function()
|
||||||
|
api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN })
|
||||||
|
end,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
-- Save chat history
|
||||||
|
table.insert(chat_history or {}, {
|
||||||
|
timestamp = timestamp,
|
||||||
|
provider = Config.provider,
|
||||||
|
model = model,
|
||||||
|
request = request,
|
||||||
|
response = full_response,
|
||||||
|
})
|
||||||
|
save_chat_history(self, chat_history)
|
||||||
|
end)
|
||||||
|
|
||||||
if Config.behaviour.auto_apply_diff_after_generation then
|
if Config.behaviour.auto_apply_diff_after_generation then
|
||||||
apply()
|
apply()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user