feat(providers): add support for custom vendors (#74)
* feat(providers): add support for custom vendors Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * fix: override configuration not setup Signed-off-by: Aaron Pham <contact@aarnphm.xyz> --------- Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
		
							parent
							
								
									5fa4f701dd
								
							
						
					
					
						commit
						2700cad921
					
				
							
								
								
									
										93
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										93
									
								
								README.md
									
									
									
									
									
								
							@ -258,6 +258,99 @@ lua_ls = {
 | 
			
		||||
 | 
			
		||||
Then you can set `dev = true` in your `lazy` config for development.
 | 
			
		||||
 | 
			
		||||
## Custom Providers
 | 
			
		||||
 | 
			
		||||
To add support for custom providers, one add `AvanteProvider` spec into `opts.vendors`:
 | 
			
		||||
 | 
			
		||||
```lua
 | 
			
		||||
{
 | 
			
		||||
  provider = "my-custom-provider", -- You can then change this provider here
 | 
			
		||||
  vendors = {
 | 
			
		||||
    ["my-custom-provider"] = {...}
 | 
			
		||||
  },
 | 
			
		||||
  windows = {
 | 
			
		||||
    wrap_line = true,
 | 
			
		||||
    width = 30, -- default % based on available width
 | 
			
		||||
  },
 | 
			
		||||
  --- @class AvanteConflictUserConfig
 | 
			
		||||
  diff = {
 | 
			
		||||
    debug = false,
 | 
			
		||||
    autojump = true,
 | 
			
		||||
    ---@type string | fun(): any
 | 
			
		||||
    list_opener = "copen",
 | 
			
		||||
  },
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
A custom provider should following the following spec:
 | 
			
		||||
 | 
			
		||||
```lua
 | 
			
		||||
---@type AvanteProvider
 | 
			
		||||
{
 | 
			
		||||
  endpoint = "https://api.openai.com/v1/chat/completions", -- The full endpoint of the provider
 | 
			
		||||
  model = "gpt-4o", -- The model name to use with this provider
 | 
			
		||||
  api_key_name = "OPENAI_API_KEY", -- The name of the environment variable that contains the API key
 | 
			
		||||
  --- This function below will be used to parse in cURL arguments.
 | 
			
		||||
  --- It takes in the provider options as the first argument, followed by code_opts retrieved from given buffer.
 | 
			
		||||
  --- This code_opts include:
 | 
			
		||||
  --- - question: Input from the users
 | 
			
		||||
  --- - code_lang: the language of given code buffer
 | 
			
		||||
  --- - code_content: content of code buffer
 | 
			
		||||
  --- - selected_code_content: (optional) If given code content is selected in visual mode as context.
 | 
			
		||||
  ---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
			
		||||
  parse_curl_args = function(opts, code_opts) end
 | 
			
		||||
  --- This function will be used to parse incoming SSE stream
 | 
			
		||||
  --- It takes in the data stream as the first argument, followed by opts retrieved from given buffer.
 | 
			
		||||
  --- This opts include:
 | 
			
		||||
  --- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk
 | 
			
		||||
  --- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk
 | 
			
		||||
  --- - event_state: SSE event state.
 | 
			
		||||
  ---@type fun(data_stream: string, opts: ResponseParser): nil
 | 
			
		||||
  parse_response_data = function(data_stream, opts) end
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
<details>
 | 
			
		||||
<summary>Full working example of perplexity</summary>
 | 
			
		||||
 | 
			
		||||
```lua
 | 
			
		||||
vendors = {
 | 
			
		||||
  ---@type AvanteProvider
 | 
			
		||||
  perplexity = {
 | 
			
		||||
    endpoint = "https://api.perplexity.ai/chat/completions",
 | 
			
		||||
    model = "llama-3.1-sonar-large-128k-online",
 | 
			
		||||
    api_key_name = "PPLX_API_KEY",
 | 
			
		||||
    --- this function below will be used to parse in cURL arguments.
 | 
			
		||||
    parse_curl_args = function(opts, code_opts)
 | 
			
		||||
      local Llm = require "avante.llm"
 | 
			
		||||
      return {
 | 
			
		||||
        url = opts.endpoint,
 | 
			
		||||
        headers = {
 | 
			
		||||
          ["Accept"] = "application/json",
 | 
			
		||||
          ["Content-Type"] = "application/json",
 | 
			
		||||
          ["Authorization"] = "Bearer " .. os.getenv(opts.api_key_name),
 | 
			
		||||
        },
 | 
			
		||||
        body = {
 | 
			
		||||
          model = opts.model,
 | 
			
		||||
          messages = Llm.make_openai_message(code_opts), -- you can make your own message, but this is very advanced
 | 
			
		||||
          temperature = 0,
 | 
			
		||||
          max_tokens = 8192,
 | 
			
		||||
          stream = true, -- this will be set by default.
 | 
			
		||||
        },
 | 
			
		||||
      }
 | 
			
		||||
    end,
 | 
			
		||||
    -- The below function is used if the vendors has specific SSE spec that is not claude or openai.
 | 
			
		||||
    parse_response_data = function(data_stream, opts)
 | 
			
		||||
      local Llm = require "avante.llm"
 | 
			
		||||
      Llm.parse_openai_response(data_stream, opts)
 | 
			
		||||
    end,
 | 
			
		||||
  },
 | 
			
		||||
},
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
## License
 | 
			
		||||
 | 
			
		||||
avante.nvim is licensed under the Apache License. For more details, please refer to the [LICENSE](./LICENSE) file.
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ local M = {}
 | 
			
		||||
 | 
			
		||||
---@class avante.Config
 | 
			
		||||
M.defaults = {
 | 
			
		||||
  ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq"
 | 
			
		||||
  ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | [string]
 | 
			
		||||
  provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq"
 | 
			
		||||
  openai = {
 | 
			
		||||
    endpoint = "https://api.openai.com",
 | 
			
		||||
@ -39,6 +39,10 @@ M.defaults = {
 | 
			
		||||
    temperature = 0,
 | 
			
		||||
    max_tokens = 4096,
 | 
			
		||||
  },
 | 
			
		||||
  --- To add support for custom provider, follow the format below
 | 
			
		||||
  --- See https://github.com/yetone/avante.nvim/README.md#custom-providers for more details
 | 
			
		||||
  ---@type table<string, AvanteProvider>
 | 
			
		||||
  vendors = {},
 | 
			
		||||
  behaviour = {
 | 
			
		||||
    auto_apply_diff_after_generation = false, -- Whether to automatically apply diff after LLM response.
 | 
			
		||||
  },
 | 
			
		||||
@ -100,6 +104,11 @@ function M.setup(opts)
 | 
			
		||||
  )
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@param opts? avante.Config
 | 
			
		||||
function M.override(opts)
 | 
			
		||||
  M.options = vim.tbl_deep_extend("force", M.options, opts or {})
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
M = setmetatable(M, {
 | 
			
		||||
  __index = function(_, k)
 | 
			
		||||
    if M.options[k] then
 | 
			
		||||
 | 
			
		||||
@ -201,7 +201,7 @@ function M.setup(opts)
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  require("avante.diff").setup()
 | 
			
		||||
  require("avante.ai_bot").setup()
 | 
			
		||||
  require("avante.llm").setup()
 | 
			
		||||
 | 
			
		||||
  -- setup helpers
 | 
			
		||||
  H.autocmds()
 | 
			
		||||
 | 
			
		||||
@ -8,13 +8,13 @@ local Tiktoken = require("avante.tiktoken")
 | 
			
		||||
local Dressing = require("avante.ui.dressing")
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
---@class AvanteAiBotInternal
 | 
			
		||||
---@class AvanteLLMInternal
 | 
			
		||||
local H = {}
 | 
			
		||||
 | 
			
		||||
---@class avante.AiBot
 | 
			
		||||
---@class avante.LLM
 | 
			
		||||
local M = {}
 | 
			
		||||
 | 
			
		||||
M.CANCEL_PATTERN = "AvanteAiBotEscape"
 | 
			
		||||
M.CANCEL_PATTERN = "AvanteLLMEscape"
 | 
			
		||||
 | 
			
		||||
---@class EnvironmentHandler: table<[Provider], string>
 | 
			
		||||
local E = {
 | 
			
		||||
@ -31,16 +31,41 @@ local E = {
 | 
			
		||||
E = setmetatable(E, {
 | 
			
		||||
  ---@param k Provider
 | 
			
		||||
  __index = function(_, k)
 | 
			
		||||
    return os.getenv(E.env[k]) and true or false
 | 
			
		||||
    local builtins = E.env[k]
 | 
			
		||||
    if builtins then
 | 
			
		||||
      return os.getenv(builtins) and true or false
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    local external = Config.vendors[k]
 | 
			
		||||
    if external then
 | 
			
		||||
      return os.getenv(external.api_key_name) and true or false
 | 
			
		||||
    end
 | 
			
		||||
  end,
 | 
			
		||||
})
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
E._once = false
 | 
			
		||||
 | 
			
		||||
E.is_default = function(provider)
 | 
			
		||||
  return E.env[provider] and true or false
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- return the environment variable name for the given provider
 | 
			
		||||
---@param provider? Provider
 | 
			
		||||
---@return string the envvar key
 | 
			
		||||
E.key = function(provider)
 | 
			
		||||
  return E.env[provider or Config.provider]
 | 
			
		||||
  provider = provider or Config.provider
 | 
			
		||||
 | 
			
		||||
  if E.is_default(provider) then
 | 
			
		||||
    return E.env[provider]
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  local external = Config.vendors[provider]
 | 
			
		||||
  if external then
 | 
			
		||||
    return external.api_key_name
 | 
			
		||||
  else
 | 
			
		||||
    error("Failed to find provider: " .. provider, 2)
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@param provider? Provider
 | 
			
		||||
@ -52,6 +77,7 @@ end
 | 
			
		||||
--- This will only run once and spawn a UI for users to input the envvar.
 | 
			
		||||
---@param var Provider supported providers
 | 
			
		||||
---@param refresh? boolean
 | 
			
		||||
---@private
 | 
			
		||||
E.setup = function(var, refresh)
 | 
			
		||||
  refresh = refresh or false
 | 
			
		||||
 | 
			
		||||
@ -160,7 +186,19 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
 | 
			
		||||
---@field code_content string
 | 
			
		||||
---@field selected_code_content? string
 | 
			
		||||
---
 | 
			
		||||
---@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): {role: "user" | "system", content: string | table<string, any>}[]
 | 
			
		||||
---@class AvanteBaseMessage
 | 
			
		||||
---@field role "user" | "system"
 | 
			
		||||
---@field content string
 | 
			
		||||
---
 | 
			
		||||
---@class AvanteClaudeMessage: AvanteBaseMessage
 | 
			
		||||
---@field role "user"
 | 
			
		||||
---@field content {type: "text", text: string, cache_control?: {type: "ephemeral"}}[]
 | 
			
		||||
---
 | 
			
		||||
---@alias AvanteOpenAIMessage AvanteBaseMessage
 | 
			
		||||
---
 | 
			
		||||
---@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage
 | 
			
		||||
---
 | 
			
		||||
---@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): AvanteChatMessage[]
 | 
			
		||||
---
 | 
			
		||||
---@class AvanteCurlOutput: {url: string, body: table<string, any> | string, headers: table<string, string>}
 | 
			
		||||
---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
			
		||||
@ -169,12 +207,19 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
 | 
			
		||||
---@field event_state string
 | 
			
		||||
---@field on_chunk fun(chunk: string): any
 | 
			
		||||
---@field on_complete fun(err: string|nil): any
 | 
			
		||||
---@field on_error? fun(err_type: string): nil
 | 
			
		||||
---@alias AvanteAiResponseParser fun(data_stream: string, opts: ResponseParser): nil
 | 
			
		||||
---@alias AvanteResponseParser fun(data_stream: string, opts: ResponseParser): nil
 | 
			
		||||
---
 | 
			
		||||
---@class AvanteProvider
 | 
			
		||||
---@field endpoint string
 | 
			
		||||
---@field model string
 | 
			
		||||
---@field api_key_name string
 | 
			
		||||
---@field parse_response_data AvanteResponseParser
 | 
			
		||||
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
			
		||||
 | 
			
		||||
------------------------------Anthropic------------------------------
 | 
			
		||||
 | 
			
		||||
---@type AvanteAiMessageBuilder
 | 
			
		||||
---@param opts AvantePromptOptions
 | 
			
		||||
---@return AvanteClaudeMessage[]
 | 
			
		||||
H.make_claude_message = function(opts)
 | 
			
		||||
  local code_prompt_obj = {
 | 
			
		||||
    type = "text",
 | 
			
		||||
@ -232,7 +277,7 @@ H.make_claude_message = function(opts)
 | 
			
		||||
  }
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@type AvanteAiResponseParser
 | 
			
		||||
---@type AvanteResponseParser
 | 
			
		||||
H.parse_claude_response = function(data_stream, opts)
 | 
			
		||||
  if opts.event_state == "content_block_delta" then
 | 
			
		||||
    local json = vim.json.decode(data_stream)
 | 
			
		||||
@ -268,7 +313,8 @@ end
 | 
			
		||||
 | 
			
		||||
------------------------------OpenAI------------------------------
 | 
			
		||||
 | 
			
		||||
---@type AvanteAiMessageBuilder
 | 
			
		||||
---@param opts AvantePromptOptions
 | 
			
		||||
---@return AvanteOpenAIMessage[]
 | 
			
		||||
H.make_openai_message = function(opts)
 | 
			
		||||
  local user_prompt = base_user_prompt
 | 
			
		||||
    .. "\n\nCODE:\n"
 | 
			
		||||
@ -304,7 +350,7 @@ H.make_openai_message = function(opts)
 | 
			
		||||
  }
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@type AvanteAiResponseParser
 | 
			
		||||
---@type AvanteResponseParser
 | 
			
		||||
H.parse_openai_response = function(data_stream, opts)
 | 
			
		||||
  if data_stream:match('"%[DONE%]":') then
 | 
			
		||||
    opts.on_complete(nil)
 | 
			
		||||
@ -346,7 +392,7 @@ end
 | 
			
		||||
---@type AvanteAiMessageBuilder
 | 
			
		||||
H.make_azure_message = H.make_openai_message
 | 
			
		||||
 | 
			
		||||
---@type AvanteAiResponseParser
 | 
			
		||||
---@type AvanteResponseParser
 | 
			
		||||
H.parse_azure_response = H.parse_openai_response
 | 
			
		||||
 | 
			
		||||
---@type AvanteCurlArgsBuilder
 | 
			
		||||
@ -375,7 +421,7 @@ end
 | 
			
		||||
---@type AvanteAiMessageBuilder
 | 
			
		||||
H.make_deepseek_message = H.make_openai_message
 | 
			
		||||
 | 
			
		||||
---@type AvanteAiResponseParser
 | 
			
		||||
---@type AvanteResponseParser
 | 
			
		||||
H.parse_deepseek_response = H.parse_openai_response
 | 
			
		||||
 | 
			
		||||
---@type AvanteCurlArgsBuilder
 | 
			
		||||
@ -401,7 +447,7 @@ end
 | 
			
		||||
---@type AvanteAiMessageBuilder
 | 
			
		||||
H.make_groq_message = H.make_openai_message
 | 
			
		||||
 | 
			
		||||
---@type AvanteAiResponseParser
 | 
			
		||||
---@type AvanteResponseParser
 | 
			
		||||
H.parse_groq_response = H.parse_openai_response
 | 
			
		||||
 | 
			
		||||
---@type AvanteCurlArgsBuilder
 | 
			
		||||
@ -424,7 +470,7 @@ end
 | 
			
		||||
 | 
			
		||||
------------------------------Logic------------------------------
 | 
			
		||||
 | 
			
		||||
local group = vim.api.nvim_create_augroup("AvanteAiBot", { clear = true })
 | 
			
		||||
local group = vim.api.nvim_create_augroup("AvanteLLM", { clear = true })
 | 
			
		||||
local active_job = nil
 | 
			
		||||
 | 
			
		||||
---@param question string
 | 
			
		||||
@ -433,17 +479,35 @@ local active_job = nil
 | 
			
		||||
---@param selected_content_content string | nil
 | 
			
		||||
---@param on_chunk fun(chunk: string): any
 | 
			
		||||
---@param on_complete fun(err: string|nil): any
 | 
			
		||||
M.invoke_llm_stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
 | 
			
		||||
M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
 | 
			
		||||
  local provider = Config.provider
 | 
			
		||||
  local event_state = nil
 | 
			
		||||
 | 
			
		||||
  ---@type AvanteCurlOutput
 | 
			
		||||
  local spec = H["make_" .. provider .. "_curl_args"]({
 | 
			
		||||
  local code_opts = {
 | 
			
		||||
    question = question,
 | 
			
		||||
    code_lang = code_lang,
 | 
			
		||||
    code_content = code_content,
 | 
			
		||||
    selected_code_content = selected_content_content,
 | 
			
		||||
  })
 | 
			
		||||
  }
 | 
			
		||||
  local handler_opts = vim.deepcopy({ on_chunk = on_chunk, on_complete = on_complete, event_state = event_state }, true)
 | 
			
		||||
 | 
			
		||||
  ---@type AvanteCurlOutput
 | 
			
		||||
  local spec = nil
 | 
			
		||||
 | 
			
		||||
  ---@type AvanteProvider
 | 
			
		||||
  local ProviderConfig = nil
 | 
			
		||||
 | 
			
		||||
  if E.is_default(provider) then
 | 
			
		||||
    spec = H["make_" .. provider .. "_curl_args"](code_opts)
 | 
			
		||||
  else
 | 
			
		||||
    ProviderConfig = Config.vendors[provider]
 | 
			
		||||
    spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
 | 
			
		||||
  end
 | 
			
		||||
  if spec.body.stream == nil then
 | 
			
		||||
    spec = vim.tbl_deep_extend("force", spec, {
 | 
			
		||||
      body = { stream = true },
 | 
			
		||||
    })
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  ---@param line string
 | 
			
		||||
  local function parse_and_call(line)
 | 
			
		||||
@ -454,10 +518,11 @@ M.invoke_llm_stream = function(question, code_lang, code_content, selected_conte
 | 
			
		||||
    end
 | 
			
		||||
    local data_match = line:match("^data: (.+)$")
 | 
			
		||||
    if data_match then
 | 
			
		||||
      H["parse_" .. provider .. "_response"](
 | 
			
		||||
        data_match,
 | 
			
		||||
        vim.deepcopy({ on_chunk = on_chunk, on_complete = on_complete, event_state = event_state }, true)
 | 
			
		||||
      )
 | 
			
		||||
      if ProviderConfig ~= nil then
 | 
			
		||||
        ProviderConfig.parse_response_data(data_match, handler_opts)
 | 
			
		||||
      else
 | 
			
		||||
        H["parse_" .. provider .. "_response"](data_match, handler_opts)
 | 
			
		||||
      end
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
@ -521,7 +586,7 @@ function M.refresh(provider)
 | 
			
		||||
  else
 | 
			
		||||
    vim.notify_once("Switch to provider: " .. provider, vim.log.levels.INFO)
 | 
			
		||||
  end
 | 
			
		||||
  require("avante").setup({ provider = provider })
 | 
			
		||||
  require("avante.config").override({ provider = provider })
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
M.commands = function()
 | 
			
		||||
@ -536,11 +601,25 @@ M.commands = function()
 | 
			
		||||
        return {}
 | 
			
		||||
      end
 | 
			
		||||
      local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or ""
 | 
			
		||||
      -- join two tables
 | 
			
		||||
      local Keys = vim.list_extend(vim.tbl_keys(E.env), vim.tbl_keys(Config.vendors))
 | 
			
		||||
      return vim.tbl_filter(function(key)
 | 
			
		||||
        return key:find(prefix) == 1
 | 
			
		||||
      end, vim.tbl_keys(E.env))
 | 
			
		||||
      end, Keys)
 | 
			
		||||
    end,
 | 
			
		||||
  })
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
return M
 | 
			
		||||
return setmetatable(M, {
 | 
			
		||||
  __index = function(t, k)
 | 
			
		||||
    local h = H[k]
 | 
			
		||||
    if h then
 | 
			
		||||
      return H[k]
 | 
			
		||||
    end
 | 
			
		||||
    local v = t[k]
 | 
			
		||||
    if v then
 | 
			
		||||
      return t[k]
 | 
			
		||||
    end
 | 
			
		||||
    error("Failed to find key: " .. k)
 | 
			
		||||
  end,
 | 
			
		||||
})
 | 
			
		||||
@ -7,7 +7,7 @@ local N = require("nui-components")
 | 
			
		||||
local Config = require("avante.config")
 | 
			
		||||
local View = require("avante.view")
 | 
			
		||||
local Diff = require("avante.diff")
 | 
			
		||||
local AiBot = require("avante.ai_bot")
 | 
			
		||||
local Llm = require("avante.llm")
 | 
			
		||||
local Utils = require("avante.utils")
 | 
			
		||||
 | 
			
		||||
local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated"
 | 
			
		||||
@ -141,7 +141,7 @@ function Sidebar:intialize()
 | 
			
		||||
      mode = { "n" },
 | 
			
		||||
      key = "q",
 | 
			
		||||
      handler = function()
 | 
			
		||||
        api.nvim_exec_autocmds("User", { pattern = AiBot.CANCEL_PATTERN })
 | 
			
		||||
        api.nvim_exec_autocmds("User", { pattern = Llm.CANCEL_PATTERN })
 | 
			
		||||
        self.renderer:close()
 | 
			
		||||
      end,
 | 
			
		||||
    },
 | 
			
		||||
@ -149,7 +149,7 @@ function Sidebar:intialize()
 | 
			
		||||
      mode = { "n" },
 | 
			
		||||
      key = "<Esc>",
 | 
			
		||||
      handler = function()
 | 
			
		||||
        api.nvim_exec_autocmds("User", { pattern = AiBot.CANCEL_PATTERN })
 | 
			
		||||
        api.nvim_exec_autocmds("User", { pattern = Llm.CANCEL_PATTERN })
 | 
			
		||||
        self.renderer:close()
 | 
			
		||||
      end,
 | 
			
		||||
    },
 | 
			
		||||
@ -245,6 +245,9 @@ end
 | 
			
		||||
---@param content string concatenated content of the buffer
 | 
			
		||||
---@param opts? {focus?: boolean, stream?: boolean, scroll?: boolean, callback?: fun(): nil} whether to focus the result view
 | 
			
		||||
function Sidebar:update_content(content, opts)
 | 
			
		||||
  if not self.view.buf then
 | 
			
		||||
    return
 | 
			
		||||
  end
 | 
			
		||||
  opts = vim.tbl_deep_extend("force", { focus = true, scroll = true, stream = false, callback = nil }, opts or {})
 | 
			
		||||
  if opts.stream then
 | 
			
		||||
    vim.schedule(function()
 | 
			
		||||
@ -643,9 +646,16 @@ function Sidebar:render()
 | 
			
		||||
    signal.is_loading = true
 | 
			
		||||
    local state = signal:get_value()
 | 
			
		||||
    local request = state.text
 | 
			
		||||
    ---@type string
 | 
			
		||||
    local model
 | 
			
		||||
 | 
			
		||||
    local provider_config = Config[Config.provider]
 | 
			
		||||
    local model = provider_config and provider_config.model or "default"
 | 
			
		||||
    local builtins_provider_config = Config[Config.provider]
 | 
			
		||||
    if builtins_provider_config ~= nil then
 | 
			
		||||
      model = builtins_provider_config.model
 | 
			
		||||
    else
 | 
			
		||||
      local vendor_provider_config = Config.vendors[Config.provider]
 | 
			
		||||
      model = vendor_provider_config and vendor_provider_config.model or "default"
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    local timestamp = get_timestamp()
 | 
			
		||||
 | 
			
		||||
@ -670,50 +680,43 @@ function Sidebar:render()
 | 
			
		||||
 | 
			
		||||
    local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
 | 
			
		||||
 | 
			
		||||
    AiBot.invoke_llm_stream(
 | 
			
		||||
      request,
 | 
			
		||||
      filetype,
 | 
			
		||||
      content_with_line_numbers,
 | 
			
		||||
      selected_code_content_with_line_numbers,
 | 
			
		||||
      function(chunk)
 | 
			
		||||
        signal.is_loading = true
 | 
			
		||||
        full_response = full_response .. chunk
 | 
			
		||||
        self:update_content(chunk, { stream = true, scroll = false })
 | 
			
		||||
        vim.schedule(function()
 | 
			
		||||
          vim.cmd("redraw")
 | 
			
		||||
        end)
 | 
			
		||||
      end,
 | 
			
		||||
      function(err)
 | 
			
		||||
        signal.is_loading = false
 | 
			
		||||
    Llm.stream(request, filetype, content_with_line_numbers, selected_code_content_with_line_numbers, function(chunk)
 | 
			
		||||
      signal.is_loading = true
 | 
			
		||||
      full_response = full_response .. chunk
 | 
			
		||||
      self:update_content(chunk, { stream = true, scroll = false })
 | 
			
		||||
      vim.schedule(function()
 | 
			
		||||
        vim.cmd("redraw")
 | 
			
		||||
      end)
 | 
			
		||||
    end, function(err)
 | 
			
		||||
      signal.is_loading = false
 | 
			
		||||
 | 
			
		||||
        if err ~= nil then
 | 
			
		||||
          self:update_content(content_prefix .. full_response .. "\n\n🚨 Error: " .. vim.inspect(err))
 | 
			
		||||
          return
 | 
			
		||||
        end
 | 
			
		||||
 | 
			
		||||
        -- Execute when the stream request is actually completed
 | 
			
		||||
        self:update_content(
 | 
			
		||||
          content_prefix
 | 
			
		||||
            .. full_response
 | 
			
		||||
            .. "\n\n🎉🎉🎉 **Generation complete!** Please review the code suggestions above.\n\n",
 | 
			
		||||
          {
 | 
			
		||||
            callback = function()
 | 
			
		||||
              api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN })
 | 
			
		||||
            end,
 | 
			
		||||
          }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        -- Save chat history
 | 
			
		||||
        table.insert(chat_history or {}, {
 | 
			
		||||
          timestamp = timestamp,
 | 
			
		||||
          provider = Config.provider,
 | 
			
		||||
          model = model,
 | 
			
		||||
          request = request,
 | 
			
		||||
          response = full_response,
 | 
			
		||||
        })
 | 
			
		||||
        save_chat_history(self, chat_history)
 | 
			
		||||
      if err ~= nil then
 | 
			
		||||
        self:update_content(content_prefix .. full_response .. "\n\n🚨 Error: " .. vim.inspect(err))
 | 
			
		||||
        return
 | 
			
		||||
      end
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
      -- Execute when the stream request is actually completed
 | 
			
		||||
      self:update_content(
 | 
			
		||||
        content_prefix
 | 
			
		||||
          .. full_response
 | 
			
		||||
          .. "\n\n🎉🎉🎉 **Generation complete!** Please review the code suggestions above.\n\n",
 | 
			
		||||
        {
 | 
			
		||||
          callback = function()
 | 
			
		||||
            api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN })
 | 
			
		||||
          end,
 | 
			
		||||
        }
 | 
			
		||||
      )
 | 
			
		||||
 | 
			
		||||
      -- Save chat history
 | 
			
		||||
      table.insert(chat_history or {}, {
 | 
			
		||||
        timestamp = timestamp,
 | 
			
		||||
        provider = Config.provider,
 | 
			
		||||
        model = model,
 | 
			
		||||
        request = request,
 | 
			
		||||
        response = full_response,
 | 
			
		||||
      })
 | 
			
		||||
      save_chat_history(self, chat_history)
 | 
			
		||||
    end)
 | 
			
		||||
 | 
			
		||||
    if Config.behaviour.auto_apply_diff_after_generation then
 | 
			
		||||
      apply()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user