diff --git a/README.md b/README.md
index 9317fbd..cf613bd 100644
--- a/README.md
+++ b/README.md
@@ -23,7 +23,7 @@ Install `avante.nvim` using [lazy.nvim](https://github.com/folke/lazy.nvim):
{
"yetone/avante.nvim",
event = "VeryLazy",
- build = "make",
+ build = "make", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count
opts = {
-- add any opts here
},
@@ -50,7 +50,7 @@ For Windows users, change the build command to the following:
{
"yetone/avante.nvim",
event = "VeryLazy",
- build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1",
+ build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count
-- rest of the config
}
```
diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua
index 77bf043..efce123 100644
--- a/lua/avante/providers/claude.lua
+++ b/lua/avante/providers/claude.lua
@@ -1,5 +1,5 @@
local Utils = require("avante.utils")
-local Tiktoken = require("avante.tiktoken")
+local Tokens = require("avante.utils.tokens")
local P = require("avante.providers")
---@class AvanteProviderFunctor
@@ -13,7 +13,7 @@ M.parse_message = function(opts)
text = string.format("```%s\n%s```
", opts.code_lang, opts.code_content),
}
- if Tiktoken.count(code_prompt_obj.text) > 1024 then
+ if Tokens.calculate_tokens(code_prompt_obj.text) > 1024 then
code_prompt_obj.cache_control = { type = "ephemeral" }
end
@@ -31,7 +31,7 @@ M.parse_message = function(opts)
text = string.format("```%s\n%s```
", opts.code_lang, opts.selected_code_content),
}
- if Tiktoken.count(selected_code_obj.text) > 1024 then
+ if Tokens.calculate_tokens(selected_code_obj.text) > 1024 then
selected_code_obj.cache_control = { type = "ephemeral" }
end
@@ -50,7 +50,7 @@ M.parse_message = function(opts)
text = user_prompt,
}
- if Tiktoken.count(user_prompt_obj.text) > 1024 then
+ if Tokens.calculate_tokens(user_prompt_obj.text) > 1024 then
user_prompt_obj.cache_control = { type = "ephemeral" }
end
@@ -79,6 +79,9 @@ M.parse_response = function(data_stream, event_state, opts)
end
end
+---@param provider AvanteProviderFunctor
+---@param code_opts AvantePromptOptions
+---@return table
M.parse_curl_args = function(provider, code_opts)
local base, body_opts = P.parse_config(provider)
diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua
index e9655f9..e7f298c 100644
--- a/lua/avante/providers/init.lua
+++ b/lua/avante/providers/init.lua
@@ -62,6 +62,7 @@ local Dressing = require("avante.ui.dressing")
---@field parse_response_data AvanteResponseParser
---@field parse_curl_args? AvanteCurlArgsParser
---@field parse_stream_data? AvanteStreamParser
+---@field parse_api_key fun(): string | nil
---
---@class AvanteProviderFunctor
---@field parse_message AvanteMessageParser
diff --git a/lua/avante/tiktoken.lua b/lua/avante/tiktoken.lua
index 7792662..bb2377a 100644
--- a/lua/avante/tiktoken.lua
+++ b/lua/avante/tiktoken.lua
@@ -52,7 +52,6 @@ local M = {}
function M.setup(model)
local ok, core = pcall(require, "tiktoken_core")
if not ok then
- print("Warn: tiktoken_core is not found!!!!")
return
end
diff --git a/lua/avante/utils/tokens.lua b/lua/avante/utils/tokens.lua
new file mode 100644
index 0000000..80c19b3
--- /dev/null
+++ b/lua/avante/utils/tokens.lua
@@ -0,0 +1,56 @@
+--Taken from https://github.com/jackMort/ChatGPT.nvim/blob/main/lua/chatgpt/flows/chat/tokens.lua
+local Tiktoken = require("avante.tiktoken")
+local Tokens = {}
+
+--[[
+ cost_per_token
+ @param {string} token_name
+ @return {number} cost_per_token
+]]
+local cost_per_token = {
+ davinci = 0.000002,
+}
+
+--- Calculate the number of tokens in a given text.
+-- @param text The text to calculate the number of tokens for.
+-- @return The number of tokens in the given text.
+function Tokens.calculate_tokens(text)
+ if Tiktoken.available() then
+ return Tiktoken.count(text)
+ end
+ local tokens = 0
+ local current_token = ""
+ for char in text:gmatch(".") do
+ if char == " " or char == "\n" then
+ if current_token ~= "" then
+ tokens = tokens + 1
+ current_token = ""
+ end
+ else
+ current_token = current_token .. char
+ end
+ end
+ if current_token ~= "" then
+ tokens = tokens + 1
+ end
+ return tokens
+end
+
+--- Calculate the cost of a given text in dollars.
+-- @param text The text to calculate the cost of.
+-- @param model The model to use to calculate the cost.
+-- @return The cost of the given text in dollars.
+function Tokens.calculate_usage_in_dollars(text, model)
+ local tokens = Tokens.calculate_tokens(text)
+ return Tokens.usage_in_dollars(tokens, model)
+end
+
+--- Calculate the cost of a given number of tokens in dollars.
+-- @param tokens The number of tokens to calculate the cost of.
+-- @param model The model to use to calculate the cost.
+-- @return The cost of the given number of tokens in dollars.
+function Tokens.usage_in_dollars(tokens, model)
+ return tokens * cost_per_token[model or "davinci"]
+end
+
+return Tokens