refactor: chat history based on project (#867)
This commit is contained in:
parent
87885a4530
commit
3b390040f5
@ -93,9 +93,9 @@ M._stream = function(opts, Provider)
|
|||||||
break
|
break
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if #history_messages > 0 and history_messages[1].role == "assistant" then table.remove(history_messages, 1) end
|
|
||||||
-- prepend the history messages to the messages table
|
-- prepend the history messages to the messages table
|
||||||
vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end)
|
vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end)
|
||||||
|
if #messages > 0 and messages[1].role == "assistant" then table.remove(messages, 1) end
|
||||||
end
|
end
|
||||||
|
|
||||||
---@type AvantePromptOptions
|
---@type AvantePromptOptions
|
||||||
|
@ -1,45 +1,52 @@
|
|||||||
local fn, api = vim.fn, vim.api
|
local fn = vim.fn
|
||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
local LRUCache = require("avante.utils.lru_cache")
|
local LRUCache = require("avante.utils.lru_cache")
|
||||||
local Path = require("plenary.path")
|
local Path = require("plenary.path")
|
||||||
local Scan = require("plenary.scandir")
|
local Scan = require("plenary.scandir")
|
||||||
local Config = require("avante.config")
|
local Config = require("avante.config")
|
||||||
|
|
||||||
|
---@class avante.ChatHistoryEntry
|
||||||
|
---@field timestamp string
|
||||||
|
---@field provider string
|
||||||
|
---@field model string
|
||||||
|
---@field request string
|
||||||
|
---@field response string
|
||||||
|
---@field original_response string
|
||||||
|
---@field selected_file {filepath: string}?
|
||||||
|
---@field selected_code {filetype: string, content: string}?
|
||||||
|
---@field reset_memory boolean?
|
||||||
|
|
||||||
---@class avante.Path
|
---@class avante.Path
|
||||||
---@field history_path Path
|
---@field history_path Path
|
||||||
---@field cache_path Path
|
---@field cache_path Path
|
||||||
local P = {}
|
local P = {}
|
||||||
|
|
||||||
-- Helpers
|
|
||||||
local H = {}
|
|
||||||
|
|
||||||
-- Get a chat history file name given a buffer
|
|
||||||
---@param bufnr integer
|
|
||||||
---@return string
|
|
||||||
H.filename = function(bufnr)
|
|
||||||
local code_buf_name = api.nvim_buf_get_name(bufnr)
|
|
||||||
-- Replace path separators with double underscores
|
|
||||||
local path_with_separators = fn.substitute(code_buf_name, "/", "__", "g")
|
|
||||||
-- Replace other non-alphanumeric characters with single underscores
|
|
||||||
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. ".json"
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Given a mode, return the file name for the custom prompt.
|
|
||||||
---@param mode LlmMode
|
|
||||||
H.get_mode_file = function(mode) return string.format("custom.%s.avanterules", mode) end
|
|
||||||
|
|
||||||
local history_file_cache = LRUCache:new(12)
|
local history_file_cache = LRUCache:new(12)
|
||||||
|
|
||||||
-- History path
|
-- History path
|
||||||
local History = {}
|
local History = {}
|
||||||
|
|
||||||
|
-- Get a chat history file name given a buffer
|
||||||
|
---@param bufnr integer
|
||||||
|
---@return string
|
||||||
|
History.filename = function(bufnr)
|
||||||
|
local project_root = Utils.root.get({
|
||||||
|
buf = bufnr,
|
||||||
|
})
|
||||||
|
-- Replace path separators with double underscores
|
||||||
|
local path_with_separators = fn.substitute(project_root, "/", "__", "g")
|
||||||
|
-- Replace other non-alphanumeric characters with single underscores
|
||||||
|
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. ".json"
|
||||||
|
end
|
||||||
|
|
||||||
-- Returns the Path to the chat history file for the given buffer.
|
-- Returns the Path to the chat history file for the given buffer.
|
||||||
---@param bufnr integer
|
---@param bufnr integer
|
||||||
---@return Path
|
---@return Path
|
||||||
History.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(H.filename(bufnr)) end
|
History.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(History.filename(bufnr)) end
|
||||||
|
|
||||||
-- Loads the chat history for the given buffer.
|
-- Loads the chat history for the given buffer.
|
||||||
---@param bufnr integer
|
---@param bufnr integer
|
||||||
|
---@return avante.ChatHistoryEntry[]
|
||||||
History.load = function(bufnr)
|
History.load = function(bufnr)
|
||||||
local history_file = History.get(bufnr)
|
local history_file = History.get(bufnr)
|
||||||
local cached_key = tostring(history_file:absolute())
|
local cached_key = tostring(history_file:absolute())
|
||||||
@ -56,7 +63,7 @@ end
|
|||||||
|
|
||||||
-- Saves the chat history for the given buffer.
|
-- Saves the chat history for the given buffer.
|
||||||
---@param bufnr integer
|
---@param bufnr integer
|
||||||
---@param history table
|
---@param history avante.ChatHistoryEntry[]
|
||||||
History.save = vim.schedule_wrap(function(bufnr, history)
|
History.save = vim.schedule_wrap(function(bufnr, history)
|
||||||
local history_file = History.get(bufnr)
|
local history_file = History.get(bufnr)
|
||||||
local cached_key = tostring(history_file:absolute())
|
local cached_key = tostring(history_file:absolute())
|
||||||
@ -69,6 +76,10 @@ P.history = History
|
|||||||
-- Prompt path
|
-- Prompt path
|
||||||
local Prompt = {}
|
local Prompt = {}
|
||||||
|
|
||||||
|
-- Given a mode, return the file name for the custom prompt.
|
||||||
|
---@param mode LlmMode
|
||||||
|
Prompt.get_mode_file = function(mode) return string.format("custom.%s.avanterules", mode) end
|
||||||
|
|
||||||
---@class AvanteTemplates
|
---@class AvanteTemplates
|
||||||
---@field initialize fun(directory: string): nil
|
---@field initialize fun(directory: string): nil
|
||||||
---@field render fun(template: string, context: TemplateOptions): string
|
---@field render fun(template: string, context: TemplateOptions): string
|
||||||
@ -110,7 +121,7 @@ Prompt.get = function(bufnr)
|
|||||||
:copy({ destination = cache_prompt_dir, recursive = true })
|
:copy({ destination = cache_prompt_dir, recursive = true })
|
||||||
|
|
||||||
vim.iter(Prompt.templates):filter(function(_, v) return v ~= nil end):each(function(k, v)
|
vim.iter(Prompt.templates):filter(function(_, v) return v ~= nil end):each(function(k, v)
|
||||||
local f = cache_prompt_dir:joinpath(H.get_mode_file(k))
|
local f = cache_prompt_dir:joinpath(Prompt.get_mode_file(k))
|
||||||
f:write(v, "w")
|
f:write(v, "w")
|
||||||
end)
|
end)
|
||||||
|
|
||||||
@ -119,7 +130,7 @@ end
|
|||||||
|
|
||||||
---@param mode LlmMode
|
---@param mode LlmMode
|
||||||
Prompt.get_file = function(mode)
|
Prompt.get_file = function(mode)
|
||||||
if Prompt.templates[mode] ~= nil then return H.get_mode_file(mode) end
|
if Prompt.templates[mode] ~= nil then return Prompt.get_mode_file(mode) end
|
||||||
return string.format("%s.avanterules", mode)
|
return string.format("%s.avanterules", mode)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1039,7 +1039,7 @@ function Sidebar:update_content(content, opts)
|
|||||||
opts = vim.tbl_deep_extend("force", { focus = true, scroll = true, stream = false, callback = nil }, opts or {})
|
opts = vim.tbl_deep_extend("force", { focus = true, scroll = true, stream = false, callback = nil }, opts or {})
|
||||||
if not opts.ignore_history then
|
if not opts.ignore_history then
|
||||||
local chat_history = Path.history.load(self.code.bufnr)
|
local chat_history = Path.history.load(self.code.bufnr)
|
||||||
content = self:get_history_content(chat_history) .. "---\n\n" .. content
|
content = self:render_history_content(chat_history) .. "---\n\n" .. content
|
||||||
end
|
end
|
||||||
if opts.stream then
|
if opts.stream then
|
||||||
local scroll_to_bottom = function()
|
local scroll_to_bottom = function()
|
||||||
@ -1095,19 +1095,28 @@ end
|
|||||||
-- Function to get current timestamp
|
-- Function to get current timestamp
|
||||||
local function get_timestamp() return os.date("%Y-%m-%d %H:%M:%S") end
|
local function get_timestamp() return os.date("%Y-%m-%d %H:%M:%S") end
|
||||||
|
|
||||||
local function get_chat_record_prefix(timestamp, provider, model, request)
|
---@param timestamp string|osdate
|
||||||
|
---@param provider string
|
||||||
|
---@param model string
|
||||||
|
---@param request string
|
||||||
|
---@param selected_file {filepath: string, content: string}?
|
||||||
|
---@param selected_code {filetype: string, content: string}?
|
||||||
|
---@return string
|
||||||
|
local function render_chat_record_prefix(timestamp, provider, model, request, selected_file, selected_code)
|
||||||
provider = provider or "unknown"
|
provider = provider or "unknown"
|
||||||
model = model or "unknown"
|
model = model or "unknown"
|
||||||
return "- Datetime: "
|
local res = "- Datetime: " .. timestamp .. "\n\n" .. "- Model: " .. provider .. "/" .. model
|
||||||
.. timestamp
|
if selected_file ~= nil then res = res .. "\n\n- Selected file: " .. selected_file.filepath end
|
||||||
.. "\n\n"
|
if selected_code ~= nil then
|
||||||
.. "- Model: "
|
res = res
|
||||||
.. provider
|
.. "\n\n- Selected code: "
|
||||||
.. "/"
|
.. "\n\n```"
|
||||||
.. model
|
.. selected_code.filetype
|
||||||
.. "\n\n> "
|
.. "\n"
|
||||||
.. request:gsub("\n", "\n> "):gsub("([%w-_]+)%b[]", "`%0`")
|
.. selected_code.content
|
||||||
.. "\n\n"
|
.. "\n```"
|
||||||
|
end
|
||||||
|
return res .. "\n\n> " .. request:gsub("\n", "\n> "):gsub("([%w-_]+)%b[]", "`%0`") .. "\n\n"
|
||||||
end
|
end
|
||||||
|
|
||||||
local function calculate_config_window_position()
|
local function calculate_config_window_position()
|
||||||
@ -1132,20 +1141,34 @@ function Sidebar:get_layout()
|
|||||||
return vim.tbl_contains({ "left", "right" }, calculate_config_window_position()) and "vertical" or "horizontal"
|
return vim.tbl_contains({ "left", "right" }, calculate_config_window_position()) and "vertical" or "horizontal"
|
||||||
end
|
end
|
||||||
|
|
||||||
function Sidebar:get_history_content(history)
|
---@param history avante.ChatHistoryEntry[]
|
||||||
|
---@return string
|
||||||
|
function Sidebar:render_history_content(history)
|
||||||
local content = ""
|
local content = ""
|
||||||
for idx, entry in ipairs(history) do
|
for idx, entry in ipairs(history) do
|
||||||
local prefix =
|
if entry.reset_memory then
|
||||||
get_chat_record_prefix(entry.timestamp, entry.provider, entry.model, entry.request or entry.requirement or "")
|
content = content .. "***MEMORY RESET***\n\n"
|
||||||
|
if idx < #history then content = content .. "---\n\n" end
|
||||||
|
goto continue
|
||||||
|
end
|
||||||
|
local prefix = render_chat_record_prefix(
|
||||||
|
entry.timestamp,
|
||||||
|
entry.provider,
|
||||||
|
entry.model,
|
||||||
|
entry.request or "",
|
||||||
|
entry.selected_file,
|
||||||
|
entry.selected_code
|
||||||
|
)
|
||||||
content = content .. prefix
|
content = content .. prefix
|
||||||
content = content .. entry.response .. "\n\n"
|
content = content .. entry.response .. "\n\n"
|
||||||
if idx < #history then content = content .. "---\n\n" end
|
if idx < #history then content = content .. "---\n\n" end
|
||||||
|
::continue::
|
||||||
end
|
end
|
||||||
return content
|
return content
|
||||||
end
|
end
|
||||||
|
|
||||||
function Sidebar:update_content_with_history(history)
|
function Sidebar:update_content_with_history(history)
|
||||||
local content = self:get_history_content(history)
|
local content = self:render_history_content(history)
|
||||||
self:update_content(content, { ignore_history = true })
|
self:update_content(content, { ignore_history = true })
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -1184,7 +1207,7 @@ function Sidebar:get_content_between_separators()
|
|||||||
return content, start_line
|
return content, start_line
|
||||||
end
|
end
|
||||||
|
|
||||||
---@alias AvanteSlashCommands "clear" | "help" | "lines"
|
---@alias AvanteSlashCommands "clear" | "help" | "lines" | "reset"
|
||||||
---@alias AvanteSlashCallback fun(args: string, cb?: fun(args: string): nil): nil
|
---@alias AvanteSlashCallback fun(args: string, cb?: fun(args: string): nil): nil
|
||||||
---@alias AvanteSlash {description: string, command: AvanteSlashCommands, details: string, shorthelp?: string, callback?: AvanteSlashCallback}
|
---@alias AvanteSlash {description: string, command: AvanteSlashCommands, details: string, shorthelp?: string, callback?: AvanteSlashCallback}
|
||||||
---@return AvanteSlash[]
|
---@return AvanteSlash[]
|
||||||
@ -1203,6 +1226,7 @@ function Sidebar:get_commands()
|
|||||||
local items = {
|
local items = {
|
||||||
{ description = "Show help message", command = "help" },
|
{ description = "Show help message", command = "help" },
|
||||||
{ description = "Clear chat history", command = "clear" },
|
{ description = "Clear chat history", command = "clear" },
|
||||||
|
{ description = "Reset memory", command = "reset" },
|
||||||
{
|
{
|
||||||
shorthelp = "Ask a question about specific lines",
|
shorthelp = "Ask a question about specific lines",
|
||||||
description = "/lines <start>-<end> <question>",
|
description = "/lines <start>-<end> <question>",
|
||||||
@ -1223,13 +1247,31 @@ function Sidebar:get_commands()
|
|||||||
chat_history = {}
|
chat_history = {}
|
||||||
Path.history.save(self.code.bufnr, chat_history)
|
Path.history.save(self.code.bufnr, chat_history)
|
||||||
self:update_content("Chat history cleared", { focus = false, scroll = false })
|
self:update_content("Chat history cleared", { focus = false, scroll = false })
|
||||||
vim.defer_fn(function()
|
|
||||||
self:close()
|
|
||||||
if cb then cb(args) end
|
if cb then cb(args) end
|
||||||
end, 1000)
|
|
||||||
else
|
else
|
||||||
self:update_content("Chat history is already empty", { focus = false, scroll = false })
|
self:update_content("Chat history is already empty", { focus = false, scroll = false })
|
||||||
vim.defer_fn(function() self:close() end, 1000)
|
end
|
||||||
|
end,
|
||||||
|
reset = function(args, cb)
|
||||||
|
local chat_history = Path.history.load(self.code.bufnr)
|
||||||
|
if next(chat_history) ~= nil then
|
||||||
|
table.insert(chat_history, {
|
||||||
|
timestamp = get_timestamp(),
|
||||||
|
provider = Config.provider,
|
||||||
|
model = Config.get_provider(Config.provider).model,
|
||||||
|
request = "",
|
||||||
|
response = "",
|
||||||
|
original_response = "",
|
||||||
|
selected_file = nil,
|
||||||
|
selected_code = nil,
|
||||||
|
reset_memory = true,
|
||||||
|
})
|
||||||
|
Path.history.save(self.code.bufnr, chat_history)
|
||||||
|
local history_content = self:render_history_content(chat_history)
|
||||||
|
self:update_content(history_content, { focus = false, scroll = true })
|
||||||
|
if cb then cb(args) end
|
||||||
|
else
|
||||||
|
self:update_content("Chat history is already empty", { focus = false, scroll = false })
|
||||||
end
|
end
|
||||||
end,
|
end,
|
||||||
lines = function(args, cb)
|
lines = function(args, cb)
|
||||||
@ -1300,7 +1342,22 @@ function Sidebar:create_input(opts)
|
|||||||
|
|
||||||
local timestamp = get_timestamp()
|
local timestamp = get_timestamp()
|
||||||
|
|
||||||
local content_prefix = get_chat_record_prefix(timestamp, Config.provider, model, request)
|
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr })
|
||||||
|
|
||||||
|
local selected_file = {
|
||||||
|
filepath = api.nvim_buf_get_name(self.code.bufnr),
|
||||||
|
}
|
||||||
|
|
||||||
|
local selected_code = nil
|
||||||
|
if self.code.selection ~= nil then
|
||||||
|
selected_code = {
|
||||||
|
filetype = filetype,
|
||||||
|
content = self.code.selection.content,
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
local content_prefix =
|
||||||
|
render_chat_record_prefix(timestamp, Config.provider, model, request, selected_file, selected_code)
|
||||||
|
|
||||||
--- HACK: we need to set focus to true and scroll to false to
|
--- HACK: we need to set focus to true and scroll to false to
|
||||||
--- prevent the cursor from jumping to the bottom of the
|
--- prevent the cursor from jumping to the bottom of the
|
||||||
@ -1310,8 +1367,6 @@ function Sidebar:create_input(opts)
|
|||||||
|
|
||||||
local content = table.concat(Utils.get_buf_lines(0, -1, self.code.bufnr), "\n")
|
local content = table.concat(Utils.get_buf_lines(0, -1, self.code.bufnr), "\n")
|
||||||
|
|
||||||
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.bufnr })
|
|
||||||
|
|
||||||
local selected_code_content = nil
|
local selected_code_content = nil
|
||||||
if self.code.selection ~= nil then selected_code_content = self.code.selection.content end
|
if self.code.selection ~= nil then selected_code_content = self.code.selection.content end
|
||||||
|
|
||||||
@ -1409,6 +1464,8 @@ function Sidebar:create_input(opts)
|
|||||||
request = request,
|
request = request,
|
||||||
response = displayed_response,
|
response = displayed_response,
|
||||||
original_response = original_response,
|
original_response = original_response,
|
||||||
|
selected_file = selected_file,
|
||||||
|
selected_code = selected_code,
|
||||||
})
|
})
|
||||||
Path.history.save(self.code.bufnr, chat_history)
|
Path.history.save(self.code.bufnr, chat_history)
|
||||||
end
|
end
|
||||||
@ -1420,26 +1477,34 @@ function Sidebar:create_input(opts)
|
|||||||
|
|
||||||
local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil
|
local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil
|
||||||
|
|
||||||
local history_messages = vim
|
local history_messages = {}
|
||||||
.iter(chat_history)
|
for i = #chat_history, 1, -1 do
|
||||||
:filter(
|
local entry = chat_history[i]
|
||||||
function(history)
|
if entry.reset_memory then break end
|
||||||
return history.request ~= nil
|
if
|
||||||
and history.original_response ~= nil
|
entry.request == nil
|
||||||
and history.request ~= ""
|
or entry.original_response == nil
|
||||||
and history.original_response ~= ""
|
or entry.request == ""
|
||||||
|
or entry.original_response == ""
|
||||||
|
then
|
||||||
|
break
|
||||||
end
|
end
|
||||||
)
|
table.insert(history_messages, 1, { role = "assistant", content = entry.original_response })
|
||||||
:map(
|
local user_content = ""
|
||||||
function(history)
|
if entry.selected_file ~= nil then
|
||||||
return {
|
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
|
||||||
{ role = "user", content = history.request },
|
end
|
||||||
{ role = "assistant", content = history.original_response },
|
if entry.selected_code ~= nil then
|
||||||
}
|
user_content = user_content
|
||||||
|
.. "SELECTED CODE:\n\n```"
|
||||||
|
.. entry.selected_code.filetype
|
||||||
|
.. "\n"
|
||||||
|
.. entry.selected_code.content
|
||||||
|
.. "\n```\n\n"
|
||||||
|
end
|
||||||
|
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
|
||||||
|
table.insert(history_messages, 1, { role = "user", content = user_content })
|
||||||
end
|
end
|
||||||
)
|
|
||||||
:flatten()
|
|
||||||
:totable()
|
|
||||||
|
|
||||||
Llm.stream({
|
Llm.stream({
|
||||||
bufnr = self.code.bufnr,
|
bufnr = self.code.bufnr,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user