Aaron Pham 0d8098e4eb
fix(style): add parentheses (#471)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
2024-09-03 05:12:07 -04:00

87 lines
2.5 KiB
Lua

local Utils = require("avante.utils")
local P = require("avante.providers")
---@alias CohereFinishReason "COMPLETE" | "LENGTH" | "ERROR"
---
---@class CohereChatStreamResponse
---@field event_type "stream-start" | "text-generation" | "stream-end"
---@field is_finished boolean
---
---@class CohereTextGenerationResponse: CohereChatStreamResponse
---@field text string
---
---@class CohereStreamEndResponse: CohereChatStreamResponse
---@field response CohereChatResponse
---@field finish_reason CohereFinishReason
---
---@class CohereChatResponse
---@field text string
---@field generation_id string
---@field chat_history CohereMessage[]
---@field finish_reason CohereFinishReason
---@field meta {api_version: {version: integer}, billed_units: {input_tokens: integer, output_tokens: integer}, tokens: {input_tokens: integer, output_tokens: integer}}
---
---@class CohereMessage
---@field role? "USER" | "SYSTEM" | "CHATBOT"
---@field message string
---
---@class AvanteProviderFunctor
local M = {}
M.api_key_name = "CO_API_KEY"
M.tokenizer_id = "CohereForAI/c4ai-command-r-plus-08-2024"
M.parse_message = function(opts)
return {
preamble = opts.system_prompt,
message = opts.user_prompt,
}
end
M.parse_stream_data = function(data, opts)
---@type CohereChatStreamResponse
local json = vim.json.decode(data)
if json.is_finished then
opts.on_complete(nil)
return
end
if json.event_type ~= nil then
---@cast json CohereStreamEndResponse
if json.event_type == "stream-end" and json.finish_reason == "COMPLETE" then
opts.on_complete(nil)
return
end
---@cast json CohereTextGenerationResponse
if json.event_type == "text-generation" then opts.on_chunk(json.text) end
end
end
M.parse_curl_args = function(provider, code_opts)
local base, body_opts = P.parse_config(provider)
local headers = {
["Accept"] = "application/json",
["Content-Type"] = "application/json",
["X-Client-Name"] = "avante.nvim/Neovim/"
.. vim.version().major
.. "."
.. vim.version().minor
.. "."
.. vim.version().patch,
}
if not P.env.is_local("cohere") then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
return {
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat",
proxy = base.proxy,
insecure = base.allow_insecure,
headers = headers,
body = vim.tbl_deep_extend("force", {
model = base.model,
stream = true,
}, M.parse_message(code_opts), body_opts),
}
end
return M