feat: supports openai o1-preview
* feat: make O1 models on openai work by handle non-streams & correct parameters * chore: set temperature automatically when using o1 models
This commit is contained in:
parent
64f26b9e72
commit
d74c9d0417
@ -114,6 +114,10 @@ M.stream = function(opts)
|
|||||||
if data_match then Provider.parse_response(data_match, current_event_state, handler_opts) end
|
if data_match then Provider.parse_response(data_match, current_event_state, handler_opts) end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function parse_response_without_stream(data)
|
||||||
|
Provider.parse_response_without_stream(data, current_event_state, handler_opts)
|
||||||
|
end
|
||||||
|
|
||||||
local completed = false
|
local completed = false
|
||||||
|
|
||||||
local active_job
|
local active_job
|
||||||
@ -170,6 +174,14 @@ M.stream = function(opts)
|
|||||||
end
|
end
|
||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- If stream is not enabled, then handle the response here
|
||||||
|
if spec.body.stream == false and result.status == 200 then
|
||||||
|
vim.schedule(function()
|
||||||
|
completed = true
|
||||||
|
parse_response_without_stream(result.body)
|
||||||
|
end)
|
||||||
|
end
|
||||||
end,
|
end,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ local P = require("avante.providers")
|
|||||||
---@field created integer
|
---@field created integer
|
||||||
---@field model string
|
---@field model string
|
||||||
---@field system_fingerprint string
|
---@field system_fingerprint string
|
||||||
---@field choices? OpenAIResponseChoice[]
|
---@field choices? OpenAIResponseChoice[] | OpenAIResponseChoiceComplete[]
|
||||||
---@field usage {prompt_tokens: integer, completion_tokens: integer, total_tokens: integer}
|
---@field usage {prompt_tokens: integer, completion_tokens: integer, total_tokens: integer}
|
||||||
---
|
---
|
||||||
---@class OpenAIResponseChoice
|
---@class OpenAIResponseChoice
|
||||||
@ -18,6 +18,12 @@ local P = require("avante.providers")
|
|||||||
---@field logprobs? integer
|
---@field logprobs? integer
|
||||||
---@field finish_reason? "stop" | "length"
|
---@field finish_reason? "stop" | "length"
|
||||||
---
|
---
|
||||||
|
---@class OpenAIResponseChoiceComplete
|
||||||
|
---@field message OpenAIMessage
|
||||||
|
---@field finish_reason "stop" | "length"
|
||||||
|
---@field index integer
|
||||||
|
---@field logprobs integer
|
||||||
|
---
|
||||||
---@class OpenAIMessage
|
---@class OpenAIMessage
|
||||||
---@field role? "user" | "system" | "assistant"
|
---@field role? "user" | "system" | "assistant"
|
||||||
---@field content string
|
---@field content string
|
||||||
@ -50,10 +56,22 @@ M.parse_message = function(opts)
|
|||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
local messages = {}
|
||||||
{ role = "system", content = opts.system_prompt },
|
local provider = P[Config.provider]
|
||||||
{ role = "user", content = user_content },
|
local base, _ = P.parse_config(provider)
|
||||||
}
|
|
||||||
|
-- NOTE: Handle the case where the selected model is the `o1` model
|
||||||
|
-- "o1" models are "smart" enough to understand user prompt as a system prompt in this context
|
||||||
|
if base.model and string.find(base.model, "o1") then
|
||||||
|
table.insert(messages, { role = "user", content = opts.system_prompt })
|
||||||
|
else
|
||||||
|
table.insert(messages, { role = "system", content = opts.system_prompt })
|
||||||
|
end
|
||||||
|
|
||||||
|
-- User message after the prompt
|
||||||
|
table.insert(messages, { role = "user", content = user_content })
|
||||||
|
|
||||||
|
return messages
|
||||||
end
|
end
|
||||||
|
|
||||||
M.parse_response = function(data_stream, _, opts)
|
M.parse_response = function(data_stream, _, opts)
|
||||||
@ -75,6 +93,18 @@ M.parse_response = function(data_stream, _, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
M.parse_response_without_stream = function(data, _, opts)
|
||||||
|
---@type OpenAIChatResponse
|
||||||
|
local json = vim.json.decode(data)
|
||||||
|
if json.choices and json.choices[1] then
|
||||||
|
local choice = json.choices[1]
|
||||||
|
if choice.message and choice.message.content then
|
||||||
|
opts.on_chunk(choice.message.content)
|
||||||
|
vim.schedule(function() opts.on_complete(nil) end)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
M.parse_curl_args = function(provider, code_opts)
|
M.parse_curl_args = function(provider, code_opts)
|
||||||
local base, body_opts = P.parse_config(provider)
|
local base, body_opts = P.parse_config(provider)
|
||||||
|
|
||||||
@ -83,6 +113,14 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
}
|
}
|
||||||
if not P.env.is_local("openai") then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
|
if not P.env.is_local("openai") then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
|
||||||
|
|
||||||
|
-- NOTE: When using "o1" set the supported parameters only
|
||||||
|
local stream = true
|
||||||
|
if base.model and string.find(base.model, "o1") then
|
||||||
|
stream = false
|
||||||
|
body_opts.max_tokens = nil
|
||||||
|
body_opts.temperature = 1
|
||||||
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat/completions",
|
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat/completions",
|
||||||
proxy = base.proxy,
|
proxy = base.proxy,
|
||||||
@ -91,7 +129,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = base.model,
|
model = base.model,
|
||||||
messages = M.parse_message(code_opts),
|
messages = M.parse_message(code_opts),
|
||||||
stream = true,
|
stream = stream,
|
||||||
}, body_opts),
|
}, body_opts),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
Loading…
x
Reference in New Issue
Block a user