feat: supports openai o1-preview
* feat: make O1 models on openai work by handle non-streams & correct parameters * chore: set temperature automatically when using o1 models
This commit is contained in:
parent
64f26b9e72
commit
d74c9d0417
@ -114,6 +114,10 @@ M.stream = function(opts)
|
||||
if data_match then Provider.parse_response(data_match, current_event_state, handler_opts) end
|
||||
end
|
||||
|
||||
local function parse_response_without_stream(data)
|
||||
Provider.parse_response_without_stream(data, current_event_state, handler_opts)
|
||||
end
|
||||
|
||||
local completed = false
|
||||
|
||||
local active_job
|
||||
@ -170,6 +174,14 @@ M.stream = function(opts)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
-- If stream is not enabled, then handle the response here
|
||||
if spec.body.stream == false and result.status == 200 then
|
||||
vim.schedule(function()
|
||||
completed = true
|
||||
parse_response_without_stream(result.body)
|
||||
end)
|
||||
end
|
||||
end,
|
||||
})
|
||||
|
||||
|
@ -9,7 +9,7 @@ local P = require("avante.providers")
|
||||
---@field created integer
|
||||
---@field model string
|
||||
---@field system_fingerprint string
|
||||
---@field choices? OpenAIResponseChoice[]
|
||||
---@field choices? OpenAIResponseChoice[] | OpenAIResponseChoiceComplete[]
|
||||
---@field usage {prompt_tokens: integer, completion_tokens: integer, total_tokens: integer}
|
||||
---
|
||||
---@class OpenAIResponseChoice
|
||||
@ -18,6 +18,12 @@ local P = require("avante.providers")
|
||||
---@field logprobs? integer
|
||||
---@field finish_reason? "stop" | "length"
|
||||
---
|
||||
---@class OpenAIResponseChoiceComplete
|
||||
---@field message OpenAIMessage
|
||||
---@field finish_reason "stop" | "length"
|
||||
---@field index integer
|
||||
---@field logprobs integer
|
||||
---
|
||||
---@class OpenAIMessage
|
||||
---@field role? "user" | "system" | "assistant"
|
||||
---@field content string
|
||||
@ -50,10 +56,22 @@ M.parse_message = function(opts)
|
||||
end)
|
||||
end
|
||||
|
||||
return {
|
||||
{ role = "system", content = opts.system_prompt },
|
||||
{ role = "user", content = user_content },
|
||||
}
|
||||
local messages = {}
|
||||
local provider = P[Config.provider]
|
||||
local base, _ = P.parse_config(provider)
|
||||
|
||||
-- NOTE: Handle the case where the selected model is the `o1` model
|
||||
-- "o1" models are "smart" enough to understand user prompt as a system prompt in this context
|
||||
if base.model and string.find(base.model, "o1") then
|
||||
table.insert(messages, { role = "user", content = opts.system_prompt })
|
||||
else
|
||||
table.insert(messages, { role = "system", content = opts.system_prompt })
|
||||
end
|
||||
|
||||
-- User message after the prompt
|
||||
table.insert(messages, { role = "user", content = user_content })
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
M.parse_response = function(data_stream, _, opts)
|
||||
@ -75,6 +93,18 @@ M.parse_response = function(data_stream, _, opts)
|
||||
end
|
||||
end
|
||||
|
||||
M.parse_response_without_stream = function(data, _, opts)
|
||||
---@type OpenAIChatResponse
|
||||
local json = vim.json.decode(data)
|
||||
if json.choices and json.choices[1] then
|
||||
local choice = json.choices[1]
|
||||
if choice.message and choice.message.content then
|
||||
opts.on_chunk(choice.message.content)
|
||||
vim.schedule(function() opts.on_complete(nil) end)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
M.parse_curl_args = function(provider, code_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
|
||||
@ -83,6 +113,14 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
}
|
||||
if not P.env.is_local("openai") then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
|
||||
|
||||
-- NOTE: When using "o1" set the supported parameters only
|
||||
local stream = true
|
||||
if base.model and string.find(base.model, "o1") then
|
||||
stream = false
|
||||
body_opts.max_tokens = nil
|
||||
body_opts.temperature = 1
|
||||
end
|
||||
|
||||
return {
|
||||
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat/completions",
|
||||
proxy = base.proxy,
|
||||
@ -91,7 +129,7 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = base.model,
|
||||
messages = M.parse_message(code_opts),
|
||||
stream = true,
|
||||
stream = stream,
|
||||
}, body_opts),
|
||||
}
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user