diff --git a/README.md b/README.md index 9317fbd..cf613bd 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Install `avante.nvim` using [lazy.nvim](https://github.com/folke/lazy.nvim): { "yetone/avante.nvim", event = "VeryLazy", - build = "make", + build = "make", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count opts = { -- add any opts here }, @@ -50,7 +50,7 @@ For Windows users, change the build command to the following: { "yetone/avante.nvim", event = "VeryLazy", - build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1", + build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count -- rest of the config } ``` diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 77bf043..efce123 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -1,5 +1,5 @@ local Utils = require("avante.utils") -local Tiktoken = require("avante.tiktoken") +local Tokens = require("avante.utils.tokens") local P = require("avante.providers") ---@class AvanteProviderFunctor @@ -13,7 +13,7 @@ M.parse_message = function(opts) text = string.format("```%s\n%s```", opts.code_lang, opts.code_content), } - if Tiktoken.count(code_prompt_obj.text) > 1024 then + if Tokens.calculate_tokens(code_prompt_obj.text) > 1024 then code_prompt_obj.cache_control = { type = "ephemeral" } end @@ -31,7 +31,7 @@ M.parse_message = function(opts) text = string.format("```%s\n%s```", opts.code_lang, opts.selected_code_content), } - if Tiktoken.count(selected_code_obj.text) > 1024 then + if Tokens.calculate_tokens(selected_code_obj.text) > 1024 then selected_code_obj.cache_control = { type = "ephemeral" } end @@ -50,7 +50,7 @@ M.parse_message = function(opts) text = user_prompt, } - if Tiktoken.count(user_prompt_obj.text) > 1024 then + if Tokens.calculate_tokens(user_prompt_obj.text) > 1024 then user_prompt_obj.cache_control = { type = "ephemeral" } end @@ -79,6 +79,9 @@ M.parse_response = function(data_stream, event_state, opts) end end +---@param provider AvanteProviderFunctor +---@param code_opts AvantePromptOptions +---@return table M.parse_curl_args = function(provider, code_opts) local base, body_opts = P.parse_config(provider) diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index e9655f9..e7f298c 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -62,6 +62,7 @@ local Dressing = require("avante.ui.dressing") ---@field parse_response_data AvanteResponseParser ---@field parse_curl_args? AvanteCurlArgsParser ---@field parse_stream_data? AvanteStreamParser +---@field parse_api_key fun(): string | nil --- ---@class AvanteProviderFunctor ---@field parse_message AvanteMessageParser diff --git a/lua/avante/tiktoken.lua b/lua/avante/tiktoken.lua index 7792662..bb2377a 100644 --- a/lua/avante/tiktoken.lua +++ b/lua/avante/tiktoken.lua @@ -52,7 +52,6 @@ local M = {} function M.setup(model) local ok, core = pcall(require, "tiktoken_core") if not ok then - print("Warn: tiktoken_core is not found!!!!") return end diff --git a/lua/avante/utils/tokens.lua b/lua/avante/utils/tokens.lua new file mode 100644 index 0000000..80c19b3 --- /dev/null +++ b/lua/avante/utils/tokens.lua @@ -0,0 +1,56 @@ +--Taken from https://github.com/jackMort/ChatGPT.nvim/blob/main/lua/chatgpt/flows/chat/tokens.lua +local Tiktoken = require("avante.tiktoken") +local Tokens = {} + +--[[ + cost_per_token + @param {string} token_name + @return {number} cost_per_token +]] +local cost_per_token = { + davinci = 0.000002, +} + +--- Calculate the number of tokens in a given text. +-- @param text The text to calculate the number of tokens for. +-- @return The number of tokens in the given text. +function Tokens.calculate_tokens(text) + if Tiktoken.available() then + return Tiktoken.count(text) + end + local tokens = 0 + local current_token = "" + for char in text:gmatch(".") do + if char == " " or char == "\n" then + if current_token ~= "" then + tokens = tokens + 1 + current_token = "" + end + else + current_token = current_token .. char + end + end + if current_token ~= "" then + tokens = tokens + 1 + end + return tokens +end + +--- Calculate the cost of a given text in dollars. +-- @param text The text to calculate the cost of. +-- @param model The model to use to calculate the cost. +-- @return The cost of the given text in dollars. +function Tokens.calculate_usage_in_dollars(text, model) + local tokens = Tokens.calculate_tokens(text) + return Tokens.usage_in_dollars(tokens, model) +end + +--- Calculate the cost of a given number of tokens in dollars. +-- @param tokens The number of tokens to calculate the cost of. +-- @param model The model to use to calculate the cost. +-- @return The cost of the given number of tokens in dollars. +function Tokens.usage_in_dollars(tokens, model) + return tokens * cost_per_token[model or "davinci"] +end + +return Tokens