106 lines
3.1 KiB
Lua
106 lines
3.1 KiB
Lua
local Utils = require("avante.utils")
|
|
local P = require("avante.providers")
|
|
|
|
---@alias CohereFinishReason "COMPLETE" | "LENGTH" | "ERROR"
|
|
---@alias CohereStreamType "message-start" | "content-start" | "content-delta" | "content-end" | "message-end"
|
|
---
|
|
---@class CohereChatContent
|
|
---@field type? CohereStreamType
|
|
---@field text string
|
|
---
|
|
---@class CohereChatMessage
|
|
---@field content CohereChatContent
|
|
---
|
|
---@class CohereChatStreamBase
|
|
---@field type CohereStreamType
|
|
---@field index integer
|
|
---
|
|
---@class CohereChatContentDelta: CohereChatStreamBase
|
|
---@field type "content-delta" | "content-start" | "content-end"
|
|
---@field delta? { message: CohereChatMessage }
|
|
---
|
|
---@class CohereChatMessageStart: CohereChatStreamBase
|
|
---@field type "message-start"
|
|
---@field delta { message: { role: "assistant" } }
|
|
---
|
|
---@class CohereChatMessageEnd: CohereChatStreamBase
|
|
---@field type "message-end"
|
|
---@field delta { finish_reason: CohereFinishReason, usage: CohereChatUsage }
|
|
---
|
|
---@class CohereChatUsage
|
|
---@field billed_units { input_tokens: integer, output_tokens: integer }
|
|
---@field tokens { input_tokens: integer, output_tokens: integer }
|
|
---
|
|
---@alias CohereChatResponse CohereChatContentDelta | CohereChatMessageStart | CohereChatMessageEnd
|
|
---
|
|
---@class CohereMessage
|
|
---@field type "text"
|
|
---@field text string
|
|
---
|
|
---@class AvanteProviderFunctor
|
|
local M = {}
|
|
|
|
M.api_key_name = "CO_API_KEY"
|
|
M.tokenizer_id = "https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json"
|
|
M.role_map = {
|
|
user = "user",
|
|
assistant = "assistant",
|
|
}
|
|
|
|
M.parse_messages = function(opts)
|
|
local messages = {
|
|
{ role = "system", content = opts.system_prompt },
|
|
}
|
|
vim
|
|
.iter(opts.messages)
|
|
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
|
|
return { messages = messages }
|
|
end
|
|
|
|
M.parse_stream_data = function(data, opts)
|
|
---@type CohereChatResponse
|
|
local json = vim.json.decode(data)
|
|
if json.type ~= nil then
|
|
if json.type == "message-end" and json.delta.finish_reason == "COMPLETE" then
|
|
opts.on_stop({ reason = "complete" })
|
|
return
|
|
end
|
|
if json.type == "content-delta" then opts.on_chunk(json.delta.message.content.text) end
|
|
end
|
|
end
|
|
|
|
M.parse_curl_args = function(provider, code_opts)
|
|
local base, body_opts = P.parse_config(provider)
|
|
|
|
local headers = {
|
|
["Accept"] = "application/json",
|
|
["Content-Type"] = "application/json",
|
|
["X-Client-Name"] = "avante.nvim/Neovim/"
|
|
.. vim.version().major
|
|
.. "."
|
|
.. vim.version().minor
|
|
.. "."
|
|
.. vim.version().patch,
|
|
}
|
|
if P.env.require_api_key(base) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
|
|
|
|
return {
|
|
url = Utils.url_join(base.endpoint, "/chat"),
|
|
proxy = base.proxy,
|
|
insecure = base.allow_insecure,
|
|
headers = headers,
|
|
body = vim.tbl_deep_extend("force", {
|
|
model = base.model,
|
|
stream = true,
|
|
}, M.parse_messages(code_opts), body_opts),
|
|
}
|
|
end
|
|
|
|
M.setup = function()
|
|
P.env.parse_envvar(M)
|
|
require("avante.tokenizers").setup(M.tokenizer_id, false)
|
|
vim.g.avante_login = true
|
|
end
|
|
|
|
return M
|