refactor(ai): streaming chunks to avoid excessive redraw. (#73)
* perf(ai): token streaming with quick refactoring Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * fix: window resize and AvanteSwitchProvider Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * revert: config change Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * chore: return early Signed-off-by: Aaron Pham <contact@aarnphm.xyz> --------- Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
		
							parent
							
								
									0fddfc7d8f
								
							
						
					
					
						commit
						5fa4f701dd
					
				@ -1,4 +1,3 @@
 | 
				
			|||||||
local fn = vim.fn
 | 
					 | 
				
			||||||
local api = vim.api
 | 
					local api = vim.api
 | 
				
			||||||
 | 
					
 | 
				
			||||||
local curl = require("plenary.curl")
 | 
					local curl = require("plenary.curl")
 | 
				
			||||||
@ -6,10 +5,17 @@ local curl = require("plenary.curl")
 | 
				
			|||||||
local Utils = require("avante.utils")
 | 
					local Utils = require("avante.utils")
 | 
				
			||||||
local Config = require("avante.config")
 | 
					local Config = require("avante.config")
 | 
				
			||||||
local Tiktoken = require("avante.tiktoken")
 | 
					local Tiktoken = require("avante.tiktoken")
 | 
				
			||||||
 | 
					local Dressing = require("avante.ui.dressing")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@private
 | 
				
			||||||
 | 
					---@class AvanteAiBotInternal
 | 
				
			||||||
 | 
					local H = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---@class avante.AiBot
 | 
					---@class avante.AiBot
 | 
				
			||||||
local M = {}
 | 
					local M = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					M.CANCEL_PATTERN = "AvanteAiBotEscape"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---@class EnvironmentHandler: table<[Provider], string>
 | 
					---@class EnvironmentHandler: table<[Provider], string>
 | 
				
			||||||
local E = {
 | 
					local E = {
 | 
				
			||||||
  ---@type table<Provider, string>
 | 
					  ---@type table<Provider, string>
 | 
				
			||||||
@ -20,7 +26,6 @@ local E = {
 | 
				
			|||||||
    deepseek = "DEEPSEEK_API_KEY",
 | 
					    deepseek = "DEEPSEEK_API_KEY",
 | 
				
			||||||
    groq = "GROQ_API_KEY",
 | 
					    groq = "GROQ_API_KEY",
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
  _once = false,
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
E = setmetatable(E, {
 | 
					E = setmetatable(E, {
 | 
				
			||||||
@ -29,30 +34,32 @@ E = setmetatable(E, {
 | 
				
			|||||||
    return os.getenv(E.env[k]) and true or false
 | 
					    return os.getenv(E.env[k]) and true or false
 | 
				
			||||||
  end,
 | 
					  end,
 | 
				
			||||||
})
 | 
					})
 | 
				
			||||||
 | 
					E._once = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
--- return the environment variable name for the given provider
 | 
					--- return the environment variable name for the given provider
 | 
				
			||||||
---@param provider? Provider
 | 
					---@param provider? Provider
 | 
				
			||||||
---@return string the envvar key
 | 
					---@return string the envvar key
 | 
				
			||||||
E.key = function(provider)
 | 
					E.key = function(provider)
 | 
				
			||||||
  provider = provider or Config.provider
 | 
					  return E.env[provider or Config.provider]
 | 
				
			||||||
  local var = E.env[provider]
 | 
					 | 
				
			||||||
  return type(var) == "table" and var[1] ---@cast var string
 | 
					 | 
				
			||||||
    or var
 | 
					 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
E.setup = function(var)
 | 
					---@param provider? Provider
 | 
				
			||||||
  local Dressing = require("avante.ui.dressing")
 | 
					E.value = function(provider)
 | 
				
			||||||
 | 
					  return os.getenv(E.key(provider or Config.provider))
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if E._once then
 | 
					--- intialize the environment variable for current neovim session.
 | 
				
			||||||
    return
 | 
					--- This will only run once and spawn a UI for users to input the envvar.
 | 
				
			||||||
  end
 | 
					---@param var Provider supported providers
 | 
				
			||||||
 | 
					---@param refresh? boolean
 | 
				
			||||||
 | 
					E.setup = function(var, refresh)
 | 
				
			||||||
 | 
					  refresh = refresh or false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  ---@param value string
 | 
					  ---@param value string
 | 
				
			||||||
  ---@return nil
 | 
					  ---@return nil
 | 
				
			||||||
  local function on_confirm(value)
 | 
					  local function on_confirm(value)
 | 
				
			||||||
    if value then
 | 
					    if value then
 | 
				
			||||||
      vim.fn.setenv(var, value)
 | 
					      vim.fn.setenv(var, value)
 | 
				
			||||||
      E._once = true
 | 
					 | 
				
			||||||
    else
 | 
					    else
 | 
				
			||||||
      if not E[Config.provider] then
 | 
					      if not E[Config.provider] then
 | 
				
			||||||
        vim.notify_once("Failed to set " .. var .. ". Avante won't work as expected", vim.log.levels.WARN)
 | 
					        vim.notify_once("Failed to set " .. var .. ". Avante won't work as expected", vim.log.levels.WARN)
 | 
				
			||||||
@ -60,35 +67,45 @@ E.setup = function(var)
 | 
				
			|||||||
    end
 | 
					    end
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  api.nvim_create_autocmd({ "BufEnter", "BufWinEnter" }, {
 | 
					  if refresh then
 | 
				
			||||||
    pattern = "*",
 | 
					    vim.defer_fn(function()
 | 
				
			||||||
    once = true,
 | 
					      Dressing.initialize_input_buffer({ opts = { prompt = "Enter " .. var .. ": " }, on_confirm = on_confirm })
 | 
				
			||||||
    callback = function()
 | 
					    end, 200)
 | 
				
			||||||
      vim.defer_fn(function()
 | 
					  elseif not E._once then
 | 
				
			||||||
        -- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
 | 
					    E._once = true
 | 
				
			||||||
        local exclude_buftypes = { "dashboard", "alpha", "qf", "nofile" }
 | 
					    api.nvim_create_autocmd({ "BufEnter", "BufWinEnter" }, {
 | 
				
			||||||
        local exclude_filetypes = {
 | 
					      pattern = "*",
 | 
				
			||||||
          "NvimTree",
 | 
					      once = true,
 | 
				
			||||||
          "Outline",
 | 
					      callback = function()
 | 
				
			||||||
          "help",
 | 
					        vim.defer_fn(function()
 | 
				
			||||||
          "dashboard",
 | 
					          -- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
 | 
				
			||||||
          "alpha",
 | 
					          local exclude_buftypes = { "dashboard", "alpha", "qf", "nofile" }
 | 
				
			||||||
          "qf",
 | 
					          local exclude_filetypes = {
 | 
				
			||||||
          "ministarter",
 | 
					            "NvimTree",
 | 
				
			||||||
          "TelescopePrompt",
 | 
					            "Outline",
 | 
				
			||||||
          "gitcommit",
 | 
					            "help",
 | 
				
			||||||
        }
 | 
					            "dashboard",
 | 
				
			||||||
        if
 | 
					            "alpha",
 | 
				
			||||||
          not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
 | 
					            "qf",
 | 
				
			||||||
          and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)
 | 
					            "ministarter",
 | 
				
			||||||
        then
 | 
					            "TelescopePrompt",
 | 
				
			||||||
          Dressing.initialize_input_buffer({ opts = { prompt = "Enter " .. var .. ": " }, on_confirm = on_confirm })
 | 
					            "gitcommit",
 | 
				
			||||||
        end
 | 
					            "gitrebase",
 | 
				
			||||||
      end, 200)
 | 
					          }
 | 
				
			||||||
    end,
 | 
					          if
 | 
				
			||||||
  })
 | 
					            not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
 | 
				
			||||||
 | 
					            and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)
 | 
				
			||||||
 | 
					          then
 | 
				
			||||||
 | 
					            Dressing.initialize_input_buffer({ opts = { prompt = "Enter " .. var .. ": " }, on_confirm = on_confirm })
 | 
				
			||||||
 | 
					          end
 | 
				
			||||||
 | 
					        end, 200)
 | 
				
			||||||
 | 
					      end,
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					------------------------------Prompt and type------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
local system_prompt = [[
 | 
					local system_prompt = [[
 | 
				
			||||||
You are an excellent programming expert.
 | 
					You are an excellent programming expert.
 | 
				
			||||||
]]
 | 
					]]
 | 
				
			||||||
@ -137,38 +154,49 @@ Replace lines: {{start_line}}-{{end_line}}
 | 
				
			|||||||
Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift.
 | 
					Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift.
 | 
				
			||||||
]]
 | 
					]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
local function call_claude_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete)
 | 
					---@class AvantePromptOptions: table<[string], string>
 | 
				
			||||||
  local api_key = os.getenv(E.key("claude"))
 | 
					---@field question string
 | 
				
			||||||
 | 
					---@field code_lang string
 | 
				
			||||||
 | 
					---@field code_content string
 | 
				
			||||||
 | 
					---@field selected_code_content? string
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					---@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): {role: "user" | "system", content: string | table<string, any>}[]
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					---@class AvanteCurlOutput: {url: string, body: table<string, any> | string, headers: table<string, string>}
 | 
				
			||||||
 | 
					---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					---@class ResponseParser
 | 
				
			||||||
 | 
					---@field event_state string
 | 
				
			||||||
 | 
					---@field on_chunk fun(chunk: string): any
 | 
				
			||||||
 | 
					---@field on_complete fun(err: string|nil): any
 | 
				
			||||||
 | 
					---@field on_error? fun(err_type: string): nil
 | 
				
			||||||
 | 
					---@alias AvanteAiResponseParser fun(data_stream: string, opts: ResponseParser): nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  local tokens = Config.claude.max_tokens
 | 
					------------------------------Anthropic------------------------------
 | 
				
			||||||
  local headers = {
 | 
					 | 
				
			||||||
    ["Content-Type"] = "application/json",
 | 
					 | 
				
			||||||
    ["x-api-key"] = api_key,
 | 
					 | 
				
			||||||
    ["anthropic-version"] = "2023-06-01",
 | 
					 | 
				
			||||||
    ["anthropic-beta"] = "prompt-caching-2024-07-31",
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteAiMessageBuilder
 | 
				
			||||||
 | 
					H.make_claude_message = function(opts)
 | 
				
			||||||
  local code_prompt_obj = {
 | 
					  local code_prompt_obj = {
 | 
				
			||||||
    type = "text",
 | 
					    type = "text",
 | 
				
			||||||
    text = string.format("<code>```%s\n%s```</code>", code_lang, code_content),
 | 
					    text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if Tiktoken.count(code_prompt_obj.text) > 1024 then
 | 
					  if Tiktoken.count(code_prompt_obj.text) > 1024 then
 | 
				
			||||||
    code_prompt_obj.cache_control = { type = "ephemeral" }
 | 
					    code_prompt_obj.cache_control = { type = "ephemeral" }
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if selected_code_content then
 | 
					  if opts.selected_code_content then
 | 
				
			||||||
    code_prompt_obj.text = string.format("<code_context>```%s\n%s```</code_context>", code_lang, code_content)
 | 
					    code_prompt_obj.text = string.format("<code_context>```%s\n%s```</code_context>", opts.code_lang, opts.code_content)
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  local message_content = {
 | 
					  local message_content = {
 | 
				
			||||||
    code_prompt_obj,
 | 
					    code_prompt_obj,
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if selected_code_content then
 | 
					  if opts.selected_code_content then
 | 
				
			||||||
    local selected_code_obj = {
 | 
					    local selected_code_obj = {
 | 
				
			||||||
      type = "text",
 | 
					      type = "text",
 | 
				
			||||||
      text = string.format("<code>```%s\n%s```</code>", code_lang, selected_code_content),
 | 
					      text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.selected_code_content),
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if Tiktoken.count(selected_code_obj.text) > 1024 then
 | 
					    if Tiktoken.count(selected_code_obj.text) > 1024 then
 | 
				
			||||||
@ -180,7 +208,7 @@ local function call_claude_api_stream(question, code_lang, code_content, selecte
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  table.insert(message_content, {
 | 
					  table.insert(message_content, {
 | 
				
			||||||
    type = "text",
 | 
					    type = "text",
 | 
				
			||||||
    text = string.format("<question>%s</question>", question),
 | 
					    text = string.format("<question>%s</question>", opts.question),
 | 
				
			||||||
  })
 | 
					  })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  local user_prompt = base_user_prompt
 | 
					  local user_prompt = base_user_prompt
 | 
				
			||||||
@ -196,220 +224,284 @@ local function call_claude_api_stream(question, code_lang, code_content, selecte
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  table.insert(message_content, user_prompt_obj)
 | 
					  table.insert(message_content, user_prompt_obj)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  local body = {
 | 
					  return {
 | 
				
			||||||
    model = Config.claude.model,
 | 
					    {
 | 
				
			||||||
    system = system_prompt,
 | 
					      role = "user",
 | 
				
			||||||
    messages = {
 | 
					      content = message_content,
 | 
				
			||||||
      {
 | 
					 | 
				
			||||||
        role = "user",
 | 
					 | 
				
			||||||
        content = message_content,
 | 
					 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
    stream = true,
 | 
					 | 
				
			||||||
    temperature = Config.claude.temperature,
 | 
					 | 
				
			||||||
    max_tokens = tokens,
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					 | 
				
			||||||
  local url = Utils.trim_suffix(Config.claude.endpoint, "/") .. "/v1/messages"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  curl.post(url, {
 | 
					 | 
				
			||||||
    ---@diagnostic disable-next-line: unused-local
 | 
					 | 
				
			||||||
    stream = function(err, data, job)
 | 
					 | 
				
			||||||
      if err then
 | 
					 | 
				
			||||||
        on_complete(err)
 | 
					 | 
				
			||||||
        return
 | 
					 | 
				
			||||||
      end
 | 
					 | 
				
			||||||
      if not data then
 | 
					 | 
				
			||||||
        return
 | 
					 | 
				
			||||||
      end
 | 
					 | 
				
			||||||
      for _, line in ipairs(vim.split(data, "\n")) do
 | 
					 | 
				
			||||||
        if line:sub(1, 6) ~= "data: " then
 | 
					 | 
				
			||||||
          return
 | 
					 | 
				
			||||||
        end
 | 
					 | 
				
			||||||
        vim.schedule(function()
 | 
					 | 
				
			||||||
          local success, parsed = pcall(fn.json_decode, line:sub(7))
 | 
					 | 
				
			||||||
          if not success then
 | 
					 | 
				
			||||||
            error("Error: failed to parse json: " .. parsed)
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
          end
 | 
					 | 
				
			||||||
          if parsed and parsed.type == "content_block_delta" then
 | 
					 | 
				
			||||||
            on_chunk(parsed.delta.text)
 | 
					 | 
				
			||||||
          elseif parsed and parsed.type == "message_stop" then
 | 
					 | 
				
			||||||
            -- Stream request completed
 | 
					 | 
				
			||||||
            on_complete(nil)
 | 
					 | 
				
			||||||
          elseif parsed and parsed.type == "error" then
 | 
					 | 
				
			||||||
            -- Stream request completed
 | 
					 | 
				
			||||||
            on_complete(parsed)
 | 
					 | 
				
			||||||
          end
 | 
					 | 
				
			||||||
        end)
 | 
					 | 
				
			||||||
      end
 | 
					 | 
				
			||||||
    end,
 | 
					 | 
				
			||||||
    headers = headers,
 | 
					 | 
				
			||||||
    body = fn.json_encode(body),
 | 
					 | 
				
			||||||
  })
 | 
					 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
local function call_openai_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete)
 | 
					---@type AvanteAiResponseParser
 | 
				
			||||||
  local api_key = os.getenv(E.key("openai"))
 | 
					H.parse_claude_response = function(data_stream, opts)
 | 
				
			||||||
 | 
					  if opts.event_state == "content_block_delta" then
 | 
				
			||||||
 | 
					    local json = vim.json.decode(data_stream)
 | 
				
			||||||
 | 
					    opts.on_chunk(json.delta.text)
 | 
				
			||||||
 | 
					  elseif opts.event_state == "message_stop" then
 | 
				
			||||||
 | 
					    opts.on_complete(nil)
 | 
				
			||||||
 | 
					    return
 | 
				
			||||||
 | 
					  elseif opts.event_state == "error" then
 | 
				
			||||||
 | 
					    opts.on_complete(vim.json.decode(data_stream))
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteCurlArgsBuilder
 | 
				
			||||||
 | 
					H.make_claude_curl_args = function(code_opts)
 | 
				
			||||||
 | 
					  return {
 | 
				
			||||||
 | 
					    url = Utils.trim(Config.claude.endpoint, { suffix = "/" }) .. "/v1/messages",
 | 
				
			||||||
 | 
					    headers = {
 | 
				
			||||||
 | 
					      ["Content-Type"] = "application/json",
 | 
				
			||||||
 | 
					      ["x-api-key"] = E.value("claude"),
 | 
				
			||||||
 | 
					      ["anthropic-version"] = "2023-06-01",
 | 
				
			||||||
 | 
					      ["anthropic-beta"] = "prompt-caching-2024-07-31",
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    body = {
 | 
				
			||||||
 | 
					      model = Config.claude.model,
 | 
				
			||||||
 | 
					      system = system_prompt,
 | 
				
			||||||
 | 
					      stream = true,
 | 
				
			||||||
 | 
					      messages = H.make_claude_message(code_opts),
 | 
				
			||||||
 | 
					      temperature = Config.claude.temperature,
 | 
				
			||||||
 | 
					      max_tokens = Config.claude.max_tokens,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					------------------------------OpenAI------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteAiMessageBuilder
 | 
				
			||||||
 | 
					H.make_openai_message = function(opts)
 | 
				
			||||||
  local user_prompt = base_user_prompt
 | 
					  local user_prompt = base_user_prompt
 | 
				
			||||||
    .. "\n\nCODE:\n"
 | 
					    .. "\n\nCODE:\n"
 | 
				
			||||||
    .. "```"
 | 
					    .. "```"
 | 
				
			||||||
    .. code_lang
 | 
					    .. opts.code_lang
 | 
				
			||||||
    .. "\n"
 | 
					    .. "\n"
 | 
				
			||||||
    .. code_content
 | 
					    .. opts.code_content
 | 
				
			||||||
    .. "\n```"
 | 
					    .. "\n```"
 | 
				
			||||||
    .. "\n\nQUESTION:\n"
 | 
					    .. "\n\nQUESTION:\n"
 | 
				
			||||||
    .. question
 | 
					    .. opts.question
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if selected_code_content then
 | 
					  if opts.selected_code_content ~= nil then
 | 
				
			||||||
    user_prompt = base_user_prompt
 | 
					    user_prompt = base_user_prompt
 | 
				
			||||||
      .. "\n\nCODE CONTEXT:\n"
 | 
					      .. "\n\nCODE CONTEXT:\n"
 | 
				
			||||||
      .. "```"
 | 
					      .. "```"
 | 
				
			||||||
      .. code_lang
 | 
					      .. opts.code_lang
 | 
				
			||||||
      .. "\n"
 | 
					      .. "\n"
 | 
				
			||||||
      .. code_content
 | 
					      .. opts.code_content
 | 
				
			||||||
      .. "\n```"
 | 
					      .. "\n```"
 | 
				
			||||||
      .. "\n\nCODE:\n"
 | 
					      .. "\n\nCODE:\n"
 | 
				
			||||||
      .. "```"
 | 
					      .. "```"
 | 
				
			||||||
      .. code_lang
 | 
					      .. opts.code_lang
 | 
				
			||||||
      .. "\n"
 | 
					      .. "\n"
 | 
				
			||||||
      .. selected_code_content
 | 
					      .. opts.selected_code_content
 | 
				
			||||||
      .. "\n```"
 | 
					      .. "\n```"
 | 
				
			||||||
      .. "\n\nQUESTION:\n"
 | 
					      .. "\n\nQUESTION:\n"
 | 
				
			||||||
      .. question
 | 
					      .. opts.question
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  local url, headers, body
 | 
					  return {
 | 
				
			||||||
  if Config.provider == "azure" then
 | 
					    { role = "system", content = system_prompt },
 | 
				
			||||||
    api_key = os.getenv(E.key("azure"))
 | 
					    { role = "user", content = user_prompt },
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteAiResponseParser
 | 
				
			||||||
 | 
					H.parse_openai_response = function(data_stream, opts)
 | 
				
			||||||
 | 
					  if data_stream:match('"%[DONE%]":') then
 | 
				
			||||||
 | 
					    opts.on_complete(nil)
 | 
				
			||||||
 | 
					    return
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					  if data_stream:match('"delta":') then
 | 
				
			||||||
 | 
					    local json = vim.json.decode(data_stream)
 | 
				
			||||||
 | 
					    if json.choices and json.choices[1] then
 | 
				
			||||||
 | 
					      local choice = json.choices[1]
 | 
				
			||||||
 | 
					      if choice.finish_reason == "stop" then
 | 
				
			||||||
 | 
					        opts.on_complete(nil)
 | 
				
			||||||
 | 
					      elseif choice.delta.content then
 | 
				
			||||||
 | 
					        opts.on_chunk(choice.delta.content)
 | 
				
			||||||
 | 
					      end
 | 
				
			||||||
 | 
					    end
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteCurlArgsBuilder
 | 
				
			||||||
 | 
					H.make_openai_curl_args = function(code_opts)
 | 
				
			||||||
 | 
					  return {
 | 
				
			||||||
 | 
					    url = Utils.trim(Config.openai.endpoint, { suffix = "/" }) .. "/v1/chat/completions",
 | 
				
			||||||
 | 
					    headers = {
 | 
				
			||||||
 | 
					      ["Content-Type"] = "application/json",
 | 
				
			||||||
 | 
					      ["Authorization"] = "Bearer " .. E.value("openai"),
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    body = {
 | 
				
			||||||
 | 
					      model = Config.openai.model,
 | 
				
			||||||
 | 
					      messages = H.make_openai_message(code_opts),
 | 
				
			||||||
 | 
					      temperature = Config.openai.temperature,
 | 
				
			||||||
 | 
					      max_tokens = Config.openai.max_tokens,
 | 
				
			||||||
 | 
					      stream = true,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					------------------------------Azure------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteAiMessageBuilder
 | 
				
			||||||
 | 
					H.make_azure_message = H.make_openai_message
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteAiResponseParser
 | 
				
			||||||
 | 
					H.parse_azure_response = H.parse_openai_response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteCurlArgsBuilder
 | 
				
			||||||
 | 
					H.make_azure_curl_args = function(code_opts)
 | 
				
			||||||
 | 
					  return {
 | 
				
			||||||
    url = Config.azure.endpoint
 | 
					    url = Config.azure.endpoint
 | 
				
			||||||
      .. "/openai/deployments/"
 | 
					      .. "/openai/deployments/"
 | 
				
			||||||
      .. Config.azure.deployment
 | 
					      .. Config.azure.deployment
 | 
				
			||||||
      .. "/chat/completions?api-version="
 | 
					      .. "/chat/completions?api-version="
 | 
				
			||||||
      .. Config.azure.api_version
 | 
					      .. Config.azure.api_version,
 | 
				
			||||||
    headers = {
 | 
					    headers = {
 | 
				
			||||||
      ["Content-Type"] = "application/json",
 | 
					      ["Content-Type"] = "application/json",
 | 
				
			||||||
      ["api-key"] = api_key,
 | 
					      ["api-key"] = E.value("azure"),
 | 
				
			||||||
    }
 | 
					    },
 | 
				
			||||||
    body = {
 | 
					    body = {
 | 
				
			||||||
      messages = {
 | 
					      messages = H.make_openai_message(code_opts),
 | 
				
			||||||
        { role = "system", content = system_prompt },
 | 
					 | 
				
			||||||
        { role = "user", content = user_prompt },
 | 
					 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
      temperature = Config.azure.temperature,
 | 
					      temperature = Config.azure.temperature,
 | 
				
			||||||
      max_tokens = Config.azure.max_tokens,
 | 
					      max_tokens = Config.azure.max_tokens,
 | 
				
			||||||
      stream = true,
 | 
					      stream = true,
 | 
				
			||||||
    }
 | 
					    },
 | 
				
			||||||
  elseif Config.provider == "deepseek" then
 | 
					  }
 | 
				
			||||||
    api_key = os.getenv(E.key("deepseek"))
 | 
					end
 | 
				
			||||||
    url = Utils.trim_suffix(Config.deepseek.endpoint, "/") .. "/chat/completions"
 | 
					
 | 
				
			||||||
 | 
					------------------------------Deepseek------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteAiMessageBuilder
 | 
				
			||||||
 | 
					H.make_deepseek_message = H.make_openai_message
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteAiResponseParser
 | 
				
			||||||
 | 
					H.parse_deepseek_response = H.parse_openai_response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteCurlArgsBuilder
 | 
				
			||||||
 | 
					H.make_deepseek_curl_args = function(code_opts)
 | 
				
			||||||
 | 
					  return {
 | 
				
			||||||
 | 
					    url = Utils.trim(Config.deepseek.endpoint, { suffix = "/" }) .. "/chat/completions",
 | 
				
			||||||
    headers = {
 | 
					    headers = {
 | 
				
			||||||
      ["Content-Type"] = "application/json",
 | 
					      ["Content-Type"] = "application/json",
 | 
				
			||||||
      ["Authorization"] = "Bearer " .. api_key,
 | 
					      ["Authorization"] = "Bearer " .. E.value("deepseek"),
 | 
				
			||||||
    }
 | 
					    },
 | 
				
			||||||
    body = {
 | 
					    body = {
 | 
				
			||||||
      model = Config.deepseek.model,
 | 
					      model = Config.deepseek.model,
 | 
				
			||||||
      messages = {
 | 
					      messages = H.make_openai_message(code_opts),
 | 
				
			||||||
        { role = "system", content = system_prompt },
 | 
					 | 
				
			||||||
        { role = "user", content = user_prompt },
 | 
					 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
      temperature = Config.deepseek.temperature,
 | 
					      temperature = Config.deepseek.temperature,
 | 
				
			||||||
      max_tokens = Config.deepseek.max_tokens,
 | 
					      max_tokens = Config.deepseek.max_tokens,
 | 
				
			||||||
      stream = true,
 | 
					      stream = true,
 | 
				
			||||||
    }
 | 
					    },
 | 
				
			||||||
  elseif Config.provider == "groq" then
 | 
					  }
 | 
				
			||||||
    api_key = os.getenv(E.key("groq"))
 | 
					end
 | 
				
			||||||
    url = Utils.trim_suffix(Config.groq.endpoint, "/") .. "/openai/v1/chat/completions"
 | 
					
 | 
				
			||||||
 | 
					------------------------------Grok------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteAiMessageBuilder
 | 
				
			||||||
 | 
					H.make_groq_message = H.make_openai_message
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteAiResponseParser
 | 
				
			||||||
 | 
					H.parse_groq_response = H.parse_openai_response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@type AvanteCurlArgsBuilder
 | 
				
			||||||
 | 
					H.make_groq_curl_args = function(code_opts)
 | 
				
			||||||
 | 
					  return {
 | 
				
			||||||
 | 
					    url = Utils.trim(Config.groq.endpoint, { suffix = "/" }) .. "/openai/v1/chat/completions",
 | 
				
			||||||
    headers = {
 | 
					    headers = {
 | 
				
			||||||
      ["Content-Type"] = "application/json",
 | 
					      ["Content-Type"] = "application/json",
 | 
				
			||||||
      ["Authorization"] = "Bearer " .. api_key,
 | 
					      ["Authorization"] = "Bearer " .. E.value("groq"),
 | 
				
			||||||
    }
 | 
					    },
 | 
				
			||||||
    body = {
 | 
					    body = {
 | 
				
			||||||
      model = Config.groq.model,
 | 
					      model = Config.groq.model,
 | 
				
			||||||
      messages = {
 | 
					      messages = H.make_openai_message(code_opts),
 | 
				
			||||||
        { role = "system", content = system_prompt },
 | 
					 | 
				
			||||||
        { role = "user", content = user_prompt },
 | 
					 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
      temperature = Config.groq.temperature,
 | 
					      temperature = Config.groq.temperature,
 | 
				
			||||||
      max_tokens = Config.groq.max_tokens,
 | 
					      max_tokens = Config.groq.max_tokens,
 | 
				
			||||||
      stream = true,
 | 
					      stream = true,
 | 
				
			||||||
    }
 | 
					    },
 | 
				
			||||||
  else
 | 
					  }
 | 
				
			||||||
    url = Utils.trim_suffix(Config.openai.endpoint, "/") .. "/v1/chat/completions"
 | 
					 | 
				
			||||||
    headers = {
 | 
					 | 
				
			||||||
      ["Content-Type"] = "application/json",
 | 
					 | 
				
			||||||
      ["Authorization"] = "Bearer " .. api_key,
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    body = {
 | 
					 | 
				
			||||||
      model = Config.openai.model,
 | 
					 | 
				
			||||||
      messages = {
 | 
					 | 
				
			||||||
        { role = "system", content = system_prompt },
 | 
					 | 
				
			||||||
        { role = "user", content = user_prompt },
 | 
					 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
      temperature = Config.openai.temperature,
 | 
					 | 
				
			||||||
      max_tokens = Config.openai.max_tokens,
 | 
					 | 
				
			||||||
      stream = true,
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  end
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  curl.post(url, {
 | 
					 | 
				
			||||||
    ---@diagnostic disable-next-line: unused-local
 | 
					 | 
				
			||||||
    stream = function(err, data, job)
 | 
					 | 
				
			||||||
      if err then
 | 
					 | 
				
			||||||
        on_complete(err)
 | 
					 | 
				
			||||||
        return
 | 
					 | 
				
			||||||
      end
 | 
					 | 
				
			||||||
      if not data then
 | 
					 | 
				
			||||||
        return
 | 
					 | 
				
			||||||
      end
 | 
					 | 
				
			||||||
      for _, line in ipairs(vim.split(data, "\n")) do
 | 
					 | 
				
			||||||
        if line:sub(1, 6) ~= "data: " then
 | 
					 | 
				
			||||||
          return
 | 
					 | 
				
			||||||
        end
 | 
					 | 
				
			||||||
        vim.schedule(function()
 | 
					 | 
				
			||||||
          local piece = line:sub(7)
 | 
					 | 
				
			||||||
          local success, parsed = pcall(fn.json_decode, piece)
 | 
					 | 
				
			||||||
          if not success then
 | 
					 | 
				
			||||||
            if piece == "[DONE]" then
 | 
					 | 
				
			||||||
              on_complete(nil)
 | 
					 | 
				
			||||||
              return
 | 
					 | 
				
			||||||
            end
 | 
					 | 
				
			||||||
            error("Error: failed to parse json: " .. parsed)
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
          end
 | 
					 | 
				
			||||||
          if parsed and parsed.choices and parsed.choices[1] then
 | 
					 | 
				
			||||||
            local choice = parsed.choices[1]
 | 
					 | 
				
			||||||
            if choice.finish_reason == "stop" then
 | 
					 | 
				
			||||||
              on_complete(nil)
 | 
					 | 
				
			||||||
            elseif choice.delta and choice.delta.content then
 | 
					 | 
				
			||||||
              on_chunk(choice.delta.content)
 | 
					 | 
				
			||||||
            end
 | 
					 | 
				
			||||||
          end
 | 
					 | 
				
			||||||
        end)
 | 
					 | 
				
			||||||
      end
 | 
					 | 
				
			||||||
    end,
 | 
					 | 
				
			||||||
    headers = headers,
 | 
					 | 
				
			||||||
    body = fn.json_encode(body),
 | 
					 | 
				
			||||||
  })
 | 
					 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					------------------------------Logic------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					local group = vim.api.nvim_create_augroup("AvanteAiBot", { clear = true })
 | 
				
			||||||
 | 
					local active_job = nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---@param question string
 | 
					---@param question string
 | 
				
			||||||
---@param code_lang string
 | 
					---@param code_lang string
 | 
				
			||||||
---@param code_content string
 | 
					---@param code_content string
 | 
				
			||||||
---@param selected_content_content string | nil
 | 
					---@param selected_content_content string | nil
 | 
				
			||||||
---@param on_chunk fun(chunk: string): any
 | 
					---@param on_chunk fun(chunk: string): any
 | 
				
			||||||
---@param on_complete fun(err: string|nil): any
 | 
					---@param on_complete fun(err: string|nil): any
 | 
				
			||||||
function M.call_ai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
 | 
					M.invoke_llm_stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
 | 
				
			||||||
  if
 | 
					  local provider = Config.provider
 | 
				
			||||||
    Config.provider == "openai"
 | 
					  local event_state = nil
 | 
				
			||||||
    or Config.provider == "azure"
 | 
					
 | 
				
			||||||
    or Config.provider == "deepseek"
 | 
					  ---@type AvanteCurlOutput
 | 
				
			||||||
    or Config.provider == "groq"
 | 
					  local spec = H["make_" .. provider .. "_curl_args"]({
 | 
				
			||||||
  then
 | 
					    question = question,
 | 
				
			||||||
    call_openai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
 | 
					    code_lang = code_lang,
 | 
				
			||||||
  elseif Config.provider == "claude" then
 | 
					    code_content = code_content,
 | 
				
			||||||
    call_claude_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
 | 
					    selected_code_content = selected_content_content,
 | 
				
			||||||
 | 
					  })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  ---@param line string
 | 
				
			||||||
 | 
					  local function parse_and_call(line)
 | 
				
			||||||
 | 
					    local event = line:match("^event: (.+)$")
 | 
				
			||||||
 | 
					    if event then
 | 
				
			||||||
 | 
					      event_state = event
 | 
				
			||||||
 | 
					      return
 | 
				
			||||||
 | 
					    end
 | 
				
			||||||
 | 
					    local data_match = line:match("^data: (.+)$")
 | 
				
			||||||
 | 
					    if data_match then
 | 
				
			||||||
 | 
					      H["parse_" .. provider .. "_response"](
 | 
				
			||||||
 | 
					        data_match,
 | 
				
			||||||
 | 
					        vim.deepcopy({ on_chunk = on_chunk, on_complete = on_complete, event_state = event_state }, true)
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					    end
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if active_job then
 | 
				
			||||||
 | 
					    active_job:shutdown()
 | 
				
			||||||
 | 
					    active_job = nil
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  active_job = curl.post(spec.url, {
 | 
				
			||||||
 | 
					    headers = spec.headers,
 | 
				
			||||||
 | 
					    body = vim.json.encode(spec.body),
 | 
				
			||||||
 | 
					    stream = function(err, data, _)
 | 
				
			||||||
 | 
					      if err then
 | 
				
			||||||
 | 
					        on_complete(err)
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					      end
 | 
				
			||||||
 | 
					      if not data then
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					      end
 | 
				
			||||||
 | 
					      vim.schedule(function()
 | 
				
			||||||
 | 
					        parse_and_call(data)
 | 
				
			||||||
 | 
					      end)
 | 
				
			||||||
 | 
					    end,
 | 
				
			||||||
 | 
					    on_error = function(err)
 | 
				
			||||||
 | 
					      on_complete(err)
 | 
				
			||||||
 | 
					    end,
 | 
				
			||||||
 | 
					    callback = function(_)
 | 
				
			||||||
 | 
					      active_job = nil
 | 
				
			||||||
 | 
					    end,
 | 
				
			||||||
 | 
					  })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  api.nvim_create_autocmd("User", {
 | 
				
			||||||
 | 
					    group = group,
 | 
				
			||||||
 | 
					    pattern = M.CANCEL_PATTERN,
 | 
				
			||||||
 | 
					    callback = function()
 | 
				
			||||||
 | 
					      if active_job then
 | 
				
			||||||
 | 
					        active_job:shutdown()
 | 
				
			||||||
 | 
					        vim.notify("LLM request cancelled", vim.log.levels.DEBUG)
 | 
				
			||||||
 | 
					        active_job = nil
 | 
				
			||||||
 | 
					      end
 | 
				
			||||||
 | 
					    end,
 | 
				
			||||||
 | 
					  })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return active_job
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
function M.setup()
 | 
					function M.setup()
 | 
				
			||||||
@ -417,6 +509,38 @@ function M.setup()
 | 
				
			|||||||
  if not has then
 | 
					  if not has then
 | 
				
			||||||
    E.setup(E.key())
 | 
					    E.setup(E.key())
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  M.commands()
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@param provider Provider
 | 
				
			||||||
 | 
					function M.refresh(provider)
 | 
				
			||||||
 | 
					  local has = E[provider]
 | 
				
			||||||
 | 
					  if not has then
 | 
				
			||||||
 | 
					    E.setup(E.key(provider), true)
 | 
				
			||||||
 | 
					  else
 | 
				
			||||||
 | 
					    vim.notify_once("Switch to provider: " .. provider, vim.log.levels.INFO)
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					  require("avante").setup({ provider = provider })
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					M.commands = function()
 | 
				
			||||||
 | 
					  api.nvim_create_user_command("AvanteSwitchProvider", function(args)
 | 
				
			||||||
 | 
					    local cmd = vim.trim(args.args or "")
 | 
				
			||||||
 | 
					    M.refresh(cmd)
 | 
				
			||||||
 | 
					  end, {
 | 
				
			||||||
 | 
					    nargs = 1,
 | 
				
			||||||
 | 
					    desc = "avante: switch provider",
 | 
				
			||||||
 | 
					    complete = function(_, line)
 | 
				
			||||||
 | 
					      if line:match("^%s*AvanteSwitchProvider %w") then
 | 
				
			||||||
 | 
					        return {}
 | 
				
			||||||
 | 
					      end
 | 
				
			||||||
 | 
					      local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or ""
 | 
				
			||||||
 | 
					      return vim.tbl_filter(function(key)
 | 
				
			||||||
 | 
					        return key:find(prefix) == 1
 | 
				
			||||||
 | 
					      end, vim.tbl_keys(E.env))
 | 
				
			||||||
 | 
					    end,
 | 
				
			||||||
 | 
					  })
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
return M
 | 
					return M
 | 
				
			||||||
 | 
				
			|||||||
@ -68,7 +68,7 @@ M.defaults = {
 | 
				
			|||||||
    },
 | 
					    },
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
  windows = {
 | 
					  windows = {
 | 
				
			||||||
    wrap_line = true,
 | 
					    wrap_line = true, -- similar to vim.o.wrap
 | 
				
			||||||
    width = 30, -- default % based on available width
 | 
					    width = 30, -- default % based on available width
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
  --- @class AvanteConflictUserConfig
 | 
					  --- @class AvanteConflictUserConfig
 | 
				
			||||||
 | 
				
			|||||||
@ -143,14 +143,11 @@ function M._init(id)
 | 
				
			|||||||
  return M
 | 
					  return M
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
M.open = function()
 | 
					 | 
				
			||||||
  M._init(api.nvim_get_current_tabpage())._get(false):open()
 | 
					 | 
				
			||||||
end
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
M.toggle = function()
 | 
					M.toggle = function()
 | 
				
			||||||
  local sidebar = M._get()
 | 
					  local sidebar = M._get()
 | 
				
			||||||
  if not sidebar then
 | 
					  if not sidebar then
 | 
				
			||||||
    M.open()
 | 
					    M._init(api.nvim_get_current_tabpage())
 | 
				
			||||||
 | 
					    M.current.sidebar:open()
 | 
				
			||||||
    return true
 | 
					    return true
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -141,6 +141,15 @@ function Sidebar:intialize()
 | 
				
			|||||||
      mode = { "n" },
 | 
					      mode = { "n" },
 | 
				
			||||||
      key = "q",
 | 
					      key = "q",
 | 
				
			||||||
      handler = function()
 | 
					      handler = function()
 | 
				
			||||||
 | 
					        api.nvim_exec_autocmds("User", { pattern = AiBot.CANCEL_PATTERN })
 | 
				
			||||||
 | 
					        self.renderer:close()
 | 
				
			||||||
 | 
					      end,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      mode = { "n" },
 | 
				
			||||||
 | 
					      key = "<Esc>",
 | 
				
			||||||
 | 
					      handler = function()
 | 
				
			||||||
 | 
					        api.nvim_exec_autocmds("User", { pattern = AiBot.CANCEL_PATTERN })
 | 
				
			||||||
        self.renderer:close()
 | 
					        self.renderer:close()
 | 
				
			||||||
      end,
 | 
					      end,
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
@ -171,10 +180,16 @@ function Sidebar:intialize()
 | 
				
			|||||||
    api.nvim_create_autocmd("VimResized", {
 | 
					    api.nvim_create_autocmd("VimResized", {
 | 
				
			||||||
      group = self.augroup,
 | 
					      group = self.augroup,
 | 
				
			||||||
      callback = function()
 | 
					      callback = function()
 | 
				
			||||||
 | 
					        if not self.view:is_open() then
 | 
				
			||||||
 | 
					          return
 | 
				
			||||||
 | 
					        end
 | 
				
			||||||
        local new_layout = Config.get_renderer_layout_options()
 | 
					        local new_layout = Config.get_renderer_layout_options()
 | 
				
			||||||
        vim.api.nvim_win_set_width(self.view.win, new_layout.width)
 | 
					        vim.api.nvim_win_set_width(self.view.win, new_layout.width)
 | 
				
			||||||
        vim.api.nvim_win_set_height(self.view.win, new_layout.height)
 | 
					        vim.api.nvim_win_set_height(self.view.win, new_layout.height)
 | 
				
			||||||
        self.renderer:set_size({ width = new_layout.width, height = new_layout.height })
 | 
					        self.renderer:set_size({ width = new_layout.width, height = new_layout.height })
 | 
				
			||||||
 | 
					        vim.defer_fn(function()
 | 
				
			||||||
 | 
					          vim.cmd("AvanteRefresh")
 | 
				
			||||||
 | 
					        end, 200)
 | 
				
			||||||
      end,
 | 
					      end,
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -228,34 +243,51 @@ function Sidebar:is_focused()
 | 
				
			|||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---@param content string concatenated content of the buffer
 | 
					---@param content string concatenated content of the buffer
 | 
				
			||||||
---@param opts? {focus?: boolean, scroll?: boolean, callback?: fun(): nil} whether to focus the result view
 | 
					---@param opts? {focus?: boolean, stream?: boolean, scroll?: boolean, callback?: fun(): nil} whether to focus the result view
 | 
				
			||||||
function Sidebar:update_content(content, opts)
 | 
					function Sidebar:update_content(content, opts)
 | 
				
			||||||
  opts = vim.tbl_deep_extend("force", { focus = true, scroll = true, callback = nil }, opts or {})
 | 
					  opts = vim.tbl_deep_extend("force", { focus = true, scroll = true, stream = false, callback = nil }, opts or {})
 | 
				
			||||||
  vim.defer_fn(function()
 | 
					  if opts.stream then
 | 
				
			||||||
    api.nvim_set_option_value("modifiable", true, { buf = self.view.buf })
 | 
					    vim.schedule(function()
 | 
				
			||||||
    api.nvim_buf_set_lines(self.view.buf, 0, -1, false, vim.split(content, "\n"))
 | 
					      api.nvim_set_option_value("modifiable", true, { buf = self.view.buf })
 | 
				
			||||||
    api.nvim_set_option_value("modifiable", false, { buf = self.view.buf })
 | 
					      local current_window = vim.api.nvim_get_current_win()
 | 
				
			||||||
    api.nvim_set_option_value("filetype", "Avante", { buf = self.view.buf })
 | 
					      local cursor_position = vim.api.nvim_win_get_cursor(current_window)
 | 
				
			||||||
    if opts.callback ~= nil then
 | 
					      local row, col = cursor_position[1], cursor_position[2]
 | 
				
			||||||
      opts.callback()
 | 
					 | 
				
			||||||
    end
 | 
					 | 
				
			||||||
    if opts.focus and not self:is_focused() then
 | 
					 | 
				
			||||||
      xpcall(function()
 | 
					 | 
				
			||||||
        --- set cursor to bottom of result view
 | 
					 | 
				
			||||||
        api.nvim_set_current_win(self.winid.result)
 | 
					 | 
				
			||||||
      end, function(err)
 | 
					 | 
				
			||||||
        return err
 | 
					 | 
				
			||||||
      end)
 | 
					 | 
				
			||||||
    end
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if opts.scroll then
 | 
					      local lines = vim.split(content, "\n")
 | 
				
			||||||
      xpcall(function()
 | 
					
 | 
				
			||||||
        api.nvim_win_set_cursor(self.winid.result, { api.nvim_buf_line_count(self.bufnr.result), 0 })
 | 
					      vim.api.nvim_put(lines, "c", true, true)
 | 
				
			||||||
      end, function(err)
 | 
					
 | 
				
			||||||
        return err
 | 
					      local num_lines = #lines
 | 
				
			||||||
      end)
 | 
					      local last_line_length = #lines[num_lines]
 | 
				
			||||||
    end
 | 
					      vim.api.nvim_win_set_cursor(current_window, { row + num_lines - 1, col + last_line_length })
 | 
				
			||||||
  end, 0)
 | 
					    end)
 | 
				
			||||||
 | 
					  else
 | 
				
			||||||
 | 
					    vim.defer_fn(function()
 | 
				
			||||||
 | 
					      api.nvim_set_option_value("modifiable", true, { buf = self.view.buf })
 | 
				
			||||||
 | 
					      api.nvim_buf_set_lines(self.view.buf, 0, -1, false, vim.split(content, "\n"))
 | 
				
			||||||
 | 
					      api.nvim_set_option_value("modifiable", false, { buf = self.view.buf })
 | 
				
			||||||
 | 
					      api.nvim_set_option_value("filetype", "Avante", { buf = self.view.buf })
 | 
				
			||||||
 | 
					      if opts.callback ~= nil then
 | 
				
			||||||
 | 
					        opts.callback()
 | 
				
			||||||
 | 
					      end
 | 
				
			||||||
 | 
					      if opts.focus and not self:is_focused() then
 | 
				
			||||||
 | 
					        xpcall(function()
 | 
				
			||||||
 | 
					          --- set cursor to bottom of result view
 | 
				
			||||||
 | 
					          api.nvim_set_current_win(self.winid.result)
 | 
				
			||||||
 | 
					        end, function(err)
 | 
				
			||||||
 | 
					          return err
 | 
				
			||||||
 | 
					        end)
 | 
				
			||||||
 | 
					      end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      if opts.scroll then
 | 
				
			||||||
 | 
					        xpcall(function()
 | 
				
			||||||
 | 
					          api.nvim_win_set_cursor(self.winid.result, { api.nvim_buf_line_count(self.bufnr.result), 0 })
 | 
				
			||||||
 | 
					        end, function(err)
 | 
				
			||||||
 | 
					          return err
 | 
				
			||||||
 | 
					        end)
 | 
				
			||||||
 | 
					      end
 | 
				
			||||||
 | 
					    end, 0)
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
  return self
 | 
					  return self
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -638,7 +670,7 @@ function Sidebar:render()
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
 | 
					    local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    AiBot.call_ai_api_stream(
 | 
					    AiBot.invoke_llm_stream(
 | 
				
			||||||
      request,
 | 
					      request,
 | 
				
			||||||
      filetype,
 | 
					      filetype,
 | 
				
			||||||
      content_with_line_numbers,
 | 
					      content_with_line_numbers,
 | 
				
			||||||
@ -646,7 +678,7 @@ function Sidebar:render()
 | 
				
			|||||||
      function(chunk)
 | 
					      function(chunk)
 | 
				
			||||||
        signal.is_loading = true
 | 
					        signal.is_loading = true
 | 
				
			||||||
        full_response = full_response .. chunk
 | 
					        full_response = full_response .. chunk
 | 
				
			||||||
        self:update_content(content_prefix .. full_response)
 | 
					        self:update_content(chunk, { stream = true, scroll = false })
 | 
				
			||||||
        vim.schedule(function()
 | 
					        vim.schedule(function()
 | 
				
			||||||
          vim.cmd("redraw")
 | 
					          vim.cmd("redraw")
 | 
				
			||||||
        end)
 | 
					        end)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,8 +1,17 @@
 | 
				
			|||||||
local Range = require("avante.range")
 | 
					local Range = require("avante.range")
 | 
				
			||||||
local SelectionResult = require("avante.selection_result")
 | 
					local SelectionResult = require("avante.selection_result")
 | 
				
			||||||
local M = {}
 | 
					local M = {}
 | 
				
			||||||
function M.trim_suffix(str, suffix)
 | 
					---@param str string
 | 
				
			||||||
  return string.gsub(str, suffix .. "$", "")
 | 
					---@param opts? {suffix?: string, prefix?: string}
 | 
				
			||||||
 | 
					function M.trim(str, opts)
 | 
				
			||||||
 | 
					  if not opts then
 | 
				
			||||||
 | 
					    return str
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					  if opts.suffix then
 | 
				
			||||||
 | 
					    return str:sub(-1) == opts.suffix and str:sub(1, -2) or str
 | 
				
			||||||
 | 
					  elseif opts.prefix then
 | 
				
			||||||
 | 
					    return str:sub(1, 1) == opts.prefix and str:sub(2) or str
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
function M.trim_line_number_prefix(line)
 | 
					function M.trim_line_number_prefix(line)
 | 
				
			||||||
  return line:gsub("^L%d+: ", "")
 | 
					  return line:gsub("^L%d+: ", "")
 | 
				
			||||||
 | 
				
			|||||||
@ -42,14 +42,17 @@ function View:setup(split_command, size)
 | 
				
			|||||||
  api.nvim_set_option_value("foldcolumn", "0", { win = self.win })
 | 
					  api.nvim_set_option_value("foldcolumn", "0", { win = self.win })
 | 
				
			||||||
  api.nvim_set_option_value("number", false, { win = self.win })
 | 
					  api.nvim_set_option_value("number", false, { win = self.win })
 | 
				
			||||||
  api.nvim_set_option_value("relativenumber", false, { win = self.win })
 | 
					  api.nvim_set_option_value("relativenumber", false, { win = self.win })
 | 
				
			||||||
 | 
					  api.nvim_set_option_value("winfixwidth", true, { win = self.win })
 | 
				
			||||||
  api.nvim_set_option_value("list", false, { win = self.win })
 | 
					  api.nvim_set_option_value("list", false, { win = self.win })
 | 
				
			||||||
  api.nvim_set_option_value("wrap", Config.windows.wrap_line, { win = self.win })
 | 
					  api.nvim_set_option_value("wrap", Config.windows.wrap_line, { win = self.win })
 | 
				
			||||||
  api.nvim_set_option_value("winhl", "", { win = self.win })
 | 
					  api.nvim_set_option_value("winhl", "", { win = self.win })
 | 
				
			||||||
 | 
					  api.nvim_set_option_value("linebreak", true, { win = self.win }) -- only has effect when wrap=true
 | 
				
			||||||
 | 
					  api.nvim_set_option_value("breakindent", true, { win = self.win }) -- only has effect when wrap=true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  -- buffer stuff
 | 
					  -- buffer stuff
 | 
				
			||||||
  xpcall(function()
 | 
					  xpcall(function()
 | 
				
			||||||
    api.nvim_buf_set_name(self.buf, RESULT_BUF_NAME)
 | 
					    api.nvim_buf_set_name(self.buf, RESULT_BUF_NAME)
 | 
				
			||||||
  end, function(err) end)
 | 
					  end, function(_) end)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return self
 | 
					  return self
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user