From eb1bc657a174cb10fe854a3966aecdaccfb91eff Mon Sep 17 00:00:00 2001 From: Alexander Muratov Date: Fri, 13 Dec 2024 20:00:43 +0500 Subject: [PATCH] refactor & fix: improve libraries initialization (#921) * refactor(libs): extract libraries initialization Extract initialization logic into separate functions for better error handling and reusability. * fix(libs): improve core libraries init This change helps prevent runtime errors from uninitialized libraries. --- lua/avante/path.lua | 26 ++++++++++++++++++-------- lua/avante/repo_map.lua | 34 +++++++++++++++++++++------------- lua/avante/tokenizers.lua | 38 ++++++++++++++++++++++++++++---------- 3 files changed, 67 insertions(+), 31 deletions(-) diff --git a/lua/avante/path.lua b/lua/avante/path.lua index 864d2ba..09c2718 100644 --- a/lua/avante/path.lua +++ b/lua/avante/path.lua @@ -178,6 +178,22 @@ end P.repo_map = RepoMap +---@return AvanteTemplates|nil +P._init_templates_lib = function() + if templates ~= nil then + return templates + end + local ok, module = pcall(require, "avante_templates") + ---@cast module AvanteTemplates + ---@cast ok boolean + if not ok then + return nil + end + templates = module + + return templates +end + P.setup = function() local history_path = Path:new(Config.history.storage_path) if not history_path:exists() then history_path:mkdir({ parents = true }) end @@ -191,16 +207,10 @@ P.setup = function() if not data_path:exists() then data_path:mkdir({ parents = true }) end P.data_path = data_path - vim.defer_fn(function() - local ok, module = pcall(require, "avante_templates") - ---@cast module AvanteTemplates - ---@cast ok boolean - if not ok then return end - if templates == nil then templates = module end - end, 1000) + vim.defer_fn(P._init_templates_lib, 1000) end -P.available = function() return templates ~= nil end +P.available = function() return P._init_templates_lib() ~= nil end P.clear = function() P.cache_path:rm({ recursive = true }) diff --git a/lua/avante/repo_map.lua b/lua/avante/repo_map.lua index b7bd477..fd974ed 100644 --- a/lua/avante/repo_map.lua +++ b/lua/avante/repo_map.lua @@ -15,16 +15,23 @@ local repo_map_lib = nil ---@class avante.utils.repo_map local RepoMap = {} -function RepoMap.setup() - vim.defer_fn(function() - local ok, core = pcall(require, "avante_repo_map") - if not ok then - error("Failed to load avante_repo_map") - return - end +---@return AvanteRepoMap|nil +function RepoMap._init_repo_map_lib() + if repo_map_lib ~= nil then + return repo_map_lib + end - if repo_map_lib == nil then repo_map_lib = core end - end, 1000) + local ok, core = pcall(require, "avante_repo_map") + if not ok then + return nil + end + + repo_map_lib = core + return repo_map_lib +end + +function RepoMap.setup() + vim.defer_fn(RepoMap._init_repo_map_lib, 1000) end function RepoMap.get_ts_lang(filepath) @@ -51,12 +58,13 @@ function RepoMap._build_repo_map(project_root, file_ext) local negate_patterns = vim.list_extend(gitignore_negate_patterns, Config.repo_map.negate_patterns) local filepaths = Utils.scan_directory(project_root, ignore_patterns, negate_patterns) + if filepaths and not RepoMap._init_repo_map_lib() then + -- or just throw an error if we don't want to execute request without codebase + Utils.error("Failed to load avante_repo_map") + return + end vim.iter(filepaths):each(function(filepath) if not Utils.is_same_file_ext(file_ext, filepath) then return end - if not repo_map_lib then - Utils.error("Failed to load avante_repo_map") - return - end local filetype = RepoMap.get_ts_lang(filepath) local definitions = filetype and repo_map_lib.stringify_definitions(filetype, Utils.file.read_content(filepath) or "") diff --git a/lua/avante/tokenizers.lua b/lua/avante/tokenizers.lua index 6dc2513..2e41d5a 100644 --- a/lua/avante/tokenizers.lua +++ b/lua/avante/tokenizers.lua @@ -5,20 +5,38 @@ local Utils = require("avante.utils") ---@field encode fun(string): integer[] local tokenizers = nil +---@type "gpt-4o" | string +local current_model = "gpt-4o" + local M = {} +---@param model "gpt-4o" | string +---@return AvanteTokenizer|nil +M._init_tokenizers_lib = function(model) + if tokenizers ~= nil then + return tokenizers + end + + local ok, core = pcall(require, "avante_tokenizers") + if not ok then + return nil + end + + ---@cast core AvanteTokenizer + tokenizers = core + + core.from_pretrained(model) + + return tokenizers +end + ---@param model "gpt-4o" | string ---@param warning? boolean M.setup = function(model, warning) + current_model = model warning = warning or true vim.defer_fn(function() - local ok, core = pcall(require, "avante_tokenizers") - if not ok then return end - - ---@cast core AvanteTokenizer - if tokenizers == nil then tokenizers = core end - - core.from_pretrained(model) + M._init_tokenizers_lib(model) end, 1000) if warning then @@ -32,11 +50,11 @@ M.setup = function(model, warning) end end -M.available = function() return tokenizers ~= nil end +M.available = function() return M._init_tokenizers_lib(current_model) ~= nil end ---@param prompt string M.encode = function(prompt) - if not tokenizers then return nil end + if not M.available() then return nil end if not prompt or prompt == "" then return nil end if type(prompt) ~= "string" then error("Prompt is not type string", 2) end @@ -45,7 +63,7 @@ end ---@param prompt string M.count = function(prompt) - if not tokenizers then return math.ceil(#prompt * 0.5) end + if not M.available() then return math.ceil(#prompt * 0.5) end local tokens = M.encode(prompt) if not tokens then return 0 end