feat: make tiktoken optional (#245)
This commit is contained in:
		
							parent
							
								
									3d3a249119
								
							
						
					
					
						commit
						b874045885
					
				@ -23,7 +23,7 @@ Install `avante.nvim` using [lazy.nvim](https://github.com/folke/lazy.nvim):
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
  "yetone/avante.nvim",
 | 
					  "yetone/avante.nvim",
 | 
				
			||||||
  event = "VeryLazy",
 | 
					  event = "VeryLazy",
 | 
				
			||||||
  build = "make",
 | 
					  build = "make", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count
 | 
				
			||||||
  opts = {
 | 
					  opts = {
 | 
				
			||||||
    -- add any opts here
 | 
					    -- add any opts here
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
@ -50,7 +50,7 @@ For Windows users, change the build command to the following:
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
  "yetone/avante.nvim",
 | 
					  "yetone/avante.nvim",
 | 
				
			||||||
  event = "VeryLazy",
 | 
					  event = "VeryLazy",
 | 
				
			||||||
  build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1",
 | 
					  build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count
 | 
				
			||||||
  -- rest of the config
 | 
					  -- rest of the config
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,5 @@
 | 
				
			|||||||
local Utils = require("avante.utils")
 | 
					local Utils = require("avante.utils")
 | 
				
			||||||
local Tiktoken = require("avante.tiktoken")
 | 
					local Tokens = require("avante.utils.tokens")
 | 
				
			||||||
local P = require("avante.providers")
 | 
					local P = require("avante.providers")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---@class AvanteProviderFunctor
 | 
					---@class AvanteProviderFunctor
 | 
				
			||||||
@ -13,7 +13,7 @@ M.parse_message = function(opts)
 | 
				
			|||||||
    text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
 | 
					    text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if Tiktoken.count(code_prompt_obj.text) > 1024 then
 | 
					  if Tokens.calculate_tokens(code_prompt_obj.text) > 1024 then
 | 
				
			||||||
    code_prompt_obj.cache_control = { type = "ephemeral" }
 | 
					    code_prompt_obj.cache_control = { type = "ephemeral" }
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -31,7 +31,7 @@ M.parse_message = function(opts)
 | 
				
			|||||||
      text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.selected_code_content),
 | 
					      text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.selected_code_content),
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if Tiktoken.count(selected_code_obj.text) > 1024 then
 | 
					    if Tokens.calculate_tokens(selected_code_obj.text) > 1024 then
 | 
				
			||||||
      selected_code_obj.cache_control = { type = "ephemeral" }
 | 
					      selected_code_obj.cache_control = { type = "ephemeral" }
 | 
				
			||||||
    end
 | 
					    end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -50,7 +50,7 @@ M.parse_message = function(opts)
 | 
				
			|||||||
    text = user_prompt,
 | 
					    text = user_prompt,
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if Tiktoken.count(user_prompt_obj.text) > 1024 then
 | 
					  if Tokens.calculate_tokens(user_prompt_obj.text) > 1024 then
 | 
				
			||||||
    user_prompt_obj.cache_control = { type = "ephemeral" }
 | 
					    user_prompt_obj.cache_control = { type = "ephemeral" }
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -79,6 +79,9 @@ M.parse_response = function(data_stream, event_state, opts)
 | 
				
			|||||||
  end
 | 
					  end
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@param provider AvanteProviderFunctor
 | 
				
			||||||
 | 
					---@param code_opts AvantePromptOptions
 | 
				
			||||||
 | 
					---@return table
 | 
				
			||||||
M.parse_curl_args = function(provider, code_opts)
 | 
					M.parse_curl_args = function(provider, code_opts)
 | 
				
			||||||
  local base, body_opts = P.parse_config(provider)
 | 
					  local base, body_opts = P.parse_config(provider)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -62,6 +62,7 @@ local Dressing = require("avante.ui.dressing")
 | 
				
			|||||||
---@field parse_response_data AvanteResponseParser
 | 
					---@field parse_response_data AvanteResponseParser
 | 
				
			||||||
---@field parse_curl_args? AvanteCurlArgsParser
 | 
					---@field parse_curl_args? AvanteCurlArgsParser
 | 
				
			||||||
---@field parse_stream_data? AvanteStreamParser
 | 
					---@field parse_stream_data? AvanteStreamParser
 | 
				
			||||||
 | 
					---@field parse_api_key fun(): string | nil
 | 
				
			||||||
---
 | 
					---
 | 
				
			||||||
---@class AvanteProviderFunctor
 | 
					---@class AvanteProviderFunctor
 | 
				
			||||||
---@field parse_message AvanteMessageParser
 | 
					---@field parse_message AvanteMessageParser
 | 
				
			||||||
 | 
				
			|||||||
@ -52,7 +52,6 @@ local M = {}
 | 
				
			|||||||
function M.setup(model)
 | 
					function M.setup(model)
 | 
				
			||||||
  local ok, core = pcall(require, "tiktoken_core")
 | 
					  local ok, core = pcall(require, "tiktoken_core")
 | 
				
			||||||
  if not ok then
 | 
					  if not ok then
 | 
				
			||||||
    print("Warn: tiktoken_core is not found!!!!")
 | 
					 | 
				
			||||||
    return
 | 
					    return
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										56
									
								
								lua/avante/utils/tokens.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								lua/avante/utils/tokens.lua
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,56 @@
 | 
				
			|||||||
 | 
					--Taken from https://github.com/jackMort/ChatGPT.nvim/blob/main/lua/chatgpt/flows/chat/tokens.lua
 | 
				
			||||||
 | 
					local Tiktoken = require("avante.tiktoken")
 | 
				
			||||||
 | 
					local Tokens = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					--[[
 | 
				
			||||||
 | 
					  cost_per_token
 | 
				
			||||||
 | 
					  @param {string} token_name
 | 
				
			||||||
 | 
					  @return {number} cost_per_token
 | 
				
			||||||
 | 
					]]
 | 
				
			||||||
 | 
					local cost_per_token = {
 | 
				
			||||||
 | 
					  davinci = 0.000002,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					--- Calculate the number of tokens in a given text.
 | 
				
			||||||
 | 
					-- @param text The text to calculate the number of tokens for.
 | 
				
			||||||
 | 
					-- @return The number of tokens in the given text.
 | 
				
			||||||
 | 
					function Tokens.calculate_tokens(text)
 | 
				
			||||||
 | 
					  if Tiktoken.available() then
 | 
				
			||||||
 | 
					    return Tiktoken.count(text)
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					  local tokens = 0
 | 
				
			||||||
 | 
					  local current_token = ""
 | 
				
			||||||
 | 
					  for char in text:gmatch(".") do
 | 
				
			||||||
 | 
					    if char == " " or char == "\n" then
 | 
				
			||||||
 | 
					      if current_token ~= "" then
 | 
				
			||||||
 | 
					        tokens = tokens + 1
 | 
				
			||||||
 | 
					        current_token = ""
 | 
				
			||||||
 | 
					      end
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      current_token = current_token .. char
 | 
				
			||||||
 | 
					    end
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					  if current_token ~= "" then
 | 
				
			||||||
 | 
					    tokens = tokens + 1
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					  return tokens
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					--- Calculate the cost of a given text in dollars.
 | 
				
			||||||
 | 
					-- @param text The text to calculate the cost of.
 | 
				
			||||||
 | 
					-- @param model The model to use to calculate the cost.
 | 
				
			||||||
 | 
					-- @return The cost of the given text in dollars.
 | 
				
			||||||
 | 
					function Tokens.calculate_usage_in_dollars(text, model)
 | 
				
			||||||
 | 
					  local tokens = Tokens.calculate_tokens(text)
 | 
				
			||||||
 | 
					  return Tokens.usage_in_dollars(tokens, model)
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					--- Calculate the cost of a given number of tokens in dollars.
 | 
				
			||||||
 | 
					-- @param tokens The number of tokens to calculate the cost of.
 | 
				
			||||||
 | 
					-- @param model The model to use to calculate the cost.
 | 
				
			||||||
 | 
					-- @return The cost of the given number of tokens in dollars.
 | 
				
			||||||
 | 
					function Tokens.usage_in_dollars(tokens, model)
 | 
				
			||||||
 | 
					  return tokens * cost_per_token[model or "davinci"]
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					return Tokens
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user