feat: make tiktoken optional (#245)
This commit is contained in:
parent
3d3a249119
commit
b874045885
@ -23,7 +23,7 @@ Install `avante.nvim` using [lazy.nvim](https://github.com/folke/lazy.nvim):
|
|||||||
{
|
{
|
||||||
"yetone/avante.nvim",
|
"yetone/avante.nvim",
|
||||||
event = "VeryLazy",
|
event = "VeryLazy",
|
||||||
build = "make",
|
build = "make", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count
|
||||||
opts = {
|
opts = {
|
||||||
-- add any opts here
|
-- add any opts here
|
||||||
},
|
},
|
||||||
@ -50,7 +50,7 @@ For Windows users, change the build command to the following:
|
|||||||
{
|
{
|
||||||
"yetone/avante.nvim",
|
"yetone/avante.nvim",
|
||||||
event = "VeryLazy",
|
event = "VeryLazy",
|
||||||
build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1",
|
build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count
|
||||||
-- rest of the config
|
-- rest of the config
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
local Tiktoken = require("avante.tiktoken")
|
local Tokens = require("avante.utils.tokens")
|
||||||
local P = require("avante.providers")
|
local P = require("avante.providers")
|
||||||
|
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
@ -13,7 +13,7 @@ M.parse_message = function(opts)
|
|||||||
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
|
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
|
||||||
}
|
}
|
||||||
|
|
||||||
if Tiktoken.count(code_prompt_obj.text) > 1024 then
|
if Tokens.calculate_tokens(code_prompt_obj.text) > 1024 then
|
||||||
code_prompt_obj.cache_control = { type = "ephemeral" }
|
code_prompt_obj.cache_control = { type = "ephemeral" }
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ M.parse_message = function(opts)
|
|||||||
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.selected_code_content),
|
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.selected_code_content),
|
||||||
}
|
}
|
||||||
|
|
||||||
if Tiktoken.count(selected_code_obj.text) > 1024 then
|
if Tokens.calculate_tokens(selected_code_obj.text) > 1024 then
|
||||||
selected_code_obj.cache_control = { type = "ephemeral" }
|
selected_code_obj.cache_control = { type = "ephemeral" }
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ M.parse_message = function(opts)
|
|||||||
text = user_prompt,
|
text = user_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
if Tiktoken.count(user_prompt_obj.text) > 1024 then
|
if Tokens.calculate_tokens(user_prompt_obj.text) > 1024 then
|
||||||
user_prompt_obj.cache_control = { type = "ephemeral" }
|
user_prompt_obj.cache_control = { type = "ephemeral" }
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -79,6 +79,9 @@ M.parse_response = function(data_stream, event_state, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---@param provider AvanteProviderFunctor
|
||||||
|
---@param code_opts AvantePromptOptions
|
||||||
|
---@return table
|
||||||
M.parse_curl_args = function(provider, code_opts)
|
M.parse_curl_args = function(provider, code_opts)
|
||||||
local base, body_opts = P.parse_config(provider)
|
local base, body_opts = P.parse_config(provider)
|
||||||
|
|
||||||
|
@ -62,6 +62,7 @@ local Dressing = require("avante.ui.dressing")
|
|||||||
---@field parse_response_data AvanteResponseParser
|
---@field parse_response_data AvanteResponseParser
|
||||||
---@field parse_curl_args? AvanteCurlArgsParser
|
---@field parse_curl_args? AvanteCurlArgsParser
|
||||||
---@field parse_stream_data? AvanteStreamParser
|
---@field parse_stream_data? AvanteStreamParser
|
||||||
|
---@field parse_api_key fun(): string | nil
|
||||||
---
|
---
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
---@field parse_message AvanteMessageParser
|
---@field parse_message AvanteMessageParser
|
||||||
|
@ -52,7 +52,6 @@ local M = {}
|
|||||||
function M.setup(model)
|
function M.setup(model)
|
||||||
local ok, core = pcall(require, "tiktoken_core")
|
local ok, core = pcall(require, "tiktoken_core")
|
||||||
if not ok then
|
if not ok then
|
||||||
print("Warn: tiktoken_core is not found!!!!")
|
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
56
lua/avante/utils/tokens.lua
Normal file
56
lua/avante/utils/tokens.lua
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
--Taken from https://github.com/jackMort/ChatGPT.nvim/blob/main/lua/chatgpt/flows/chat/tokens.lua
|
||||||
|
local Tiktoken = require("avante.tiktoken")
|
||||||
|
local Tokens = {}
|
||||||
|
|
||||||
|
--[[
|
||||||
|
cost_per_token
|
||||||
|
@param {string} token_name
|
||||||
|
@return {number} cost_per_token
|
||||||
|
]]
|
||||||
|
local cost_per_token = {
|
||||||
|
davinci = 0.000002,
|
||||||
|
}
|
||||||
|
|
||||||
|
--- Calculate the number of tokens in a given text.
|
||||||
|
-- @param text The text to calculate the number of tokens for.
|
||||||
|
-- @return The number of tokens in the given text.
|
||||||
|
function Tokens.calculate_tokens(text)
|
||||||
|
if Tiktoken.available() then
|
||||||
|
return Tiktoken.count(text)
|
||||||
|
end
|
||||||
|
local tokens = 0
|
||||||
|
local current_token = ""
|
||||||
|
for char in text:gmatch(".") do
|
||||||
|
if char == " " or char == "\n" then
|
||||||
|
if current_token ~= "" then
|
||||||
|
tokens = tokens + 1
|
||||||
|
current_token = ""
|
||||||
|
end
|
||||||
|
else
|
||||||
|
current_token = current_token .. char
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if current_token ~= "" then
|
||||||
|
tokens = tokens + 1
|
||||||
|
end
|
||||||
|
return tokens
|
||||||
|
end
|
||||||
|
|
||||||
|
--- Calculate the cost of a given text in dollars.
|
||||||
|
-- @param text The text to calculate the cost of.
|
||||||
|
-- @param model The model to use to calculate the cost.
|
||||||
|
-- @return The cost of the given text in dollars.
|
||||||
|
function Tokens.calculate_usage_in_dollars(text, model)
|
||||||
|
local tokens = Tokens.calculate_tokens(text)
|
||||||
|
return Tokens.usage_in_dollars(tokens, model)
|
||||||
|
end
|
||||||
|
|
||||||
|
--- Calculate the cost of a given number of tokens in dollars.
|
||||||
|
-- @param tokens The number of tokens to calculate the cost of.
|
||||||
|
-- @param model The model to use to calculate the cost.
|
||||||
|
-- @return The cost of the given number of tokens in dollars.
|
||||||
|
function Tokens.usage_in_dollars(tokens, model)
|
||||||
|
return tokens * cost_per_token[model or "davinci"]
|
||||||
|
end
|
||||||
|
|
||||||
|
return Tokens
|
Loading…
x
Reference in New Issue
Block a user