From 3b390040f5b206e3b21176ac065eaa475512cd29 Mon Sep 17 00:00:00 2001 From: yetone Date: Mon, 18 Nov 2024 18:07:33 +0800 Subject: [PATCH] refactor: chat history based on project (#867) --- lua/avante/llm.lua | 2 +- lua/avante/path.lua | 57 +++++++++------ lua/avante/sidebar.lua | 155 +++++++++++++++++++++++++++++------------ 3 files changed, 145 insertions(+), 69 deletions(-) diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 20e1703..3a63a53 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -93,9 +93,9 @@ M._stream = function(opts, Provider) break end end - if #history_messages > 0 and history_messages[1].role == "assistant" then table.remove(history_messages, 1) end -- prepend the history messages to the messages table vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end) + if #messages > 0 and messages[1].role == "assistant" then table.remove(messages, 1) end end ---@type AvantePromptOptions diff --git a/lua/avante/path.lua b/lua/avante/path.lua index 7b1249b..93b9b11 100644 --- a/lua/avante/path.lua +++ b/lua/avante/path.lua @@ -1,45 +1,52 @@ -local fn, api = vim.fn, vim.api +local fn = vim.fn local Utils = require("avante.utils") local LRUCache = require("avante.utils.lru_cache") local Path = require("plenary.path") local Scan = require("plenary.scandir") local Config = require("avante.config") +---@class avante.ChatHistoryEntry +---@field timestamp string +---@field provider string +---@field model string +---@field request string +---@field response string +---@field original_response string +---@field selected_file {filepath: string}? +---@field selected_code {filetype: string, content: string}? +---@field reset_memory boolean? + ---@class avante.Path ---@field history_path Path ---@field cache_path Path local P = {} --- Helpers -local H = {} - --- Get a chat history file name given a buffer ----@param bufnr integer ----@return string -H.filename = function(bufnr) - local code_buf_name = api.nvim_buf_get_name(bufnr) - -- Replace path separators with double underscores - local path_with_separators = fn.substitute(code_buf_name, "/", "__", "g") - -- Replace other non-alphanumeric characters with single underscores - return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. ".json" -end - --- Given a mode, return the file name for the custom prompt. ----@param mode LlmMode -H.get_mode_file = function(mode) return string.format("custom.%s.avanterules", mode) end - local history_file_cache = LRUCache:new(12) -- History path local History = {} +-- Get a chat history file name given a buffer +---@param bufnr integer +---@return string +History.filename = function(bufnr) + local project_root = Utils.root.get({ + buf = bufnr, + }) + -- Replace path separators with double underscores + local path_with_separators = fn.substitute(project_root, "/", "__", "g") + -- Replace other non-alphanumeric characters with single underscores + return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. ".json" +end + -- Returns the Path to the chat history file for the given buffer. ---@param bufnr integer ---@return Path -History.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(H.filename(bufnr)) end +History.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(History.filename(bufnr)) end -- Loads the chat history for the given buffer. ---@param bufnr integer +---@return avante.ChatHistoryEntry[] History.load = function(bufnr) local history_file = History.get(bufnr) local cached_key = tostring(history_file:absolute()) @@ -56,7 +63,7 @@ end -- Saves the chat history for the given buffer. ---@param bufnr integer ----@param history table +---@param history avante.ChatHistoryEntry[] History.save = vim.schedule_wrap(function(bufnr, history) local history_file = History.get(bufnr) local cached_key = tostring(history_file:absolute()) @@ -69,6 +76,10 @@ P.history = History -- Prompt path local Prompt = {} +-- Given a mode, return the file name for the custom prompt. +---@param mode LlmMode +Prompt.get_mode_file = function(mode) return string.format("custom.%s.avanterules", mode) end + ---@class AvanteTemplates ---@field initialize fun(directory: string): nil ---@field render fun(template: string, context: TemplateOptions): string @@ -110,7 +121,7 @@ Prompt.get = function(bufnr) :copy({ destination = cache_prompt_dir, recursive = true }) vim.iter(Prompt.templates):filter(function(_, v) return v ~= nil end):each(function(k, v) - local f = cache_prompt_dir:joinpath(H.get_mode_file(k)) + local f = cache_prompt_dir:joinpath(Prompt.get_mode_file(k)) f:write(v, "w") end) @@ -119,7 +130,7 @@ end ---@param mode LlmMode Prompt.get_file = function(mode) - if Prompt.templates[mode] ~= nil then return H.get_mode_file(mode) end + if Prompt.templates[mode] ~= nil then return Prompt.get_mode_file(mode) end return string.format("%s.avanterules", mode) end diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 44eb655..68cde47 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -1039,7 +1039,7 @@ function Sidebar:update_content(content, opts) opts = vim.tbl_deep_extend("force", { focus = true, scroll = true, stream = false, callback = nil }, opts or {}) if not opts.ignore_history then local chat_history = Path.history.load(self.code.bufnr) - content = self:get_history_content(chat_history) .. "---\n\n" .. content + content = self:render_history_content(chat_history) .. "---\n\n" .. content end if opts.stream then local scroll_to_bottom = function() @@ -1095,19 +1095,28 @@ end -- Function to get current timestamp local function get_timestamp() return os.date("%Y-%m-%d %H:%M:%S") end -local function get_chat_record_prefix(timestamp, provider, model, request) +---@param timestamp string|osdate +---@param provider string +---@param model string +---@param request string +---@param selected_file {filepath: string, content: string}? +---@param selected_code {filetype: string, content: string}? +---@return string +local function render_chat_record_prefix(timestamp, provider, model, request, selected_file, selected_code) provider = provider or "unknown" model = model or "unknown" - return "- Datetime: " - .. timestamp - .. "\n\n" - .. "- Model: " - .. provider - .. "/" - .. model - .. "\n\n> " - .. request:gsub("\n", "\n> "):gsub("([%w-_]+)%b[]", "`%0`") - .. "\n\n" + local res = "- Datetime: " .. timestamp .. "\n\n" .. "- Model: " .. provider .. "/" .. model + if selected_file ~= nil then res = res .. "\n\n- Selected file: " .. selected_file.filepath end + if selected_code ~= nil then + res = res + .. "\n\n- Selected code: " + .. "\n\n```" + .. selected_code.filetype + .. "\n" + .. selected_code.content + .. "\n```" + end + return res .. "\n\n> " .. request:gsub("\n", "\n> "):gsub("([%w-_]+)%b[]", "`%0`") .. "\n\n" end local function calculate_config_window_position() @@ -1132,20 +1141,34 @@ function Sidebar:get_layout() return vim.tbl_contains({ "left", "right" }, calculate_config_window_position()) and "vertical" or "horizontal" end -function Sidebar:get_history_content(history) +---@param history avante.ChatHistoryEntry[] +---@return string +function Sidebar:render_history_content(history) local content = "" for idx, entry in ipairs(history) do - local prefix = - get_chat_record_prefix(entry.timestamp, entry.provider, entry.model, entry.request or entry.requirement or "") + if entry.reset_memory then + content = content .. "***MEMORY RESET***\n\n" + if idx < #history then content = content .. "---\n\n" end + goto continue + end + local prefix = render_chat_record_prefix( + entry.timestamp, + entry.provider, + entry.model, + entry.request or "", + entry.selected_file, + entry.selected_code + ) content = content .. prefix content = content .. entry.response .. "\n\n" if idx < #history then content = content .. "---\n\n" end + ::continue:: end return content end function Sidebar:update_content_with_history(history) - local content = self:get_history_content(history) + local content = self:render_history_content(history) self:update_content(content, { ignore_history = true }) end @@ -1184,7 +1207,7 @@ function Sidebar:get_content_between_separators() return content, start_line end ----@alias AvanteSlashCommands "clear" | "help" | "lines" +---@alias AvanteSlashCommands "clear" | "help" | "lines" | "reset" ---@alias AvanteSlashCallback fun(args: string, cb?: fun(args: string): nil): nil ---@alias AvanteSlash {description: string, command: AvanteSlashCommands, details: string, shorthelp?: string, callback?: AvanteSlashCallback} ---@return AvanteSlash[] @@ -1203,6 +1226,7 @@ function Sidebar:get_commands() local items = { { description = "Show help message", command = "help" }, { description = "Clear chat history", command = "clear" }, + { description = "Reset memory", command = "reset" }, { shorthelp = "Ask a question about specific lines", description = "/lines - ", @@ -1223,13 +1247,31 @@ function Sidebar:get_commands() chat_history = {} Path.history.save(self.code.bufnr, chat_history) self:update_content("Chat history cleared", { focus = false, scroll = false }) - vim.defer_fn(function() - self:close() - if cb then cb(args) end - end, 1000) + if cb then cb(args) end + else + self:update_content("Chat history is already empty", { focus = false, scroll = false }) + end + end, + reset = function(args, cb) + local chat_history = Path.history.load(self.code.bufnr) + if next(chat_history) ~= nil then + table.insert(chat_history, { + timestamp = get_timestamp(), + provider = Config.provider, + model = Config.get_provider(Config.provider).model, + request = "", + response = "", + original_response = "", + selected_file = nil, + selected_code = nil, + reset_memory = true, + }) + Path.history.save(self.code.bufnr, chat_history) + local history_content = self:render_history_content(chat_history) + self:update_content(history_content, { focus = false, scroll = true }) + if cb then cb(args) end else self:update_content("Chat history is already empty", { focus = false, scroll = false }) - vim.defer_fn(function() self:close() end, 1000) end end, lines = function(args, cb) @@ -1300,7 +1342,22 @@ function Sidebar:create_input(opts) local timestamp = get_timestamp() - local content_prefix = get_chat_record_prefix(timestamp, Config.provider, model, request) + local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr }) + + local selected_file = { + filepath = api.nvim_buf_get_name(self.code.bufnr), + } + + local selected_code = nil + if self.code.selection ~= nil then + selected_code = { + filetype = filetype, + content = self.code.selection.content, + } + end + + local content_prefix = + render_chat_record_prefix(timestamp, Config.provider, model, request, selected_file, selected_code) --- HACK: we need to set focus to true and scroll to false to --- prevent the cursor from jumping to the bottom of the @@ -1310,8 +1367,6 @@ function Sidebar:create_input(opts) local content = table.concat(Utils.get_buf_lines(0, -1, self.code.bufnr), "\n") - local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr }) - local selected_code_content = nil if self.code.selection ~= nil then selected_code_content = self.code.selection.content end @@ -1409,6 +1464,8 @@ function Sidebar:create_input(opts) request = request, response = displayed_response, original_response = original_response, + selected_file = selected_file, + selected_code = selected_code, }) Path.history.save(self.code.bufnr, chat_history) end @@ -1420,26 +1477,34 @@ function Sidebar:create_input(opts) local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil - local history_messages = vim - .iter(chat_history) - :filter( - function(history) - return history.request ~= nil - and history.original_response ~= nil - and history.request ~= "" - and history.original_response ~= "" - end - ) - :map( - function(history) - return { - { role = "user", content = history.request }, - { role = "assistant", content = history.original_response }, - } - end - ) - :flatten() - :totable() + local history_messages = {} + for i = #chat_history, 1, -1 do + local entry = chat_history[i] + if entry.reset_memory then break end + if + entry.request == nil + or entry.original_response == nil + or entry.request == "" + or entry.original_response == "" + then + break + end + table.insert(history_messages, 1, { role = "assistant", content = entry.original_response }) + local user_content = "" + if entry.selected_file ~= nil then + user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n" + end + if entry.selected_code ~= nil then + user_content = user_content + .. "SELECTED CODE:\n\n```" + .. entry.selected_code.filetype + .. "\n" + .. entry.selected_code.content + .. "\n```\n\n" + end + user_content = user_content .. "USER PROMPT:\n\n" .. entry.request + table.insert(history_messages, 1, { role = "user", content = user_content }) + end Llm.stream({ bufnr = self.code.bufnr,