diff --git a/README.md b/README.md index 55bcd76..0bfb669 100644 --- a/README.md +++ b/README.md @@ -82,33 +82,13 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_ ```lua { ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" - provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq" - openai = { - endpoint = "https://api.openai.com", - model = "gpt-4o", - temperature = 0, - max_tokens = 4096, - }, - azure = { - endpoint = "", -- example: "https://.openai.azure.com" - deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment") - api_version = "2024-06-01", - temperature = 0, - max_tokens = 4096, - }, + provider = "claude", claude = { endpoint = "https://api.anthropic.com", model = "claude-3-5-sonnet-20240620", temperature = 0, max_tokens = 4096, }, - highlights = { - ---@type AvanteConflictHighlights - diff = { - current = "DiffText", - incoming = "DiffAdd", - }, - }, mappings = { ask = "aa", edit = "ae", @@ -127,6 +107,7 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_ prev = "[[", }, }, + hints = { enabled = true }, windows = { wrap_line = true, width = 30, -- default % based on available width @@ -301,13 +282,13 @@ A custom provider should following the following spec: ---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput parse_curl_args = function(opts, code_opts) end --- This function will be used to parse incoming SSE stream - --- It takes in the data stream as the first argument, followed by opts retrieved from given buffer. + --- It takes in the data stream as the first argument, followed by SSE event state, and opts + --- retrieved from given buffer. --- This opts include: --- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk --- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk - --- - event_state: SSE event state. - ---@type fun(data_stream: string, opts: ResponseParser): nil - parse_response_data = function(data_stream, opts) end + ---@type fun(data_stream: string, opts: ResponseParser, event_state: string): nil + parse_response_data = function(data_stream, event_state, opts) end } ``` @@ -341,9 +322,9 @@ vendors = { } end, -- The below function is used if the vendors has specific SSE spec that is not claude or openai. - parse_response_data = function(data_stream, opts) + parse_response_data = function(data_stream, event_state, opts) local Llm = require "avante.llm" - Llm.parse_openai_response(data_stream, opts) + Llm.parse_openai_response(data_stream, event_state, opts) end, }, }, diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 8e69258..b2fa24c 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -204,10 +204,9 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m ---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput --- ---@class ResponseParser ----@field event_state string ---@field on_chunk fun(chunk: string): any ---@field on_complete fun(err: string|nil): any ----@alias AvanteResponseParser fun(data_stream: string, opts: ResponseParser): nil +---@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil --- ---@class AvanteProvider ---@field endpoint string @@ -215,6 +214,9 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m ---@field api_key_name string ---@field parse_response_data AvanteResponseParser ---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput +--- +---@alias AvanteChunkParser fun(chunk: string): any +---@alias AvanteCompleteParser fun(err: string|nil): nil ------------------------------Anthropic------------------------------ @@ -278,14 +280,14 @@ H.make_claude_message = function(opts) end ---@type AvanteResponseParser -H.parse_claude_response = function(data_stream, opts) - if opts.event_state == "content_block_delta" then +H.parse_claude_response = function(data_stream, event_state, opts) + if event_state == "content_block_delta" then local json = vim.json.decode(data_stream) opts.on_chunk(json.delta.text) - elseif opts.event_state == "message_stop" then + elseif event_state == "message_stop" then opts.on_complete(nil) return - elseif opts.event_state == "error" then + elseif event_state == "error" then opts.on_complete(vim.json.decode(data_stream)) end end @@ -351,7 +353,7 @@ H.make_openai_message = function(opts) end ---@type AvanteResponseParser -H.parse_openai_response = function(data_stream, opts) +H.parse_openai_response = function(data_stream, _, opts) if data_stream:match('"%[DONE%]":') then opts.on_complete(nil) return @@ -477,8 +479,8 @@ local active_job = nil ---@param code_lang string ---@param code_content string ---@param selected_content_content string | nil ----@param on_chunk fun(chunk: string): any ----@param on_complete fun(err: string|nil): any +---@param on_chunk AvanteChunkParser +---@param on_complete AvanteCompleteParser M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete) local provider = Config.provider @@ -488,7 +490,8 @@ M.stream = function(question, code_lang, code_content, selected_content_content, code_content = code_content, selected_code_content = selected_content_content, } - local handler_opts = { on_chunk = on_chunk, on_complete = on_complete, event_state = nil } + local current_event_state = nil + local handler_opts = { on_chunk = on_chunk, on_complete = on_complete } ---@type AvanteCurlOutput local spec = nil @@ -502,6 +505,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content, ProviderConfig = Config.vendors[provider] spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts) end + --- If the provider doesn't have stream set, we set it to true if spec.body.stream == nil then spec = vim.tbl_deep_extend("force", spec, { body = { stream = true }, @@ -512,15 +516,15 @@ M.stream = function(question, code_lang, code_content, selected_content_content, local function parse_and_call(line) local event = line:match("^event: (.+)$") if event then - handler_opts.event_state = event + current_event_state = event return end local data_match = line:match("^data: (.+)$") if data_match then if ProviderConfig ~= nil then - ProviderConfig.parse_response_data(data_match, handler_opts) + ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts) else - H["parse_" .. provider .. "_response"](data_match, handler_opts) + H["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts) end end end diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index c5635ed..7968522 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -266,6 +266,9 @@ function Sidebar:update_content(content, opts) end) else vim.defer_fn(function() + if self.view.buf == nil then + return + end api.nvim_set_option_value("modifiable", true, { buf = self.view.buf }) api.nvim_buf_set_lines(self.view.buf, 0, -1, false, vim.split(content, "\n")) api.nvim_set_option_value("modifiable", false, { buf = self.view.buf }) @@ -500,7 +503,7 @@ local function get_conflict_content(content, snippets) table.insert(result, "=======") for _, line in ipairs(vim.split(snippet.content, "\n")) do - line = Utils.trim_line_number_prefix(line) + line = line:gsub("^L%d+: ", "") table.insert(result, line) end @@ -680,14 +683,18 @@ function Sidebar:render() local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf }) - Llm.stream(request, filetype, content_with_line_numbers, selected_code_content_with_line_numbers, function(chunk) + ---@type AvanteChunkParser + local on_chunk = function(chunk) signal.is_loading = true full_response = full_response .. chunk self:update_content(chunk, { stream = true, scroll = false }) vim.schedule(function() vim.cmd("redraw") end) - end, function(err) + end + + ---@type AvanteCompleteParser + local on_complete = function(err) signal.is_loading = false if err ~= nil then @@ -716,7 +723,16 @@ function Sidebar:render() response = full_response, }) save_chat_history(self, chat_history) - end) + end + + Llm.stream( + request, + filetype, + content_with_line_numbers, + selected_code_content_with_line_numbers, + on_chunk, + on_complete + ) if Config.behaviour.auto_apply_diff_after_generation then apply()