feat: make tiktoken optional ()

This commit is contained in:
yetone 2024-08-27 01:46:05 +08:00 committed by GitHub
parent 3d3a249119
commit b874045885
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 66 additions and 7 deletions

@ -23,7 +23,7 @@ Install `avante.nvim` using [lazy.nvim](https://github.com/folke/lazy.nvim):
{
"yetone/avante.nvim",
event = "VeryLazy",
build = "make",
build = "make", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count
opts = {
-- add any opts here
},
@ -50,7 +50,7 @@ For Windows users, change the build command to the following:
{
"yetone/avante.nvim",
event = "VeryLazy",
build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1",
build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count
-- rest of the config
}
```

@ -1,5 +1,5 @@
local Utils = require("avante.utils")
local Tiktoken = require("avante.tiktoken")
local Tokens = require("avante.utils.tokens")
local P = require("avante.providers")
---@class AvanteProviderFunctor
@ -13,7 +13,7 @@ M.parse_message = function(opts)
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
}
if Tiktoken.count(code_prompt_obj.text) > 1024 then
if Tokens.calculate_tokens(code_prompt_obj.text) > 1024 then
code_prompt_obj.cache_control = { type = "ephemeral" }
end
@ -31,7 +31,7 @@ M.parse_message = function(opts)
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.selected_code_content),
}
if Tiktoken.count(selected_code_obj.text) > 1024 then
if Tokens.calculate_tokens(selected_code_obj.text) > 1024 then
selected_code_obj.cache_control = { type = "ephemeral" }
end
@ -50,7 +50,7 @@ M.parse_message = function(opts)
text = user_prompt,
}
if Tiktoken.count(user_prompt_obj.text) > 1024 then
if Tokens.calculate_tokens(user_prompt_obj.text) > 1024 then
user_prompt_obj.cache_control = { type = "ephemeral" }
end
@ -79,6 +79,9 @@ M.parse_response = function(data_stream, event_state, opts)
end
end
---@param provider AvanteProviderFunctor
---@param code_opts AvantePromptOptions
---@return table
M.parse_curl_args = function(provider, code_opts)
local base, body_opts = P.parse_config(provider)

@ -62,6 +62,7 @@ local Dressing = require("avante.ui.dressing")
---@field parse_response_data AvanteResponseParser
---@field parse_curl_args? AvanteCurlArgsParser
---@field parse_stream_data? AvanteStreamParser
---@field parse_api_key fun(): string | nil
---
---@class AvanteProviderFunctor
---@field parse_message AvanteMessageParser

@ -52,7 +52,6 @@ local M = {}
function M.setup(model)
local ok, core = pcall(require, "tiktoken_core")
if not ok then
print("Warn: tiktoken_core is not found!!!!")
return
end

@ -0,0 +1,56 @@
--Taken from https://github.com/jackMort/ChatGPT.nvim/blob/main/lua/chatgpt/flows/chat/tokens.lua
local Tiktoken = require("avante.tiktoken")
local Tokens = {}
--[[
cost_per_token
@param {string} token_name
@return {number} cost_per_token
]]
local cost_per_token = {
davinci = 0.000002,
}
--- Calculate the number of tokens in a given text.
-- @param text The text to calculate the number of tokens for.
-- @return The number of tokens in the given text.
function Tokens.calculate_tokens(text)
if Tiktoken.available() then
return Tiktoken.count(text)
end
local tokens = 0
local current_token = ""
for char in text:gmatch(".") do
if char == " " or char == "\n" then
if current_token ~= "" then
tokens = tokens + 1
current_token = ""
end
else
current_token = current_token .. char
end
end
if current_token ~= "" then
tokens = tokens + 1
end
return tokens
end
--- Calculate the cost of a given text in dollars.
-- @param text The text to calculate the cost of.
-- @param model The model to use to calculate the cost.
-- @return The cost of the given text in dollars.
function Tokens.calculate_usage_in_dollars(text, model)
local tokens = Tokens.calculate_tokens(text)
return Tokens.usage_in_dollars(tokens, model)
end
--- Calculate the cost of a given number of tokens in dollars.
-- @param tokens The number of tokens to calculate the cost of.
-- @param model The model to use to calculate the cost.
-- @return The cost of the given number of tokens in dollars.
function Tokens.usage_in_dollars(tokens, model)
return tokens * cost_per_token[model or "davinci"]
end
return Tokens