fix(claude): sending state manually (#84)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
		
							parent
							
								
									ba06b9bd9d
								
							
						
					
					
						commit
						2463c896f1
					
				
							
								
								
									
										35
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								README.md
									
									
									
									
									
								
							@ -82,33 +82,13 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
 | 
			
		||||
```lua
 | 
			
		||||
{
 | 
			
		||||
  ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq"
 | 
			
		||||
  provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq"
 | 
			
		||||
  openai = {
 | 
			
		||||
    endpoint = "https://api.openai.com",
 | 
			
		||||
    model = "gpt-4o",
 | 
			
		||||
    temperature = 0,
 | 
			
		||||
    max_tokens = 4096,
 | 
			
		||||
  },
 | 
			
		||||
  azure = {
 | 
			
		||||
    endpoint = "", -- example: "https://<your-resource-name>.openai.azure.com"
 | 
			
		||||
    deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment")
 | 
			
		||||
    api_version = "2024-06-01",
 | 
			
		||||
    temperature = 0,
 | 
			
		||||
    max_tokens = 4096,
 | 
			
		||||
  },
 | 
			
		||||
  provider = "claude",
 | 
			
		||||
  claude = {
 | 
			
		||||
    endpoint = "https://api.anthropic.com",
 | 
			
		||||
    model = "claude-3-5-sonnet-20240620",
 | 
			
		||||
    temperature = 0,
 | 
			
		||||
    max_tokens = 4096,
 | 
			
		||||
  },
 | 
			
		||||
  highlights = {
 | 
			
		||||
    ---@type AvanteConflictHighlights
 | 
			
		||||
    diff = {
 | 
			
		||||
      current = "DiffText",
 | 
			
		||||
      incoming = "DiffAdd",
 | 
			
		||||
    },
 | 
			
		||||
  },
 | 
			
		||||
  mappings = {
 | 
			
		||||
    ask = "<leader>aa",
 | 
			
		||||
    edit = "<leader>ae",
 | 
			
		||||
@ -127,6 +107,7 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
 | 
			
		||||
      prev = "[[",
 | 
			
		||||
    },
 | 
			
		||||
  },
 | 
			
		||||
  hints = { enabled = true },
 | 
			
		||||
  windows = {
 | 
			
		||||
    wrap_line = true,
 | 
			
		||||
    width = 30, -- default % based on available width
 | 
			
		||||
@ -301,13 +282,13 @@ A custom provider should following the following spec:
 | 
			
		||||
  ---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
			
		||||
  parse_curl_args = function(opts, code_opts) end
 | 
			
		||||
  --- This function will be used to parse incoming SSE stream
 | 
			
		||||
  --- It takes in the data stream as the first argument, followed by opts retrieved from given buffer.
 | 
			
		||||
  --- It takes in the data stream as the first argument, followed by SSE event state, and opts
 | 
			
		||||
  --- retrieved from given buffer.
 | 
			
		||||
  --- This opts include:
 | 
			
		||||
  --- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk
 | 
			
		||||
  --- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk
 | 
			
		||||
  --- - event_state: SSE event state.
 | 
			
		||||
  ---@type fun(data_stream: string, opts: ResponseParser): nil
 | 
			
		||||
  parse_response_data = function(data_stream, opts) end
 | 
			
		||||
  ---@type fun(data_stream: string, opts: ResponseParser, event_state: string): nil
 | 
			
		||||
  parse_response_data = function(data_stream, event_state, opts) end
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
@ -341,9 +322,9 @@ vendors = {
 | 
			
		||||
      }
 | 
			
		||||
    end,
 | 
			
		||||
    -- The below function is used if the vendors has specific SSE spec that is not claude or openai.
 | 
			
		||||
    parse_response_data = function(data_stream, opts)
 | 
			
		||||
    parse_response_data = function(data_stream, event_state, opts)
 | 
			
		||||
      local Llm = require "avante.llm"
 | 
			
		||||
      Llm.parse_openai_response(data_stream, opts)
 | 
			
		||||
      Llm.parse_openai_response(data_stream, event_state, opts)
 | 
			
		||||
    end,
 | 
			
		||||
  },
 | 
			
		||||
},
 | 
			
		||||
 | 
			
		||||
@ -204,10 +204,9 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
 | 
			
		||||
---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
			
		||||
---
 | 
			
		||||
---@class ResponseParser
 | 
			
		||||
---@field event_state string
 | 
			
		||||
---@field on_chunk fun(chunk: string): any
 | 
			
		||||
---@field on_complete fun(err: string|nil): any
 | 
			
		||||
---@alias AvanteResponseParser fun(data_stream: string, opts: ResponseParser): nil
 | 
			
		||||
---@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil
 | 
			
		||||
---
 | 
			
		||||
---@class AvanteProvider
 | 
			
		||||
---@field endpoint string
 | 
			
		||||
@ -215,6 +214,9 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
 | 
			
		||||
---@field api_key_name string
 | 
			
		||||
---@field parse_response_data AvanteResponseParser
 | 
			
		||||
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
			
		||||
---
 | 
			
		||||
---@alias AvanteChunkParser fun(chunk: string): any
 | 
			
		||||
---@alias AvanteCompleteParser fun(err: string|nil): nil
 | 
			
		||||
 | 
			
		||||
------------------------------Anthropic------------------------------
 | 
			
		||||
 | 
			
		||||
@ -278,14 +280,14 @@ H.make_claude_message = function(opts)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@type AvanteResponseParser
 | 
			
		||||
H.parse_claude_response = function(data_stream, opts)
 | 
			
		||||
  if opts.event_state == "content_block_delta" then
 | 
			
		||||
H.parse_claude_response = function(data_stream, event_state, opts)
 | 
			
		||||
  if event_state == "content_block_delta" then
 | 
			
		||||
    local json = vim.json.decode(data_stream)
 | 
			
		||||
    opts.on_chunk(json.delta.text)
 | 
			
		||||
  elseif opts.event_state == "message_stop" then
 | 
			
		||||
  elseif event_state == "message_stop" then
 | 
			
		||||
    opts.on_complete(nil)
 | 
			
		||||
    return
 | 
			
		||||
  elseif opts.event_state == "error" then
 | 
			
		||||
  elseif event_state == "error" then
 | 
			
		||||
    opts.on_complete(vim.json.decode(data_stream))
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
@ -351,7 +353,7 @@ H.make_openai_message = function(opts)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@type AvanteResponseParser
 | 
			
		||||
H.parse_openai_response = function(data_stream, opts)
 | 
			
		||||
H.parse_openai_response = function(data_stream, _, opts)
 | 
			
		||||
  if data_stream:match('"%[DONE%]":') then
 | 
			
		||||
    opts.on_complete(nil)
 | 
			
		||||
    return
 | 
			
		||||
@ -477,8 +479,8 @@ local active_job = nil
 | 
			
		||||
---@param code_lang string
 | 
			
		||||
---@param code_content string
 | 
			
		||||
---@param selected_content_content string | nil
 | 
			
		||||
---@param on_chunk fun(chunk: string): any
 | 
			
		||||
---@param on_complete fun(err: string|nil): any
 | 
			
		||||
---@param on_chunk AvanteChunkParser
 | 
			
		||||
---@param on_complete AvanteCompleteParser
 | 
			
		||||
M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
 | 
			
		||||
  local provider = Config.provider
 | 
			
		||||
 | 
			
		||||
@ -488,7 +490,8 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
 | 
			
		||||
    code_content = code_content,
 | 
			
		||||
    selected_code_content = selected_content_content,
 | 
			
		||||
  }
 | 
			
		||||
  local handler_opts = { on_chunk = on_chunk, on_complete = on_complete, event_state = nil }
 | 
			
		||||
  local current_event_state = nil
 | 
			
		||||
  local handler_opts = { on_chunk = on_chunk, on_complete = on_complete }
 | 
			
		||||
 | 
			
		||||
  ---@type AvanteCurlOutput
 | 
			
		||||
  local spec = nil
 | 
			
		||||
@ -502,6 +505,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
 | 
			
		||||
    ProviderConfig = Config.vendors[provider]
 | 
			
		||||
    spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
 | 
			
		||||
  end
 | 
			
		||||
  --- If the provider doesn't have stream set, we set it to true
 | 
			
		||||
  if spec.body.stream == nil then
 | 
			
		||||
    spec = vim.tbl_deep_extend("force", spec, {
 | 
			
		||||
      body = { stream = true },
 | 
			
		||||
@ -512,15 +516,15 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
 | 
			
		||||
  local function parse_and_call(line)
 | 
			
		||||
    local event = line:match("^event: (.+)$")
 | 
			
		||||
    if event then
 | 
			
		||||
      handler_opts.event_state = event
 | 
			
		||||
      current_event_state = event
 | 
			
		||||
      return
 | 
			
		||||
    end
 | 
			
		||||
    local data_match = line:match("^data: (.+)$")
 | 
			
		||||
    if data_match then
 | 
			
		||||
      if ProviderConfig ~= nil then
 | 
			
		||||
        ProviderConfig.parse_response_data(data_match, handler_opts)
 | 
			
		||||
        ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts)
 | 
			
		||||
      else
 | 
			
		||||
        H["parse_" .. provider .. "_response"](data_match, handler_opts)
 | 
			
		||||
        H["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts)
 | 
			
		||||
      end
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
@ -266,6 +266,9 @@ function Sidebar:update_content(content, opts)
 | 
			
		||||
    end)
 | 
			
		||||
  else
 | 
			
		||||
    vim.defer_fn(function()
 | 
			
		||||
      if self.view.buf == nil then
 | 
			
		||||
        return
 | 
			
		||||
      end
 | 
			
		||||
      api.nvim_set_option_value("modifiable", true, { buf = self.view.buf })
 | 
			
		||||
      api.nvim_buf_set_lines(self.view.buf, 0, -1, false, vim.split(content, "\n"))
 | 
			
		||||
      api.nvim_set_option_value("modifiable", false, { buf = self.view.buf })
 | 
			
		||||
@ -500,7 +503,7 @@ local function get_conflict_content(content, snippets)
 | 
			
		||||
    table.insert(result, "=======")
 | 
			
		||||
 | 
			
		||||
    for _, line in ipairs(vim.split(snippet.content, "\n")) do
 | 
			
		||||
      line = Utils.trim_line_number_prefix(line)
 | 
			
		||||
      line = line:gsub("^L%d+: ", "")
 | 
			
		||||
      table.insert(result, line)
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
@ -680,14 +683,18 @@ function Sidebar:render()
 | 
			
		||||
 | 
			
		||||
    local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
 | 
			
		||||
 | 
			
		||||
    Llm.stream(request, filetype, content_with_line_numbers, selected_code_content_with_line_numbers, function(chunk)
 | 
			
		||||
    ---@type AvanteChunkParser
 | 
			
		||||
    local on_chunk = function(chunk)
 | 
			
		||||
      signal.is_loading = true
 | 
			
		||||
      full_response = full_response .. chunk
 | 
			
		||||
      self:update_content(chunk, { stream = true, scroll = false })
 | 
			
		||||
      vim.schedule(function()
 | 
			
		||||
        vim.cmd("redraw")
 | 
			
		||||
      end)
 | 
			
		||||
    end, function(err)
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    ---@type AvanteCompleteParser
 | 
			
		||||
    local on_complete = function(err)
 | 
			
		||||
      signal.is_loading = false
 | 
			
		||||
 | 
			
		||||
      if err ~= nil then
 | 
			
		||||
@ -716,7 +723,16 @@ function Sidebar:render()
 | 
			
		||||
        response = full_response,
 | 
			
		||||
      })
 | 
			
		||||
      save_chat_history(self, chat_history)
 | 
			
		||||
    end)
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    Llm.stream(
 | 
			
		||||
      request,
 | 
			
		||||
      filetype,
 | 
			
		||||
      content_with_line_numbers,
 | 
			
		||||
      selected_code_content_with_line_numbers,
 | 
			
		||||
      on_chunk,
 | 
			
		||||
      on_complete
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if Config.behaviour.auto_apply_diff_after_generation then
 | 
			
		||||
      apply()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user