feat: supports reasoning_content ()

This commit is contained in:
yetone 2025-02-02 02:12:14 +08:00 committed by GitHub
parent d1286e7bfb
commit b5ac768416
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 23 additions and 9 deletions

@ -139,6 +139,8 @@ M._stream = function(opts)
---@type AvanteCurlOutput ---@type AvanteCurlOutput
local spec = Provider.parse_curl_args(Provider, code_opts) local spec = Provider.parse_curl_args(Provider, code_opts)
local resp_ctx = {}
---@param line string ---@param line string
local function parse_stream_data(line) local function parse_stream_data(line)
local event = line:match("^event: (.+)$") local event = line:match("^event: (.+)$")
@ -147,7 +149,7 @@ M._stream = function(opts)
return return
end end
local data_match = line:match("^data: (.+)$") local data_match = line:match("^data: (.+)$")
if data_match then Provider.parse_response(data_match, current_event_state, handler_opts) end if data_match then Provider.parse_response(resp_ctx, data_match, current_event_state, handler_opts) end
end end
local function parse_response_without_stream(data) local function parse_response_without_stream(data)

@ -77,7 +77,7 @@ M.parse_messages = function(opts)
return messages return messages
end end
M.parse_response = function(data_stream, event_state, opts) M.parse_response = function(ctx, data_stream, event_state, opts)
if event_state == nil then if event_state == nil then
if data_stream:match('"content_block_delta"') then if data_stream:match('"content_block_delta"') then
event_state = "content_block_delta" event_state = "content_block_delta"

@ -64,7 +64,7 @@ M.parse_messages = function(opts)
} }
end end
M.parse_response = function(data_stream, _, opts) M.parse_response = function(ctx, data_stream, _, opts)
local ok, json = pcall(vim.json.decode, data_stream) local ok, json = pcall(vim.json.decode, data_stream)
if not ok then opts.on_complete(json) end if not ok then opts.on_complete(json) end
if json.candidates then if json.candidates then

@ -37,7 +37,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@class ResponseParser ---@class ResponseParser
---@field on_chunk fun(chunk: string): any ---@field on_chunk fun(chunk: string): any
---@field on_complete fun(err: string|nil): any ---@field on_complete fun(err: string|nil): any
---@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil ---@alias AvanteResponseParser fun(ctx: any, data_stream: string, event_state: string, opts: ResponseParser): nil
--- ---
---@class AvanteDefaultBaseProvider: table<string, any> ---@class AvanteDefaultBaseProvider: table<string, any>
---@field endpoint? string ---@field endpoint? string

@ -26,7 +26,8 @@ local P = require("avante.providers")
--- ---
---@class OpenAIMessage ---@class OpenAIMessage
---@field role? "user" | "system" | "assistant" ---@field role? "user" | "system" | "assistant"
---@field content string ---@field content? string
---@field reasoning_content? string
--- ---
---@class AvanteProviderFunctor ---@class AvanteProviderFunctor
local M = {} local M = {}
@ -106,19 +107,30 @@ M.parse_messages = function(opts)
return final_messages return final_messages
end end
M.parse_response = function(data_stream, _, opts) M.parse_response = function(ctx, data_stream, _, opts)
if data_stream:match('"%[DONE%]":') then if data_stream:match('"%[DONE%]":') then
opts.on_complete(nil) opts.on_complete(nil)
return return
end end
if data_stream:match('"delta":') then if data_stream:match('"delta":') then
---@type OpenAIChatResponse ---@type OpenAIChatResponse
local json = vim.json.decode(data_stream) local jsn = vim.json.decode(data_stream)
if json.choices and json.choices[1] then Utils.debug("jsn", jsn)
local choice = json.choices[1] if jsn.choices and jsn.choices[1] then
local choice = jsn.choices[1]
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
opts.on_complete(nil) opts.on_complete(nil)
elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
ctx.returned_think_start_tag = true
opts.on_chunk("<think>\n")
end
opts.on_chunk(choice.delta.reasoning_content)
elseif choice.delta.content then elseif choice.delta.content then
if ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag then
ctx.returned_think_end_tag = true
opts.on_chunk("\n</think>\n\n")
end
if choice.delta.content ~= vim.NIL then opts.on_chunk(choice.delta.content) end if choice.delta.content ~= vim.NIL then opts.on_chunk(choice.delta.content) end
end end
end end