fix(claude): sending state manually (#84)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2024-08-19 06:11:02 -04:00 committed by GitHub
parent ba06b9bd9d
commit 2463c896f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 44 deletions

View File

@ -82,33 +82,13 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
```lua ```lua
{ {
---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq"
provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq" provider = "claude",
openai = {
endpoint = "https://api.openai.com",
model = "gpt-4o",
temperature = 0,
max_tokens = 4096,
},
azure = {
endpoint = "", -- example: "https://<your-resource-name>.openai.azure.com"
deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment")
api_version = "2024-06-01",
temperature = 0,
max_tokens = 4096,
},
claude = { claude = {
endpoint = "https://api.anthropic.com", endpoint = "https://api.anthropic.com",
model = "claude-3-5-sonnet-20240620", model = "claude-3-5-sonnet-20240620",
temperature = 0, temperature = 0,
max_tokens = 4096, max_tokens = 4096,
}, },
highlights = {
---@type AvanteConflictHighlights
diff = {
current = "DiffText",
incoming = "DiffAdd",
},
},
mappings = { mappings = {
ask = "<leader>aa", ask = "<leader>aa",
edit = "<leader>ae", edit = "<leader>ae",
@ -127,6 +107,7 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
prev = "[[", prev = "[[",
}, },
}, },
hints = { enabled = true },
windows = { windows = {
wrap_line = true, wrap_line = true,
width = 30, -- default % based on available width width = 30, -- default % based on available width
@ -301,13 +282,13 @@ A custom provider should following the following spec:
---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput ---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
parse_curl_args = function(opts, code_opts) end parse_curl_args = function(opts, code_opts) end
--- This function will be used to parse incoming SSE stream --- This function will be used to parse incoming SSE stream
--- It takes in the data stream as the first argument, followed by opts retrieved from given buffer. --- It takes in the data stream as the first argument, followed by SSE event state, and opts
--- retrieved from given buffer.
--- This opts include: --- This opts include:
--- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk --- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk
--- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk --- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk
--- - event_state: SSE event state. ---@type fun(data_stream: string, opts: ResponseParser, event_state: string): nil
---@type fun(data_stream: string, opts: ResponseParser): nil parse_response_data = function(data_stream, event_state, opts) end
parse_response_data = function(data_stream, opts) end
} }
``` ```
@ -341,9 +322,9 @@ vendors = {
} }
end, end,
-- The below function is used if the vendors has specific SSE spec that is not claude or openai. -- The below function is used if the vendors has specific SSE spec that is not claude or openai.
parse_response_data = function(data_stream, opts) parse_response_data = function(data_stream, event_state, opts)
local Llm = require "avante.llm" local Llm = require "avante.llm"
Llm.parse_openai_response(data_stream, opts) Llm.parse_openai_response(data_stream, event_state, opts)
end, end,
}, },
}, },

View File

@ -204,10 +204,9 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput ---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
--- ---
---@class ResponseParser ---@class ResponseParser
---@field event_state string
---@field on_chunk fun(chunk: string): any ---@field on_chunk fun(chunk: string): any
---@field on_complete fun(err: string|nil): any ---@field on_complete fun(err: string|nil): any
---@alias AvanteResponseParser fun(data_stream: string, opts: ResponseParser): nil ---@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil
--- ---
---@class AvanteProvider ---@class AvanteProvider
---@field endpoint string ---@field endpoint string
@ -215,6 +214,9 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
---@field api_key_name string ---@field api_key_name string
---@field parse_response_data AvanteResponseParser ---@field parse_response_data AvanteResponseParser
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput ---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
---
---@alias AvanteChunkParser fun(chunk: string): any
---@alias AvanteCompleteParser fun(err: string|nil): nil
------------------------------Anthropic------------------------------ ------------------------------Anthropic------------------------------
@ -278,14 +280,14 @@ H.make_claude_message = function(opts)
end end
---@type AvanteResponseParser ---@type AvanteResponseParser
H.parse_claude_response = function(data_stream, opts) H.parse_claude_response = function(data_stream, event_state, opts)
if opts.event_state == "content_block_delta" then if event_state == "content_block_delta" then
local json = vim.json.decode(data_stream) local json = vim.json.decode(data_stream)
opts.on_chunk(json.delta.text) opts.on_chunk(json.delta.text)
elseif opts.event_state == "message_stop" then elseif event_state == "message_stop" then
opts.on_complete(nil) opts.on_complete(nil)
return return
elseif opts.event_state == "error" then elseif event_state == "error" then
opts.on_complete(vim.json.decode(data_stream)) opts.on_complete(vim.json.decode(data_stream))
end end
end end
@ -351,7 +353,7 @@ H.make_openai_message = function(opts)
end end
---@type AvanteResponseParser ---@type AvanteResponseParser
H.parse_openai_response = function(data_stream, opts) H.parse_openai_response = function(data_stream, _, opts)
if data_stream:match('"%[DONE%]":') then if data_stream:match('"%[DONE%]":') then
opts.on_complete(nil) opts.on_complete(nil)
return return
@ -477,8 +479,8 @@ local active_job = nil
---@param code_lang string ---@param code_lang string
---@param code_content string ---@param code_content string
---@param selected_content_content string | nil ---@param selected_content_content string | nil
---@param on_chunk fun(chunk: string): any ---@param on_chunk AvanteChunkParser
---@param on_complete fun(err: string|nil): any ---@param on_complete AvanteCompleteParser
M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete) M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
local provider = Config.provider local provider = Config.provider
@ -488,7 +490,8 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
code_content = code_content, code_content = code_content,
selected_code_content = selected_content_content, selected_code_content = selected_content_content,
} }
local handler_opts = { on_chunk = on_chunk, on_complete = on_complete, event_state = nil } local current_event_state = nil
local handler_opts = { on_chunk = on_chunk, on_complete = on_complete }
---@type AvanteCurlOutput ---@type AvanteCurlOutput
local spec = nil local spec = nil
@ -502,6 +505,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
ProviderConfig = Config.vendors[provider] ProviderConfig = Config.vendors[provider]
spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts) spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
end end
--- If the provider doesn't have stream set, we set it to true
if spec.body.stream == nil then if spec.body.stream == nil then
spec = vim.tbl_deep_extend("force", spec, { spec = vim.tbl_deep_extend("force", spec, {
body = { stream = true }, body = { stream = true },
@ -512,15 +516,15 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
local function parse_and_call(line) local function parse_and_call(line)
local event = line:match("^event: (.+)$") local event = line:match("^event: (.+)$")
if event then if event then
handler_opts.event_state = event current_event_state = event
return return
end end
local data_match = line:match("^data: (.+)$") local data_match = line:match("^data: (.+)$")
if data_match then if data_match then
if ProviderConfig ~= nil then if ProviderConfig ~= nil then
ProviderConfig.parse_response_data(data_match, handler_opts) ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts)
else else
H["parse_" .. provider .. "_response"](data_match, handler_opts) H["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts)
end end
end end
end end

View File

@ -266,6 +266,9 @@ function Sidebar:update_content(content, opts)
end) end)
else else
vim.defer_fn(function() vim.defer_fn(function()
if self.view.buf == nil then
return
end
api.nvim_set_option_value("modifiable", true, { buf = self.view.buf }) api.nvim_set_option_value("modifiable", true, { buf = self.view.buf })
api.nvim_buf_set_lines(self.view.buf, 0, -1, false, vim.split(content, "\n")) api.nvim_buf_set_lines(self.view.buf, 0, -1, false, vim.split(content, "\n"))
api.nvim_set_option_value("modifiable", false, { buf = self.view.buf }) api.nvim_set_option_value("modifiable", false, { buf = self.view.buf })
@ -500,7 +503,7 @@ local function get_conflict_content(content, snippets)
table.insert(result, "=======") table.insert(result, "=======")
for _, line in ipairs(vim.split(snippet.content, "\n")) do for _, line in ipairs(vim.split(snippet.content, "\n")) do
line = Utils.trim_line_number_prefix(line) line = line:gsub("^L%d+: ", "")
table.insert(result, line) table.insert(result, line)
end end
@ -680,14 +683,18 @@ function Sidebar:render()
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf }) local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
Llm.stream(request, filetype, content_with_line_numbers, selected_code_content_with_line_numbers, function(chunk) ---@type AvanteChunkParser
local on_chunk = function(chunk)
signal.is_loading = true signal.is_loading = true
full_response = full_response .. chunk full_response = full_response .. chunk
self:update_content(chunk, { stream = true, scroll = false }) self:update_content(chunk, { stream = true, scroll = false })
vim.schedule(function() vim.schedule(function()
vim.cmd("redraw") vim.cmd("redraw")
end) end)
end, function(err) end
---@type AvanteCompleteParser
local on_complete = function(err)
signal.is_loading = false signal.is_loading = false
if err ~= nil then if err ~= nil then
@ -716,7 +723,16 @@ function Sidebar:render()
response = full_response, response = full_response,
}) })
save_chat_history(self, chat_history) save_chat_history(self, chat_history)
end) end
Llm.stream(
request,
filetype,
content_with_line_numbers,
selected_code_content_with_line_numbers,
on_chunk,
on_complete
)
if Config.behaviour.auto_apply_diff_after_generation then if Config.behaviour.auto_apply_diff_after_generation then
apply() apply()