fix(claude): sending state manually (#84)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
		
							parent
							
								
									ba06b9bd9d
								
							
						
					
					
						commit
						2463c896f1
					
				
							
								
								
									
										35
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								README.md
									
									
									
									
									
								
							@ -82,33 +82,13 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
 | 
				
			|||||||
```lua
 | 
					```lua
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
  ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq"
 | 
					  ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq"
 | 
				
			||||||
  provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq"
 | 
					  provider = "claude",
 | 
				
			||||||
  openai = {
 | 
					 | 
				
			||||||
    endpoint = "https://api.openai.com",
 | 
					 | 
				
			||||||
    model = "gpt-4o",
 | 
					 | 
				
			||||||
    temperature = 0,
 | 
					 | 
				
			||||||
    max_tokens = 4096,
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  azure = {
 | 
					 | 
				
			||||||
    endpoint = "", -- example: "https://<your-resource-name>.openai.azure.com"
 | 
					 | 
				
			||||||
    deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment")
 | 
					 | 
				
			||||||
    api_version = "2024-06-01",
 | 
					 | 
				
			||||||
    temperature = 0,
 | 
					 | 
				
			||||||
    max_tokens = 4096,
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  claude = {
 | 
					  claude = {
 | 
				
			||||||
    endpoint = "https://api.anthropic.com",
 | 
					    endpoint = "https://api.anthropic.com",
 | 
				
			||||||
    model = "claude-3-5-sonnet-20240620",
 | 
					    model = "claude-3-5-sonnet-20240620",
 | 
				
			||||||
    temperature = 0,
 | 
					    temperature = 0,
 | 
				
			||||||
    max_tokens = 4096,
 | 
					    max_tokens = 4096,
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
  highlights = {
 | 
					 | 
				
			||||||
    ---@type AvanteConflictHighlights
 | 
					 | 
				
			||||||
    diff = {
 | 
					 | 
				
			||||||
      current = "DiffText",
 | 
					 | 
				
			||||||
      incoming = "DiffAdd",
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  mappings = {
 | 
					  mappings = {
 | 
				
			||||||
    ask = "<leader>aa",
 | 
					    ask = "<leader>aa",
 | 
				
			||||||
    edit = "<leader>ae",
 | 
					    edit = "<leader>ae",
 | 
				
			||||||
@ -127,6 +107,7 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
 | 
				
			|||||||
      prev = "[[",
 | 
					      prev = "[[",
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
 | 
					  hints = { enabled = true },
 | 
				
			||||||
  windows = {
 | 
					  windows = {
 | 
				
			||||||
    wrap_line = true,
 | 
					    wrap_line = true,
 | 
				
			||||||
    width = 30, -- default % based on available width
 | 
					    width = 30, -- default % based on available width
 | 
				
			||||||
@ -301,13 +282,13 @@ A custom provider should following the following spec:
 | 
				
			|||||||
  ---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
					  ---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
				
			||||||
  parse_curl_args = function(opts, code_opts) end
 | 
					  parse_curl_args = function(opts, code_opts) end
 | 
				
			||||||
  --- This function will be used to parse incoming SSE stream
 | 
					  --- This function will be used to parse incoming SSE stream
 | 
				
			||||||
  --- It takes in the data stream as the first argument, followed by opts retrieved from given buffer.
 | 
					  --- It takes in the data stream as the first argument, followed by SSE event state, and opts
 | 
				
			||||||
 | 
					  --- retrieved from given buffer.
 | 
				
			||||||
  --- This opts include:
 | 
					  --- This opts include:
 | 
				
			||||||
  --- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk
 | 
					  --- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk
 | 
				
			||||||
  --- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk
 | 
					  --- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk
 | 
				
			||||||
  --- - event_state: SSE event state.
 | 
					  ---@type fun(data_stream: string, opts: ResponseParser, event_state: string): nil
 | 
				
			||||||
  ---@type fun(data_stream: string, opts: ResponseParser): nil
 | 
					  parse_response_data = function(data_stream, event_state, opts) end
 | 
				
			||||||
  parse_response_data = function(data_stream, opts) end
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -341,9 +322,9 @@ vendors = {
 | 
				
			|||||||
      }
 | 
					      }
 | 
				
			||||||
    end,
 | 
					    end,
 | 
				
			||||||
    -- The below function is used if the vendors has specific SSE spec that is not claude or openai.
 | 
					    -- The below function is used if the vendors has specific SSE spec that is not claude or openai.
 | 
				
			||||||
    parse_response_data = function(data_stream, opts)
 | 
					    parse_response_data = function(data_stream, event_state, opts)
 | 
				
			||||||
      local Llm = require "avante.llm"
 | 
					      local Llm = require "avante.llm"
 | 
				
			||||||
      Llm.parse_openai_response(data_stream, opts)
 | 
					      Llm.parse_openai_response(data_stream, event_state, opts)
 | 
				
			||||||
    end,
 | 
					    end,
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
},
 | 
					},
 | 
				
			||||||
 | 
				
			|||||||
@ -204,10 +204,9 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
 | 
				
			|||||||
---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
					---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
				
			||||||
---
 | 
					---
 | 
				
			||||||
---@class ResponseParser
 | 
					---@class ResponseParser
 | 
				
			||||||
---@field event_state string
 | 
					 | 
				
			||||||
---@field on_chunk fun(chunk: string): any
 | 
					---@field on_chunk fun(chunk: string): any
 | 
				
			||||||
---@field on_complete fun(err: string|nil): any
 | 
					---@field on_complete fun(err: string|nil): any
 | 
				
			||||||
---@alias AvanteResponseParser fun(data_stream: string, opts: ResponseParser): nil
 | 
					---@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil
 | 
				
			||||||
---
 | 
					---
 | 
				
			||||||
---@class AvanteProvider
 | 
					---@class AvanteProvider
 | 
				
			||||||
---@field endpoint string
 | 
					---@field endpoint string
 | 
				
			||||||
@ -215,6 +214,9 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
 | 
				
			|||||||
---@field api_key_name string
 | 
					---@field api_key_name string
 | 
				
			||||||
---@field parse_response_data AvanteResponseParser
 | 
					---@field parse_response_data AvanteResponseParser
 | 
				
			||||||
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
					---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					---@alias AvanteChunkParser fun(chunk: string): any
 | 
				
			||||||
 | 
					---@alias AvanteCompleteParser fun(err: string|nil): nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
------------------------------Anthropic------------------------------
 | 
					------------------------------Anthropic------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -278,14 +280,14 @@ H.make_claude_message = function(opts)
 | 
				
			|||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---@type AvanteResponseParser
 | 
					---@type AvanteResponseParser
 | 
				
			||||||
H.parse_claude_response = function(data_stream, opts)
 | 
					H.parse_claude_response = function(data_stream, event_state, opts)
 | 
				
			||||||
  if opts.event_state == "content_block_delta" then
 | 
					  if event_state == "content_block_delta" then
 | 
				
			||||||
    local json = vim.json.decode(data_stream)
 | 
					    local json = vim.json.decode(data_stream)
 | 
				
			||||||
    opts.on_chunk(json.delta.text)
 | 
					    opts.on_chunk(json.delta.text)
 | 
				
			||||||
  elseif opts.event_state == "message_stop" then
 | 
					  elseif event_state == "message_stop" then
 | 
				
			||||||
    opts.on_complete(nil)
 | 
					    opts.on_complete(nil)
 | 
				
			||||||
    return
 | 
					    return
 | 
				
			||||||
  elseif opts.event_state == "error" then
 | 
					  elseif event_state == "error" then
 | 
				
			||||||
    opts.on_complete(vim.json.decode(data_stream))
 | 
					    opts.on_complete(vim.json.decode(data_stream))
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
@ -351,7 +353,7 @@ H.make_openai_message = function(opts)
 | 
				
			|||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---@type AvanteResponseParser
 | 
					---@type AvanteResponseParser
 | 
				
			||||||
H.parse_openai_response = function(data_stream, opts)
 | 
					H.parse_openai_response = function(data_stream, _, opts)
 | 
				
			||||||
  if data_stream:match('"%[DONE%]":') then
 | 
					  if data_stream:match('"%[DONE%]":') then
 | 
				
			||||||
    opts.on_complete(nil)
 | 
					    opts.on_complete(nil)
 | 
				
			||||||
    return
 | 
					    return
 | 
				
			||||||
@ -477,8 +479,8 @@ local active_job = nil
 | 
				
			|||||||
---@param code_lang string
 | 
					---@param code_lang string
 | 
				
			||||||
---@param code_content string
 | 
					---@param code_content string
 | 
				
			||||||
---@param selected_content_content string | nil
 | 
					---@param selected_content_content string | nil
 | 
				
			||||||
---@param on_chunk fun(chunk: string): any
 | 
					---@param on_chunk AvanteChunkParser
 | 
				
			||||||
---@param on_complete fun(err: string|nil): any
 | 
					---@param on_complete AvanteCompleteParser
 | 
				
			||||||
M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
 | 
					M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
 | 
				
			||||||
  local provider = Config.provider
 | 
					  local provider = Config.provider
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -488,7 +490,8 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
 | 
				
			|||||||
    code_content = code_content,
 | 
					    code_content = code_content,
 | 
				
			||||||
    selected_code_content = selected_content_content,
 | 
					    selected_code_content = selected_content_content,
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  local handler_opts = { on_chunk = on_chunk, on_complete = on_complete, event_state = nil }
 | 
					  local current_event_state = nil
 | 
				
			||||||
 | 
					  local handler_opts = { on_chunk = on_chunk, on_complete = on_complete }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  ---@type AvanteCurlOutput
 | 
					  ---@type AvanteCurlOutput
 | 
				
			||||||
  local spec = nil
 | 
					  local spec = nil
 | 
				
			||||||
@ -502,6 +505,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
 | 
				
			|||||||
    ProviderConfig = Config.vendors[provider]
 | 
					    ProviderConfig = Config.vendors[provider]
 | 
				
			||||||
    spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
 | 
					    spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					  --- If the provider doesn't have stream set, we set it to true
 | 
				
			||||||
  if spec.body.stream == nil then
 | 
					  if spec.body.stream == nil then
 | 
				
			||||||
    spec = vim.tbl_deep_extend("force", spec, {
 | 
					    spec = vim.tbl_deep_extend("force", spec, {
 | 
				
			||||||
      body = { stream = true },
 | 
					      body = { stream = true },
 | 
				
			||||||
@ -512,15 +516,15 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
 | 
				
			|||||||
  local function parse_and_call(line)
 | 
					  local function parse_and_call(line)
 | 
				
			||||||
    local event = line:match("^event: (.+)$")
 | 
					    local event = line:match("^event: (.+)$")
 | 
				
			||||||
    if event then
 | 
					    if event then
 | 
				
			||||||
      handler_opts.event_state = event
 | 
					      current_event_state = event
 | 
				
			||||||
      return
 | 
					      return
 | 
				
			||||||
    end
 | 
					    end
 | 
				
			||||||
    local data_match = line:match("^data: (.+)$")
 | 
					    local data_match = line:match("^data: (.+)$")
 | 
				
			||||||
    if data_match then
 | 
					    if data_match then
 | 
				
			||||||
      if ProviderConfig ~= nil then
 | 
					      if ProviderConfig ~= nil then
 | 
				
			||||||
        ProviderConfig.parse_response_data(data_match, handler_opts)
 | 
					        ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts)
 | 
				
			||||||
      else
 | 
					      else
 | 
				
			||||||
        H["parse_" .. provider .. "_response"](data_match, handler_opts)
 | 
					        H["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts)
 | 
				
			||||||
      end
 | 
					      end
 | 
				
			||||||
    end
 | 
					    end
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
				
			|||||||
@ -266,6 +266,9 @@ function Sidebar:update_content(content, opts)
 | 
				
			|||||||
    end)
 | 
					    end)
 | 
				
			||||||
  else
 | 
					  else
 | 
				
			||||||
    vim.defer_fn(function()
 | 
					    vim.defer_fn(function()
 | 
				
			||||||
 | 
					      if self.view.buf == nil then
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					      end
 | 
				
			||||||
      api.nvim_set_option_value("modifiable", true, { buf = self.view.buf })
 | 
					      api.nvim_set_option_value("modifiable", true, { buf = self.view.buf })
 | 
				
			||||||
      api.nvim_buf_set_lines(self.view.buf, 0, -1, false, vim.split(content, "\n"))
 | 
					      api.nvim_buf_set_lines(self.view.buf, 0, -1, false, vim.split(content, "\n"))
 | 
				
			||||||
      api.nvim_set_option_value("modifiable", false, { buf = self.view.buf })
 | 
					      api.nvim_set_option_value("modifiable", false, { buf = self.view.buf })
 | 
				
			||||||
@ -500,7 +503,7 @@ local function get_conflict_content(content, snippets)
 | 
				
			|||||||
    table.insert(result, "=======")
 | 
					    table.insert(result, "=======")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for _, line in ipairs(vim.split(snippet.content, "\n")) do
 | 
					    for _, line in ipairs(vim.split(snippet.content, "\n")) do
 | 
				
			||||||
      line = Utils.trim_line_number_prefix(line)
 | 
					      line = line:gsub("^L%d+: ", "")
 | 
				
			||||||
      table.insert(result, line)
 | 
					      table.insert(result, line)
 | 
				
			||||||
    end
 | 
					    end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -680,14 +683,18 @@ function Sidebar:render()
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
 | 
					    local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Llm.stream(request, filetype, content_with_line_numbers, selected_code_content_with_line_numbers, function(chunk)
 | 
					    ---@type AvanteChunkParser
 | 
				
			||||||
 | 
					    local on_chunk = function(chunk)
 | 
				
			||||||
      signal.is_loading = true
 | 
					      signal.is_loading = true
 | 
				
			||||||
      full_response = full_response .. chunk
 | 
					      full_response = full_response .. chunk
 | 
				
			||||||
      self:update_content(chunk, { stream = true, scroll = false })
 | 
					      self:update_content(chunk, { stream = true, scroll = false })
 | 
				
			||||||
      vim.schedule(function()
 | 
					      vim.schedule(function()
 | 
				
			||||||
        vim.cmd("redraw")
 | 
					        vim.cmd("redraw")
 | 
				
			||||||
      end)
 | 
					      end)
 | 
				
			||||||
    end, function(err)
 | 
					    end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ---@type AvanteCompleteParser
 | 
				
			||||||
 | 
					    local on_complete = function(err)
 | 
				
			||||||
      signal.is_loading = false
 | 
					      signal.is_loading = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      if err ~= nil then
 | 
					      if err ~= nil then
 | 
				
			||||||
@ -716,7 +723,16 @@ function Sidebar:render()
 | 
				
			|||||||
        response = full_response,
 | 
					        response = full_response,
 | 
				
			||||||
      })
 | 
					      })
 | 
				
			||||||
      save_chat_history(self, chat_history)
 | 
					      save_chat_history(self, chat_history)
 | 
				
			||||||
    end)
 | 
					    end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Llm.stream(
 | 
				
			||||||
 | 
					      request,
 | 
				
			||||||
 | 
					      filetype,
 | 
				
			||||||
 | 
					      content_with_line_numbers,
 | 
				
			||||||
 | 
					      selected_code_content_with_line_numbers,
 | 
				
			||||||
 | 
					      on_chunk,
 | 
				
			||||||
 | 
					      on_complete
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if Config.behaviour.auto_apply_diff_after_generation then
 | 
					    if Config.behaviour.auto_apply_diff_after_generation then
 | 
				
			||||||
      apply()
 | 
					      apply()
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user