feat(tokens): add token count display to sidebar (#956)
* feat (tokens) add token count display to sidebar * refactor: calculate the real tokens and reuse input hints to avoid occlusion --------- Co-authored-by: yetone <yetoneful@gmail.com>
This commit is contained in:
parent
e612ad7566
commit
e98fa46bec
@ -18,7 +18,7 @@ impl<'a> State<'a> {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct SelectedFile {
|
||||
path: String,
|
||||
content: String,
|
||||
content: Option<String>,
|
||||
file_type: String,
|
||||
}
|
||||
|
||||
|
@ -18,10 +18,10 @@ M.CANCEL_PATTERN = "AvanteLLMEscape"
|
||||
|
||||
local group = api.nvim_create_augroup("avante_llm", { clear = true })
|
||||
|
||||
---@param opts StreamOptions
|
||||
---@param Provider AvanteProviderFunctor
|
||||
M._stream = function(opts, Provider)
|
||||
-- print opts
|
||||
---@param opts GeneratePromptsOptions
|
||||
---@return AvantePromptOptions
|
||||
M.generate_prompts = function(opts)
|
||||
local Provider = opts.provider or P[Config.provider]
|
||||
local mode = opts.mode or "planning"
|
||||
---@type AvanteProviderFunctor
|
||||
local _, body_opts = P.parse_config(Provider)
|
||||
@ -42,7 +42,8 @@ M._stream = function(opts, Provider)
|
||||
instructions = table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
Path.prompts.initialize(Path.prompts.get(opts.bufnr))
|
||||
local project_root = Utils.root.get()
|
||||
Path.prompts.initialize(Path.prompts.get(project_root))
|
||||
|
||||
local template_opts = {
|
||||
use_xml_format = Provider.use_xml_format,
|
||||
@ -104,11 +105,30 @@ M._stream = function(opts, Provider)
|
||||
end
|
||||
|
||||
---@type AvantePromptOptions
|
||||
local code_opts = {
|
||||
return {
|
||||
system_prompt = system_prompt,
|
||||
messages = messages,
|
||||
image_paths = image_paths,
|
||||
}
|
||||
end
|
||||
|
||||
---@param opts GeneratePromptsOptions
|
||||
---@return integer
|
||||
M.calculate_tokens = function(opts)
|
||||
local code_opts = M.generate_prompts(opts)
|
||||
local tokens = Utils.tokens.calculate_tokens(code_opts.system_prompt)
|
||||
for _, message in ipairs(code_opts.messages) do
|
||||
tokens = tokens + Utils.tokens.calculate_tokens(message.content)
|
||||
end
|
||||
return tokens
|
||||
end
|
||||
|
||||
---@param opts StreamOptions
|
||||
M._stream = function(opts)
|
||||
local Provider = opts.provider or P[Config.provider]
|
||||
|
||||
local code_opts = M.generate_prompts(opts)
|
||||
|
||||
---@type string
|
||||
local current_event_state = nil
|
||||
|
||||
@ -248,7 +268,7 @@ M._stream = function(opts, Provider)
|
||||
return active_job
|
||||
end
|
||||
|
||||
local function _merge_response(first_response, second_response, opts, Provider)
|
||||
local function _merge_response(first_response, second_response, opts)
|
||||
local prompt = "\n" .. Config.dual_boost.prompt
|
||||
prompt = prompt
|
||||
:gsub("{{[%s]*provider1_output[%s]*}}", first_response)
|
||||
@ -259,28 +279,28 @@ local function _merge_response(first_response, second_response, opts, Provider)
|
||||
-- append this reference prompt to the code_opts messages at last
|
||||
opts.instructions = opts.instructions .. prompt
|
||||
|
||||
M._stream(opts, Provider)
|
||||
M._stream(opts)
|
||||
end
|
||||
|
||||
local function _collector_process_responses(collector, opts, Provider)
|
||||
local function _collector_process_responses(collector, opts)
|
||||
if not collector[1] or not collector[2] then
|
||||
Utils.error("One or both responses failed to complete")
|
||||
return
|
||||
end
|
||||
_merge_response(collector[1], collector[2], opts, Provider)
|
||||
_merge_response(collector[1], collector[2], opts)
|
||||
end
|
||||
|
||||
local function _collector_add_response(collector, index, response, opts, Provider)
|
||||
local function _collector_add_response(collector, index, response, opts)
|
||||
collector[index] = response
|
||||
collector.count = collector.count + 1
|
||||
|
||||
if collector.count == 2 then
|
||||
collector.timer:stop()
|
||||
_collector_process_responses(collector, opts, Provider)
|
||||
_collector_process_responses(collector, opts)
|
||||
end
|
||||
end
|
||||
|
||||
M._dual_boost_stream = function(opts, Provider, Provider1, Provider2)
|
||||
M._dual_boost_stream = function(opts, Provider1, Provider2)
|
||||
Utils.debug("Starting Dual Boost Stream")
|
||||
|
||||
local collector = {
|
||||
@ -299,7 +319,7 @@ M._dual_boost_stream = function(opts, Provider, Provider1, Provider2)
|
||||
Utils.warn("Dual boost stream timeout reached")
|
||||
collector.timer:stop()
|
||||
-- Process whatever responses we have
|
||||
_collector_process_responses(collector, opts, Provider)
|
||||
_collector_process_responses(collector, opts)
|
||||
end
|
||||
end)
|
||||
)
|
||||
@ -317,15 +337,19 @@ M._dual_boost_stream = function(opts, Provider, Provider1, Provider2)
|
||||
return
|
||||
end
|
||||
Utils.debug(string.format("Response %d completed", index))
|
||||
_collector_add_response(collector, index, response, opts, Provider)
|
||||
_collector_add_response(collector, index, response, opts)
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
-- Start both streams
|
||||
local success, err = xpcall(function()
|
||||
M._stream(create_stream_opts(1), Provider1)
|
||||
M._stream(create_stream_opts(2), Provider2)
|
||||
local opts1 = create_stream_opts(1)
|
||||
opts1.provider = Provider1
|
||||
M._stream(opts1)
|
||||
local opts2 = create_stream_opts(2)
|
||||
opts2.provider = Provider2
|
||||
M._stream(opts2)
|
||||
end, function(err) return err end)
|
||||
if not success then Utils.error("Failed to start dual_boost streams: " .. tostring(err)) end
|
||||
end
|
||||
@ -348,12 +372,13 @@ end
|
||||
---@field diagnostics string | nil
|
||||
---@field history_messages AvanteLLMMessage[]
|
||||
---
|
||||
---@class StreamOptions: TemplateOptions
|
||||
---@class GeneratePromptsOptions: TemplateOptions
|
||||
---@field ask boolean
|
||||
---@field bufnr integer
|
||||
---@field instructions string
|
||||
---@field mode LlmMode
|
||||
---@field provider AvanteProviderFunctor | nil
|
||||
---
|
||||
---@class StreamOptions: GeneratePromptsOptions
|
||||
---@field on_chunk AvanteChunkParser
|
||||
---@field on_complete AvanteCompleteParser
|
||||
|
||||
@ -375,11 +400,10 @@ M.stream = function(opts)
|
||||
return original_on_complete(err)
|
||||
end)
|
||||
end
|
||||
local Provider = opts.provider or P[Config.provider]
|
||||
if Config.dual_boost.enabled then
|
||||
M._dual_boost_stream(opts, Provider, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider])
|
||||
M._dual_boost_stream(opts, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider])
|
||||
else
|
||||
M._stream(opts, Provider)
|
||||
M._stream(opts)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -88,16 +88,15 @@ local templates = nil
|
||||
|
||||
Prompt.templates = { planning = nil, editing = nil, suggesting = nil }
|
||||
|
||||
-- Creates a directory in the cache path for the given buffer and copies the custom prompts to it.
|
||||
-- We need to do this beacuse the prompt template engine requires a given directory to load all required files.
|
||||
-- PERF: Hmm instead of copy to cache, we can also load in globals context, but it requires some work on bindings. (eh maybe?)
|
||||
---@param bufnr number
|
||||
---@param project_root string
|
||||
---@return string the resulted cache_directory to be loaded with avante_templates
|
||||
Prompt.get = function(bufnr)
|
||||
Prompt.get = function(project_root)
|
||||
if not P.available() then error("Make sure to build avante (missing avante_templates)", 2) end
|
||||
|
||||
-- get root directory of given bufnr
|
||||
local directory = Path:new(Utils.root.get({ buf = bufnr }))
|
||||
local directory = Path:new(project_root)
|
||||
if Utils.get_os_name() == "windows" then directory = Path:new(directory:absolute():gsub("^%a:", "")[1]) end
|
||||
---@cast directory Path
|
||||
---@type Path
|
||||
|
@ -206,7 +206,6 @@ function Selection:create_editing_input()
|
||||
local diagnostics = Utils.get_current_selection_diagnostics(code_bufnr, self.selection)
|
||||
|
||||
Llm.stream({
|
||||
bufnr = code_bufnr,
|
||||
ask = true,
|
||||
project_context = vim.json.encode(project_context),
|
||||
diagnostics = vim.json.encode(diagnostics),
|
||||
|
@ -1491,6 +1491,74 @@ function Sidebar:create_input_container(opts)
|
||||
|
||||
local chat_history = Path.history.load(self.code.bufnr)
|
||||
|
||||
---@param request string
|
||||
---@return GeneratePromptsOptions
|
||||
local function get_generate_prompts_options(request)
|
||||
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr })
|
||||
|
||||
local selected_code_content = nil
|
||||
if self.code.selection ~= nil then selected_code_content = self.code.selection.content end
|
||||
|
||||
local mentions = Utils.extract_mentions(request)
|
||||
request = mentions.new_content
|
||||
|
||||
local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$")
|
||||
|
||||
local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil
|
||||
|
||||
local selected_files_contents = self.file_selector:get_selected_files_contents()
|
||||
|
||||
local diagnostics = nil
|
||||
if mentions.enable_diagnostics then
|
||||
if self.code ~= nil and self.code.bufnr ~= nil and self.code.selection ~= nil then
|
||||
diagnostics = Utils.get_current_selection_diagnostics(self.code.bufnr, self.code.selection)
|
||||
else
|
||||
diagnostics = Utils.get_diagnostics(self.code.bufnr)
|
||||
end
|
||||
end
|
||||
|
||||
local history_messages = {}
|
||||
for i = #chat_history, 1, -1 do
|
||||
local entry = chat_history[i]
|
||||
if entry.reset_memory then break end
|
||||
if
|
||||
entry.request == nil
|
||||
or entry.original_response == nil
|
||||
or entry.request == ""
|
||||
or entry.original_response == ""
|
||||
then
|
||||
break
|
||||
end
|
||||
table.insert(history_messages, 1, { role = "assistant", content = entry.original_response })
|
||||
local user_content = ""
|
||||
if entry.selected_file ~= nil then
|
||||
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
|
||||
end
|
||||
if entry.selected_code ~= nil then
|
||||
user_content = user_content
|
||||
.. "SELECTED CODE:\n\n```"
|
||||
.. entry.selected_code.filetype
|
||||
.. "\n"
|
||||
.. entry.selected_code.content
|
||||
.. "\n```\n\n"
|
||||
end
|
||||
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
|
||||
table.insert(history_messages, 1, { role = "user", content = user_content })
|
||||
end
|
||||
|
||||
return {
|
||||
ask = opts.ask,
|
||||
project_context = vim.json.encode(project_context),
|
||||
selected_files = selected_files_contents,
|
||||
diagnostics = vim.json.encode(diagnostics),
|
||||
history_messages = history_messages,
|
||||
code_lang = filetype,
|
||||
selected_code = selected_code_content,
|
||||
instructions = request,
|
||||
mode = "planning",
|
||||
}
|
||||
end
|
||||
|
||||
---@param request string
|
||||
local function handle_submit(request)
|
||||
local model = Config.has_provider(Config.provider) and Config.get_provider(Config.provider).model or "default"
|
||||
@ -1518,9 +1586,6 @@ function Sidebar:create_input_container(opts)
|
||||
self:update_content("", { focus = true, scroll = false })
|
||||
self:update_content(content_prefix .. generating_text)
|
||||
|
||||
local selected_code_content = nil
|
||||
if self.code.selection ~= nil then selected_code_content = self.code.selection.content end
|
||||
|
||||
if request:sub(1, 1) == "/" then
|
||||
local command, args = request:match("^/(%S+)%s*(.*)")
|
||||
if command == nil then
|
||||
@ -1542,8 +1607,6 @@ function Sidebar:create_input_container(opts)
|
||||
Utils.error("Invalid end line number", { once = true, title = "Avante" })
|
||||
return
|
||||
end
|
||||
selected_code_content =
|
||||
table.concat(api.nvim_buf_get_lines(self.code.bufnr, start_line - 1, end_line, false), "\n")
|
||||
request = question
|
||||
end)
|
||||
else
|
||||
@ -1632,67 +1695,15 @@ function Sidebar:create_input_container(opts)
|
||||
Path.history.save(self.code.bufnr, chat_history)
|
||||
end
|
||||
|
||||
local mentions = Utils.extract_mentions(request)
|
||||
request = mentions.new_content
|
||||
|
||||
local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$")
|
||||
|
||||
local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil
|
||||
|
||||
local selected_files_contents = self.file_selector:get_selected_files_contents()
|
||||
|
||||
local diagnostics = nil
|
||||
if mentions.enable_diagnostics then
|
||||
if self.code ~= nil and self.code.bufnr ~= nil and self.code.selection ~= nil then
|
||||
diagnostics = Utils.get_current_selection_diagnostics(self.code.bufnr, self.code.selection)
|
||||
else
|
||||
diagnostics = Utils.get_diagnostics(self.code.bufnr)
|
||||
end
|
||||
end
|
||||
|
||||
local history_messages = {}
|
||||
for i = #chat_history, 1, -1 do
|
||||
local entry = chat_history[i]
|
||||
if entry.reset_memory then break end
|
||||
if
|
||||
entry.request == nil
|
||||
or entry.original_response == nil
|
||||
or entry.request == ""
|
||||
or entry.original_response == ""
|
||||
then
|
||||
break
|
||||
end
|
||||
table.insert(history_messages, 1, { role = "assistant", content = entry.original_response })
|
||||
local user_content = ""
|
||||
if entry.selected_file ~= nil then
|
||||
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
|
||||
end
|
||||
if entry.selected_code ~= nil then
|
||||
user_content = user_content
|
||||
.. "SELECTED CODE:\n\n```"
|
||||
.. entry.selected_code.filetype
|
||||
.. "\n"
|
||||
.. entry.selected_code.content
|
||||
.. "\n```\n\n"
|
||||
end
|
||||
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
|
||||
table.insert(history_messages, 1, { role = "user", content = user_content })
|
||||
end
|
||||
|
||||
Llm.stream({
|
||||
bufnr = self.code.bufnr,
|
||||
ask = opts.ask,
|
||||
project_context = vim.json.encode(project_context),
|
||||
selected_files = selected_files_contents,
|
||||
diagnostics = vim.json.encode(diagnostics),
|
||||
history_messages = history_messages,
|
||||
code_lang = filetype,
|
||||
selected_code = selected_code_content,
|
||||
instructions = request,
|
||||
mode = "planning",
|
||||
local generate_prompts_options = get_generate_prompts_options(request)
|
||||
---@type StreamOptions
|
||||
---@diagnostic disable-next-line: assign-type-mismatch
|
||||
local stream_options = vim.tbl_deep_extend("force", generate_prompts_options, {
|
||||
on_chunk = on_chunk,
|
||||
on_complete = on_complete,
|
||||
})
|
||||
|
||||
Llm.stream(stream_options)
|
||||
end
|
||||
|
||||
local get_position = function()
|
||||
@ -1827,7 +1838,15 @@ function Sidebar:create_input_container(opts)
|
||||
local function show_hint()
|
||||
close_hint() -- Close the existing hint window
|
||||
|
||||
local hint_text = (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert)
|
||||
local input_value = table.concat(api.nvim_buf_get_lines(self.input_container.bufnr, 0, -1, false), "\n")
|
||||
|
||||
local generate_prompts_options = get_generate_prompts_options(input_value)
|
||||
local tokens = Llm.calculate_tokens(generate_prompts_options)
|
||||
|
||||
local hint_text = "Tokens: "
|
||||
.. tostring(tokens)
|
||||
.. "; "
|
||||
.. (fn.mode() ~= "i" and Config.mappings.submit.normal or Config.mappings.submit.insert)
|
||||
.. ": submit"
|
||||
|
||||
local buf = api.nvim_create_buf(false, true)
|
||||
|
@ -69,7 +69,6 @@ function Suggestion:suggest()
|
||||
|
||||
Llm.stream({
|
||||
provider = provider,
|
||||
bufnr = bufnr,
|
||||
ask = true,
|
||||
selected_files = { { content = code_content, file_type = filetype, path = "" } },
|
||||
code_lang = filetype,
|
||||
|
Loading…
x
Reference in New Issue
Block a user