From ec5d1abf34cd8c541e7454136d9ae7c44341fd50 Mon Sep 17 00:00:00 2001 From: Larry Lv Date: Sat, 4 Jan 2025 21:23:33 -0800 Subject: [PATCH] fix(openai): support all `o` series models (#1031) Before this change, since `max_completion_tokens` was not set for `o` series models, the completion request will time out sometimes. This makes sure it converts the `max_tokens` parameter to `max_completion_tokens` for `o` series models. I tested this change with `gpt-4o-mini`, `o1-mini` and `o3-mini`, and they all still work as expected. --- lua/avante/providers/openai.lua | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 56fd627..eccae99 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -54,6 +54,10 @@ M.get_user_message = function(opts) ) end +M.is_o_series_model = function(model) + return model and string.match(model, "^o%d+") ~= nil +end + M.parse_messages = function(opts) local messages = {} local provider = P[Config.provider] @@ -61,7 +65,7 @@ M.parse_messages = function(opts) -- NOTE: Handle the case where the selected model is the `o1` model -- "o1" models are "smart" enough to understand user prompt as a system prompt in this context - if base.model and string.find(base.model, "o1") then + if M.is_o_series_model(base.model) then table.insert(messages, { role = "user", content = opts.system_prompt }) else table.insert(messages, { role = "system", content = opts.system_prompt }) @@ -150,9 +154,10 @@ M.parse_curl_args = function(provider, code_opts) headers["Authorization"] = "Bearer " .. api_key end - -- NOTE: When using "o1" set the supported parameters only + -- NOTE: When using "o" series set the supported parameters only local stream = true - if base.model and string.find(base.model, "o1") then + if M.is_o_series_model(base.model) then + body_opts.max_completion_tokens = body_opts.max_tokens body_opts.max_tokens = nil body_opts.temperature = 1 end