From e98fa46becb7d6cd108a94c1fbf4348d8c44a8b1 Mon Sep 17 00:00:00 2001 From: Michael Gendy <50384638+Mng-dev-ai@users.noreply.github.com> Date: Tue, 17 Dec 2024 14:43:25 +0200 Subject: [PATCH] feat(tokens): add token count display to sidebar (#956) * feat (tokens) add token count display to sidebar * refactor: calculate the real tokens and reuse input hints to avoid occlusion --------- Co-authored-by: yetone --- crates/avante-templates/src/lib.rs | 2 +- lua/avante/llm.lua | 68 ++++++++----- lua/avante/path.lua | 7 +- lua/avante/selection.lua | 1 - lua/avante/sidebar.lua | 147 ++++++++++++++++------------- lua/avante/suggestion.lua | 1 - 6 files changed, 133 insertions(+), 93 deletions(-) diff --git a/crates/avante-templates/src/lib.rs b/crates/avante-templates/src/lib.rs index 4e2d7a8..144666c 100644 --- a/crates/avante-templates/src/lib.rs +++ b/crates/avante-templates/src/lib.rs @@ -18,7 +18,7 @@ impl<'a> State<'a> { #[derive(Debug, Serialize, Deserialize)] struct SelectedFile { path: String, - content: String, + content: Option, file_type: String, } diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 1b8a95a..a7dbfaf 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -18,10 +18,10 @@ M.CANCEL_PATTERN = "AvanteLLMEscape" local group = api.nvim_create_augroup("avante_llm", { clear = true }) ----@param opts StreamOptions ----@param Provider AvanteProviderFunctor -M._stream = function(opts, Provider) - -- print opts +---@param opts GeneratePromptsOptions +---@return AvantePromptOptions +M.generate_prompts = function(opts) + local Provider = opts.provider or P[Config.provider] local mode = opts.mode or "planning" ---@type AvanteProviderFunctor local _, body_opts = P.parse_config(Provider) @@ -42,7 +42,8 @@ M._stream = function(opts, Provider) instructions = table.concat(lines, "\n") end - Path.prompts.initialize(Path.prompts.get(opts.bufnr)) + local project_root = Utils.root.get() + Path.prompts.initialize(Path.prompts.get(project_root)) local template_opts = { use_xml_format = Provider.use_xml_format, @@ -104,11 +105,30 @@ M._stream = function(opts, Provider) end ---@type AvantePromptOptions - local code_opts = { + return { system_prompt = system_prompt, messages = messages, image_paths = image_paths, } +end + +---@param opts GeneratePromptsOptions +---@return integer +M.calculate_tokens = function(opts) + local code_opts = M.generate_prompts(opts) + local tokens = Utils.tokens.calculate_tokens(code_opts.system_prompt) + for _, message in ipairs(code_opts.messages) do + tokens = tokens + Utils.tokens.calculate_tokens(message.content) + end + return tokens +end + +---@param opts StreamOptions +M._stream = function(opts) + local Provider = opts.provider or P[Config.provider] + + local code_opts = M.generate_prompts(opts) + ---@type string local current_event_state = nil @@ -248,7 +268,7 @@ M._stream = function(opts, Provider) return active_job end -local function _merge_response(first_response, second_response, opts, Provider) +local function _merge_response(first_response, second_response, opts) local prompt = "\n" .. Config.dual_boost.prompt prompt = prompt :gsub("{{[%s]*provider1_output[%s]*}}", first_response) @@ -259,28 +279,28 @@ local function _merge_response(first_response, second_response, opts, Provider) -- append this reference prompt to the code_opts messages at last opts.instructions = opts.instructions .. prompt - M._stream(opts, Provider) + M._stream(opts) end -local function _collector_process_responses(collector, opts, Provider) +local function _collector_process_responses(collector, opts) if not collector[1] or not collector[2] then Utils.error("One or both responses failed to complete") return end - _merge_response(collector[1], collector[2], opts, Provider) + _merge_response(collector[1], collector[2], opts) end -local function _collector_add_response(collector, index, response, opts, Provider) +local function _collector_add_response(collector, index, response, opts) collector[index] = response collector.count = collector.count + 1 if collector.count == 2 then collector.timer:stop() - _collector_process_responses(collector, opts, Provider) + _collector_process_responses(collector, opts) end end -M._dual_boost_stream = function(opts, Provider, Provider1, Provider2) +M._dual_boost_stream = function(opts, Provider1, Provider2) Utils.debug("Starting Dual Boost Stream") local collector = { @@ -299,7 +319,7 @@ M._dual_boost_stream = function(opts, Provider, Provider1, Provider2) Utils.warn("Dual boost stream timeout reached") collector.timer:stop() -- Process whatever responses we have - _collector_process_responses(collector, opts, Provider) + _collector_process_responses(collector, opts) end end) ) @@ -317,15 +337,19 @@ M._dual_boost_stream = function(opts, Provider, Provider1, Provider2) return end Utils.debug(string.format("Response %d completed", index)) - _collector_add_response(collector, index, response, opts, Provider) + _collector_add_response(collector, index, response, opts) end, }) end -- Start both streams local success, err = xpcall(function() - M._stream(create_stream_opts(1), Provider1) - M._stream(create_stream_opts(2), Provider2) + local opts1 = create_stream_opts(1) + opts1.provider = Provider1 + M._stream(opts1) + local opts2 = create_stream_opts(2) + opts2.provider = Provider2 + M._stream(opts2) end, function(err) return err end) if not success then Utils.error("Failed to start dual_boost streams: " .. tostring(err)) end end @@ -348,12 +372,13 @@ end ---@field diagnostics string | nil ---@field history_messages AvanteLLMMessage[] --- ----@class StreamOptions: TemplateOptions +---@class GeneratePromptsOptions: TemplateOptions ---@field ask boolean ----@field bufnr integer ---@field instructions string ---@field mode LlmMode ---@field provider AvanteProviderFunctor | nil +--- +---@class StreamOptions: GeneratePromptsOptions ---@field on_chunk AvanteChunkParser ---@field on_complete AvanteCompleteParser @@ -375,11 +400,10 @@ M.stream = function(opts) return original_on_complete(err) end) end - local Provider = opts.provider or P[Config.provider] if Config.dual_boost.enabled then - M._dual_boost_stream(opts, Provider, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider]) + M._dual_boost_stream(opts, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider]) else - M._stream(opts, Provider) + M._stream(opts) end end diff --git a/lua/avante/path.lua b/lua/avante/path.lua index 25ab88c..d0ff1d0 100644 --- a/lua/avante/path.lua +++ b/lua/avante/path.lua @@ -88,16 +88,15 @@ local templates = nil Prompt.templates = { planning = nil, editing = nil, suggesting = nil } --- Creates a directory in the cache path for the given buffer and copies the custom prompts to it. -- We need to do this beacuse the prompt template engine requires a given directory to load all required files. -- PERF: Hmm instead of copy to cache, we can also load in globals context, but it requires some work on bindings. (eh maybe?) ----@param bufnr number +---@param project_root string ---@return string the resulted cache_directory to be loaded with avante_templates -Prompt.get = function(bufnr) +Prompt.get = function(project_root) if not P.available() then error("Make sure to build avante (missing avante_templates)", 2) end -- get root directory of given bufnr - local directory = Path:new(Utils.root.get({ buf = bufnr })) + local directory = Path:new(project_root) if Utils.get_os_name() == "windows" then directory = Path:new(directory:absolute():gsub("^%a:", "")[1]) end ---@cast directory Path ---@type Path diff --git a/lua/avante/selection.lua b/lua/avante/selection.lua index 698c2dd..ad9610a 100644 --- a/lua/avante/selection.lua +++ b/lua/avante/selection.lua @@ -206,7 +206,6 @@ function Selection:create_editing_input() local diagnostics = Utils.get_current_selection_diagnostics(code_bufnr, self.selection) Llm.stream({ - bufnr = code_bufnr, ask = true, project_context = vim.json.encode(project_context), diagnostics = vim.json.encode(diagnostics), diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 1fa613a..d6bb353 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -1491,6 +1491,74 @@ function Sidebar:create_input_container(opts) local chat_history = Path.history.load(self.code.bufnr) + ---@param request string + ---@return GeneratePromptsOptions + local function get_generate_prompts_options(request) + local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr }) + + local selected_code_content = nil + if self.code.selection ~= nil then selected_code_content = self.code.selection.content end + + local mentions = Utils.extract_mentions(request) + request = mentions.new_content + + local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$") + + local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil + + local selected_files_contents = self.file_selector:get_selected_files_contents() + + local diagnostics = nil + if mentions.enable_diagnostics then + if self.code ~= nil and self.code.bufnr ~= nil and self.code.selection ~= nil then + diagnostics = Utils.get_current_selection_diagnostics(self.code.bufnr, self.code.selection) + else + diagnostics = Utils.get_diagnostics(self.code.bufnr) + end + end + + local history_messages = {} + for i = #chat_history, 1, -1 do + local entry = chat_history[i] + if entry.reset_memory then break end + if + entry.request == nil + or entry.original_response == nil + or entry.request == "" + or entry.original_response == "" + then + break + end + table.insert(history_messages, 1, { role = "assistant", content = entry.original_response }) + local user_content = "" + if entry.selected_file ~= nil then + user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n" + end + if entry.selected_code ~= nil then + user_content = user_content + .. "SELECTED CODE:\n\n```" + .. entry.selected_code.filetype + .. "\n" + .. entry.selected_code.content + .. "\n```\n\n" + end + user_content = user_content .. "USER PROMPT:\n\n" .. entry.request + table.insert(history_messages, 1, { role = "user", content = user_content }) + end + + return { + ask = opts.ask, + project_context = vim.json.encode(project_context), + selected_files = selected_files_contents, + diagnostics = vim.json.encode(diagnostics), + history_messages = history_messages, + code_lang = filetype, + selected_code = selected_code_content, + instructions = request, + mode = "planning", + } + end + ---@param request string local function handle_submit(request) local model = Config.has_provider(Config.provider) and Config.get_provider(Config.provider).model or "default" @@ -1518,9 +1586,6 @@ function Sidebar:create_input_container(opts) self:update_content("", { focus = true, scroll = false }) self:update_content(content_prefix .. generating_text) - local selected_code_content = nil - if self.code.selection ~= nil then selected_code_content = self.code.selection.content end - if request:sub(1, 1) == "/" then local command, args = request:match("^/(%S+)%s*(.*)") if command == nil then @@ -1542,8 +1607,6 @@ function Sidebar:create_input_container(opts) Utils.error("Invalid end line number", { once = true, title = "Avante" }) return end - selected_code_content = - table.concat(api.nvim_buf_get_lines(self.code.bufnr, start_line - 1, end_line, false), "\n") request = question end) else @@ -1632,67 +1695,15 @@ function Sidebar:create_input_container(opts) Path.history.save(self.code.bufnr, chat_history) end - local mentions = Utils.extract_mentions(request) - request = mentions.new_content - - local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$") - - local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil - - local selected_files_contents = self.file_selector:get_selected_files_contents() - - local diagnostics = nil - if mentions.enable_diagnostics then - if self.code ~= nil and self.code.bufnr ~= nil and self.code.selection ~= nil then - diagnostics = Utils.get_current_selection_diagnostics(self.code.bufnr, self.code.selection) - else - diagnostics = Utils.get_diagnostics(self.code.bufnr) - end - end - - local history_messages = {} - for i = #chat_history, 1, -1 do - local entry = chat_history[i] - if entry.reset_memory then break end - if - entry.request == nil - or entry.original_response == nil - or entry.request == "" - or entry.original_response == "" - then - break - end - table.insert(history_messages, 1, { role = "assistant", content = entry.original_response }) - local user_content = "" - if entry.selected_file ~= nil then - user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n" - end - if entry.selected_code ~= nil then - user_content = user_content - .. "SELECTED CODE:\n\n```" - .. entry.selected_code.filetype - .. "\n" - .. entry.selected_code.content - .. "\n```\n\n" - end - user_content = user_content .. "USER PROMPT:\n\n" .. entry.request - table.insert(history_messages, 1, { role = "user", content = user_content }) - end - - Llm.stream({ - bufnr = self.code.bufnr, - ask = opts.ask, - project_context = vim.json.encode(project_context), - selected_files = selected_files_contents, - diagnostics = vim.json.encode(diagnostics), - history_messages = history_messages, - code_lang = filetype, - selected_code = selected_code_content, - instructions = request, - mode = "planning", + local generate_prompts_options = get_generate_prompts_options(request) + ---@type StreamOptions + ---@diagnostic disable-next-line: assign-type-mismatch + local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, { on_chunk = on_chunk, on_complete = on_complete, }) + + Llm.stream(stream_options) end local get_position = function() @@ -1827,7 +1838,15 @@ function Sidebar:create_input_container(opts) local function show_hint() close_hint() -- Close the existing hint window - local hint_text = (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert) + local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n") + + local generate_prompts_options = get_generate_prompts_options(input_value) + local tokens = Llm.calculate_tokens(generate_prompts_options) + + local hint_text = "Tokens: " + .. tostring(tokens) + .. "; " + .. (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert) .. ": submit" local buf = api.nvim_create_buf(false, true) diff --git a/lua/avante/suggestion.lua b/lua/avante/suggestion.lua index 82db64b..ac9c70d 100644 --- a/lua/avante/suggestion.lua +++ b/lua/avante/suggestion.lua @@ -69,7 +69,6 @@ function Suggestion:suggest() Llm.stream({ provider = provider, - bufnr = bufnr, ask = true, selected_files = { { content = code_content, file_type = filetype, path = "" } }, code_lang = filetype,