feat: make tiktoken optional (#245)
This commit is contained in:
parent
3d3a249119
commit
b874045885
@ -23,7 +23,7 @@ Install `avante.nvim` using [lazy.nvim](https://github.com/folke/lazy.nvim):
|
||||
{
|
||||
"yetone/avante.nvim",
|
||||
event = "VeryLazy",
|
||||
build = "make",
|
||||
build = "make", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count
|
||||
opts = {
|
||||
-- add any opts here
|
||||
},
|
||||
@ -50,7 +50,7 @@ For Windows users, change the build command to the following:
|
||||
{
|
||||
"yetone/avante.nvim",
|
||||
event = "VeryLazy",
|
||||
build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1",
|
||||
build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count
|
||||
-- rest of the config
|
||||
}
|
||||
```
|
||||
|
@ -1,5 +1,5 @@
|
||||
local Utils = require("avante.utils")
|
||||
local Tiktoken = require("avante.tiktoken")
|
||||
local Tokens = require("avante.utils.tokens")
|
||||
local P = require("avante.providers")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
@ -13,7 +13,7 @@ M.parse_message = function(opts)
|
||||
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
|
||||
}
|
||||
|
||||
if Tiktoken.count(code_prompt_obj.text) > 1024 then
|
||||
if Tokens.calculate_tokens(code_prompt_obj.text) > 1024 then
|
||||
code_prompt_obj.cache_control = { type = "ephemeral" }
|
||||
end
|
||||
|
||||
@ -31,7 +31,7 @@ M.parse_message = function(opts)
|
||||
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.selected_code_content),
|
||||
}
|
||||
|
||||
if Tiktoken.count(selected_code_obj.text) > 1024 then
|
||||
if Tokens.calculate_tokens(selected_code_obj.text) > 1024 then
|
||||
selected_code_obj.cache_control = { type = "ephemeral" }
|
||||
end
|
||||
|
||||
@ -50,7 +50,7 @@ M.parse_message = function(opts)
|
||||
text = user_prompt,
|
||||
}
|
||||
|
||||
if Tiktoken.count(user_prompt_obj.text) > 1024 then
|
||||
if Tokens.calculate_tokens(user_prompt_obj.text) > 1024 then
|
||||
user_prompt_obj.cache_control = { type = "ephemeral" }
|
||||
end
|
||||
|
||||
@ -79,6 +79,9 @@ M.parse_response = function(data_stream, event_state, opts)
|
||||
end
|
||||
end
|
||||
|
||||
---@param provider AvanteProviderFunctor
|
||||
---@param code_opts AvantePromptOptions
|
||||
---@return table
|
||||
M.parse_curl_args = function(provider, code_opts)
|
||||
local base, body_opts = P.parse_config(provider)
|
||||
|
||||
|
@ -62,6 +62,7 @@ local Dressing = require("avante.ui.dressing")
|
||||
---@field parse_response_data AvanteResponseParser
|
||||
---@field parse_curl_args? AvanteCurlArgsParser
|
||||
---@field parse_stream_data? AvanteStreamParser
|
||||
---@field parse_api_key fun(): string | nil
|
||||
---
|
||||
---@class AvanteProviderFunctor
|
||||
---@field parse_message AvanteMessageParser
|
||||
|
@ -52,7 +52,6 @@ local M = {}
|
||||
function M.setup(model)
|
||||
local ok, core = pcall(require, "tiktoken_core")
|
||||
if not ok then
|
||||
print("Warn: tiktoken_core is not found!!!!")
|
||||
return
|
||||
end
|
||||
|
||||
|
56
lua/avante/utils/tokens.lua
Normal file
56
lua/avante/utils/tokens.lua
Normal file
@ -0,0 +1,56 @@
|
||||
--Taken from https://github.com/jackMort/ChatGPT.nvim/blob/main/lua/chatgpt/flows/chat/tokens.lua
|
||||
local Tiktoken = require("avante.tiktoken")
|
||||
local Tokens = {}
|
||||
|
||||
--[[
|
||||
cost_per_token
|
||||
@param {string} token_name
|
||||
@return {number} cost_per_token
|
||||
]]
|
||||
local cost_per_token = {
|
||||
davinci = 0.000002,
|
||||
}
|
||||
|
||||
--- Calculate the number of tokens in a given text.
|
||||
-- @param text The text to calculate the number of tokens for.
|
||||
-- @return The number of tokens in the given text.
|
||||
function Tokens.calculate_tokens(text)
|
||||
if Tiktoken.available() then
|
||||
return Tiktoken.count(text)
|
||||
end
|
||||
local tokens = 0
|
||||
local current_token = ""
|
||||
for char in text:gmatch(".") do
|
||||
if char == " " or char == "\n" then
|
||||
if current_token ~= "" then
|
||||
tokens = tokens + 1
|
||||
current_token = ""
|
||||
end
|
||||
else
|
||||
current_token = current_token .. char
|
||||
end
|
||||
end
|
||||
if current_token ~= "" then
|
||||
tokens = tokens + 1
|
||||
end
|
||||
return tokens
|
||||
end
|
||||
|
||||
--- Calculate the cost of a given text in dollars.
|
||||
-- @param text The text to calculate the cost of.
|
||||
-- @param model The model to use to calculate the cost.
|
||||
-- @return The cost of the given text in dollars.
|
||||
function Tokens.calculate_usage_in_dollars(text, model)
|
||||
local tokens = Tokens.calculate_tokens(text)
|
||||
return Tokens.usage_in_dollars(tokens, model)
|
||||
end
|
||||
|
||||
--- Calculate the cost of a given number of tokens in dollars.
|
||||
-- @param tokens The number of tokens to calculate the cost of.
|
||||
-- @param model The model to use to calculate the cost.
|
||||
-- @return The cost of the given number of tokens in dollars.
|
||||
function Tokens.usage_in_dollars(tokens, model)
|
||||
return tokens * cost_per_token[model or "davinci"]
|
||||
end
|
||||
|
||||
return Tokens
|
Loading…
x
Reference in New Issue
Block a user