fix(claude): sending state manually (#84)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
ba06b9bd9d
commit
2463c896f1
35
README.md
35
README.md
@ -82,33 +82,13 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
|
|||||||
```lua
|
```lua
|
||||||
{
|
{
|
||||||
---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq"
|
---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq"
|
||||||
provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq"
|
provider = "claude",
|
||||||
openai = {
|
|
||||||
endpoint = "https://api.openai.com",
|
|
||||||
model = "gpt-4o",
|
|
||||||
temperature = 0,
|
|
||||||
max_tokens = 4096,
|
|
||||||
},
|
|
||||||
azure = {
|
|
||||||
endpoint = "", -- example: "https://<your-resource-name>.openai.azure.com"
|
|
||||||
deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment")
|
|
||||||
api_version = "2024-06-01",
|
|
||||||
temperature = 0,
|
|
||||||
max_tokens = 4096,
|
|
||||||
},
|
|
||||||
claude = {
|
claude = {
|
||||||
endpoint = "https://api.anthropic.com",
|
endpoint = "https://api.anthropic.com",
|
||||||
model = "claude-3-5-sonnet-20240620",
|
model = "claude-3-5-sonnet-20240620",
|
||||||
temperature = 0,
|
temperature = 0,
|
||||||
max_tokens = 4096,
|
max_tokens = 4096,
|
||||||
},
|
},
|
||||||
highlights = {
|
|
||||||
---@type AvanteConflictHighlights
|
|
||||||
diff = {
|
|
||||||
current = "DiffText",
|
|
||||||
incoming = "DiffAdd",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
mappings = {
|
mappings = {
|
||||||
ask = "<leader>aa",
|
ask = "<leader>aa",
|
||||||
edit = "<leader>ae",
|
edit = "<leader>ae",
|
||||||
@ -127,6 +107,7 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
|
|||||||
prev = "[[",
|
prev = "[[",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
hints = { enabled = true },
|
||||||
windows = {
|
windows = {
|
||||||
wrap_line = true,
|
wrap_line = true,
|
||||||
width = 30, -- default % based on available width
|
width = 30, -- default % based on available width
|
||||||
@ -301,13 +282,13 @@ A custom provider should following the following spec:
|
|||||||
---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
|
---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||||
parse_curl_args = function(opts, code_opts) end
|
parse_curl_args = function(opts, code_opts) end
|
||||||
--- This function will be used to parse incoming SSE stream
|
--- This function will be used to parse incoming SSE stream
|
||||||
--- It takes in the data stream as the first argument, followed by opts retrieved from given buffer.
|
--- It takes in the data stream as the first argument, followed by SSE event state, and opts
|
||||||
|
--- retrieved from given buffer.
|
||||||
--- This opts include:
|
--- This opts include:
|
||||||
--- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk
|
--- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk
|
||||||
--- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk
|
--- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk
|
||||||
--- - event_state: SSE event state.
|
---@type fun(data_stream: string, opts: ResponseParser, event_state: string): nil
|
||||||
---@type fun(data_stream: string, opts: ResponseParser): nil
|
parse_response_data = function(data_stream, event_state, opts) end
|
||||||
parse_response_data = function(data_stream, opts) end
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -341,9 +322,9 @@ vendors = {
|
|||||||
}
|
}
|
||||||
end,
|
end,
|
||||||
-- The below function is used if the vendors has specific SSE spec that is not claude or openai.
|
-- The below function is used if the vendors has specific SSE spec that is not claude or openai.
|
||||||
parse_response_data = function(data_stream, opts)
|
parse_response_data = function(data_stream, event_state, opts)
|
||||||
local Llm = require "avante.llm"
|
local Llm = require "avante.llm"
|
||||||
Llm.parse_openai_response(data_stream, opts)
|
Llm.parse_openai_response(data_stream, event_state, opts)
|
||||||
end,
|
end,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -204,10 +204,9 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
|
|||||||
---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
|
---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||||
---
|
---
|
||||||
---@class ResponseParser
|
---@class ResponseParser
|
||||||
---@field event_state string
|
|
||||||
---@field on_chunk fun(chunk: string): any
|
---@field on_chunk fun(chunk: string): any
|
||||||
---@field on_complete fun(err: string|nil): any
|
---@field on_complete fun(err: string|nil): any
|
||||||
---@alias AvanteResponseParser fun(data_stream: string, opts: ResponseParser): nil
|
---@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil
|
||||||
---
|
---
|
||||||
---@class AvanteProvider
|
---@class AvanteProvider
|
||||||
---@field endpoint string
|
---@field endpoint string
|
||||||
@ -215,6 +214,9 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
|
|||||||
---@field api_key_name string
|
---@field api_key_name string
|
||||||
---@field parse_response_data AvanteResponseParser
|
---@field parse_response_data AvanteResponseParser
|
||||||
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
|
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||||
|
---
|
||||||
|
---@alias AvanteChunkParser fun(chunk: string): any
|
||||||
|
---@alias AvanteCompleteParser fun(err: string|nil): nil
|
||||||
|
|
||||||
------------------------------Anthropic------------------------------
|
------------------------------Anthropic------------------------------
|
||||||
|
|
||||||
@ -278,14 +280,14 @@ H.make_claude_message = function(opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@type AvanteResponseParser
|
---@type AvanteResponseParser
|
||||||
H.parse_claude_response = function(data_stream, opts)
|
H.parse_claude_response = function(data_stream, event_state, opts)
|
||||||
if opts.event_state == "content_block_delta" then
|
if event_state == "content_block_delta" then
|
||||||
local json = vim.json.decode(data_stream)
|
local json = vim.json.decode(data_stream)
|
||||||
opts.on_chunk(json.delta.text)
|
opts.on_chunk(json.delta.text)
|
||||||
elseif opts.event_state == "message_stop" then
|
elseif event_state == "message_stop" then
|
||||||
opts.on_complete(nil)
|
opts.on_complete(nil)
|
||||||
return
|
return
|
||||||
elseif opts.event_state == "error" then
|
elseif event_state == "error" then
|
||||||
opts.on_complete(vim.json.decode(data_stream))
|
opts.on_complete(vim.json.decode(data_stream))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -351,7 +353,7 @@ H.make_openai_message = function(opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@type AvanteResponseParser
|
---@type AvanteResponseParser
|
||||||
H.parse_openai_response = function(data_stream, opts)
|
H.parse_openai_response = function(data_stream, _, opts)
|
||||||
if data_stream:match('"%[DONE%]":') then
|
if data_stream:match('"%[DONE%]":') then
|
||||||
opts.on_complete(nil)
|
opts.on_complete(nil)
|
||||||
return
|
return
|
||||||
@ -477,8 +479,8 @@ local active_job = nil
|
|||||||
---@param code_lang string
|
---@param code_lang string
|
||||||
---@param code_content string
|
---@param code_content string
|
||||||
---@param selected_content_content string | nil
|
---@param selected_content_content string | nil
|
||||||
---@param on_chunk fun(chunk: string): any
|
---@param on_chunk AvanteChunkParser
|
||||||
---@param on_complete fun(err: string|nil): any
|
---@param on_complete AvanteCompleteParser
|
||||||
M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
|
M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
|
||||||
local provider = Config.provider
|
local provider = Config.provider
|
||||||
|
|
||||||
@ -488,7 +490,8 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
|||||||
code_content = code_content,
|
code_content = code_content,
|
||||||
selected_code_content = selected_content_content,
|
selected_code_content = selected_content_content,
|
||||||
}
|
}
|
||||||
local handler_opts = { on_chunk = on_chunk, on_complete = on_complete, event_state = nil }
|
local current_event_state = nil
|
||||||
|
local handler_opts = { on_chunk = on_chunk, on_complete = on_complete }
|
||||||
|
|
||||||
---@type AvanteCurlOutput
|
---@type AvanteCurlOutput
|
||||||
local spec = nil
|
local spec = nil
|
||||||
@ -502,6 +505,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
|||||||
ProviderConfig = Config.vendors[provider]
|
ProviderConfig = Config.vendors[provider]
|
||||||
spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
|
spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
|
||||||
end
|
end
|
||||||
|
--- If the provider doesn't have stream set, we set it to true
|
||||||
if spec.body.stream == nil then
|
if spec.body.stream == nil then
|
||||||
spec = vim.tbl_deep_extend("force", spec, {
|
spec = vim.tbl_deep_extend("force", spec, {
|
||||||
body = { stream = true },
|
body = { stream = true },
|
||||||
@ -512,15 +516,15 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
|||||||
local function parse_and_call(line)
|
local function parse_and_call(line)
|
||||||
local event = line:match("^event: (.+)$")
|
local event = line:match("^event: (.+)$")
|
||||||
if event then
|
if event then
|
||||||
handler_opts.event_state = event
|
current_event_state = event
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
local data_match = line:match("^data: (.+)$")
|
local data_match = line:match("^data: (.+)$")
|
||||||
if data_match then
|
if data_match then
|
||||||
if ProviderConfig ~= nil then
|
if ProviderConfig ~= nil then
|
||||||
ProviderConfig.parse_response_data(data_match, handler_opts)
|
ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts)
|
||||||
else
|
else
|
||||||
H["parse_" .. provider .. "_response"](data_match, handler_opts)
|
H["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -266,6 +266,9 @@ function Sidebar:update_content(content, opts)
|
|||||||
end)
|
end)
|
||||||
else
|
else
|
||||||
vim.defer_fn(function()
|
vim.defer_fn(function()
|
||||||
|
if self.view.buf == nil then
|
||||||
|
return
|
||||||
|
end
|
||||||
api.nvim_set_option_value("modifiable", true, { buf = self.view.buf })
|
api.nvim_set_option_value("modifiable", true, { buf = self.view.buf })
|
||||||
api.nvim_buf_set_lines(self.view.buf, 0, -1, false, vim.split(content, "\n"))
|
api.nvim_buf_set_lines(self.view.buf, 0, -1, false, vim.split(content, "\n"))
|
||||||
api.nvim_set_option_value("modifiable", false, { buf = self.view.buf })
|
api.nvim_set_option_value("modifiable", false, { buf = self.view.buf })
|
||||||
@ -500,7 +503,7 @@ local function get_conflict_content(content, snippets)
|
|||||||
table.insert(result, "=======")
|
table.insert(result, "=======")
|
||||||
|
|
||||||
for _, line in ipairs(vim.split(snippet.content, "\n")) do
|
for _, line in ipairs(vim.split(snippet.content, "\n")) do
|
||||||
line = Utils.trim_line_number_prefix(line)
|
line = line:gsub("^L%d+: ", "")
|
||||||
table.insert(result, line)
|
table.insert(result, line)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -680,14 +683,18 @@ function Sidebar:render()
|
|||||||
|
|
||||||
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
|
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
|
||||||
|
|
||||||
Llm.stream(request, filetype, content_with_line_numbers, selected_code_content_with_line_numbers, function(chunk)
|
---@type AvanteChunkParser
|
||||||
|
local on_chunk = function(chunk)
|
||||||
signal.is_loading = true
|
signal.is_loading = true
|
||||||
full_response = full_response .. chunk
|
full_response = full_response .. chunk
|
||||||
self:update_content(chunk, { stream = true, scroll = false })
|
self:update_content(chunk, { stream = true, scroll = false })
|
||||||
vim.schedule(function()
|
vim.schedule(function()
|
||||||
vim.cmd("redraw")
|
vim.cmd("redraw")
|
||||||
end)
|
end)
|
||||||
end, function(err)
|
end
|
||||||
|
|
||||||
|
---@type AvanteCompleteParser
|
||||||
|
local on_complete = function(err)
|
||||||
signal.is_loading = false
|
signal.is_loading = false
|
||||||
|
|
||||||
if err ~= nil then
|
if err ~= nil then
|
||||||
@ -716,7 +723,16 @@ function Sidebar:render()
|
|||||||
response = full_response,
|
response = full_response,
|
||||||
})
|
})
|
||||||
save_chat_history(self, chat_history)
|
save_chat_history(self, chat_history)
|
||||||
end)
|
end
|
||||||
|
|
||||||
|
Llm.stream(
|
||||||
|
request,
|
||||||
|
filetype,
|
||||||
|
content_with_line_numbers,
|
||||||
|
selected_code_content_with_line_numbers,
|
||||||
|
on_chunk,
|
||||||
|
on_complete
|
||||||
|
)
|
||||||
|
|
||||||
if Config.behaviour.auto_apply_diff_after_generation then
|
if Config.behaviour.auto_apply_diff_after_generation then
|
||||||
apply()
|
apply()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user