Aaron Pham e8c71d931e
chore: run stylua [generated] (#460)
* chore: add stylua

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>

* chore: running stylua

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>

---------

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
2024-09-03 04:19:54 -04:00

87 lines
2.5 KiB
Lua

local Utils = require "avante.utils"
local P = require "avante.providers"
---@alias CohereFinishReason "COMPLETE" | "LENGTH" | "ERROR"
---
---@class CohereChatStreamResponse
---@field event_type "stream-start" | "text-generation" | "stream-end"
---@field is_finished boolean
---
---@class CohereTextGenerationResponse: CohereChatStreamResponse
---@field text string
---
---@class CohereStreamEndResponse: CohereChatStreamResponse
---@field response CohereChatResponse
---@field finish_reason CohereFinishReason
---
---@class CohereChatResponse
---@field text string
---@field generation_id string
---@field chat_history CohereMessage[]
---@field finish_reason CohereFinishReason
---@field meta {api_version: {version: integer}, billed_units: {input_tokens: integer, output_tokens: integer}, tokens: {input_tokens: integer, output_tokens: integer}}
---
---@class CohereMessage
---@field role? "USER" | "SYSTEM" | "CHATBOT"
---@field message string
---
---@class AvanteProviderFunctor
local M = {}
M.api_key_name = "CO_API_KEY"
M.tokenizer_id = "CohereForAI/c4ai-command-r-plus-08-2024"
M.parse_message = function(opts)
return {
preamble = opts.system_prompt,
message = opts.user_prompt,
}
end
M.parse_stream_data = function(data, opts)
---@type CohereChatStreamResponse
local json = vim.json.decode(data)
if json.is_finished then
opts.on_complete(nil)
return
end
if json.event_type ~= nil then
---@cast json CohereStreamEndResponse
if json.event_type == "stream-end" and json.finish_reason == "COMPLETE" then
opts.on_complete(nil)
return
end
---@cast json CohereTextGenerationResponse
if json.event_type == "text-generation" then opts.on_chunk(json.text) end
end
end
M.parse_curl_args = function(provider, code_opts)
local base, body_opts = P.parse_config(provider)
local headers = {
["Accept"] = "application/json",
["Content-Type"] = "application/json",
["X-Client-Name"] = "avante.nvim/Neovim/"
.. vim.version().major
.. "."
.. vim.version().minor
.. "."
.. vim.version().patch,
}
if not P.env.is_local "cohere" then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
return {
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat",
proxy = base.proxy,
insecure = base.allow_insecure,
headers = headers,
body = vim.tbl_deep_extend("force", {
model = base.model,
stream = true,
}, M.parse_message(code_opts), body_opts),
}
end
return M