feat: memory 🧠 (#793)

This commit is contained in:
yetone 2024-11-04 16:20:28 +08:00 committed by GitHub
parent 579ef12f76
commit 1e8abbf798
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 197 additions and 154 deletions

View File

@ -19,13 +19,11 @@ impl<'a> State<'a> {
struct TemplateContext {
use_xml_format: bool,
ask: bool,
question: String,
code_lang: String,
filepath: String,
file_content: String,
selected_code: Option<String>,
project_context: Option<String>,
memory_context: Option<String>,
}
// Given the file name registered after add, the context table in Lua, resulted in a formatted
@ -44,13 +42,11 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
.render(context! {
use_xml_format => context.use_xml_format,
ask => context.ask,
question => context.question,
code_lang => context.code_lang,
filepath => context.filepath,
file_content => context.file_content,
selected_code => context.selected_code,
project_context => context.project_context,
memory_context => context.memory_context,
})
.map_err(LuaError::external)
.unwrap())

View File

@ -18,15 +18,6 @@ M.defaults = {
-- For most providers that we support we will determine this automatically.
-- If you wish to use a given implementation, then you can override it here.
tokenizer = "tiktoken",
---@alias AvanteSystemPrompt string
-- Default system prompt. Users can override this with their own prompt
-- You can use `require('avante.config').override({system_prompt = "MY_SYSTEM_PROMPT"}) conditionally
-- in your own autocmds to do it per directory, or that fit your needs.
system_prompt = [[
Act as an expert software developer.
Always use best practices when coding.
Respect and use existing conventions, libraries, etc that are already present in the code base.
]],
---@type AvanteSupportedProvider
openai = {
endpoint = "https://api.openai.com/v1",
@ -102,6 +93,7 @@ Respect and use existing conventions, libraries, etc that are already present in
support_paste_from_clipboard = false,
},
history = {
max_tokens = 4096,
storage_path = vim.fn.stdpath("state") .. "/avante",
paste = {
extension = "png",
@ -315,6 +307,7 @@ M.BASE_PROVIDER_KEYS = {
"_shellenv",
"tokenizer_id",
"use_xml_format",
"role_map",
}
return M

View File

@ -28,7 +28,7 @@ local group = api.nvim_create_augroup("avante_llm", { clear = true })
---@field file_content string
---@field selected_code string | nil
---@field project_context string | nil
---@field memory_context string | nil
---@field history_messages AvanteLLMMessage[]
---
---@class StreamOptions: TemplateOptions
---@field ask boolean
@ -44,10 +44,12 @@ M.stream = function(opts)
local mode = opts.mode or "planning"
---@type AvanteProviderFunctor
local Provider = opts.provider or P[Config.provider]
local _, body_opts = P.parse_config(Provider)
local max_tokens = body_opts.max_tokens or 4096
-- Check if the instructions contains an image path
local image_paths = {}
local original_instructions = opts.instructions
local instructions = opts.instructions
if opts.instructions:match("image: ") then
local lines = vim.split(opts.instructions, "\n")
for i, line in ipairs(lines) do
@ -57,7 +59,7 @@ M.stream = function(opts)
table.remove(lines, i)
end
end
original_instructions = table.concat(lines, "\n")
instructions = table.concat(lines, "\n")
end
Path.prompts.initialize(Path.prompts.get(opts.bufnr))
@ -67,29 +69,61 @@ M.stream = function(opts)
local template_opts = {
use_xml_format = Provider.use_xml_format,
ask = opts.ask, -- TODO: add mode without ask instruction
question = original_instructions,
code_lang = opts.code_lang,
filepath = filepath,
file_content = opts.file_content,
selected_code = opts.selected_code,
project_context = opts.project_context,
memory_context = opts.memory_context,
}
local user_prompts = vim
.iter({
Path.prompts.render_file("_project.avanterules", template_opts),
Path.prompts.render_file("_memory.avanterules", template_opts),
Path.prompts.render_file("_context.avanterules", template_opts),
Path.prompts.render_mode(mode, template_opts),
})
:filter(function(k) return k ~= "" end)
:totable()
local system_prompt = Path.prompts.render_mode(mode, template_opts)
---@type AvanteLLMMessage[]
local messages = {}
if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then
local project_context = Path.prompts.render_file("_project.avanterules", template_opts)
if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end
end
local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
if opts.use_xml_format then
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
else
table.insert(messages, { role = "user", content = string.format("QUESTION:\n%s", instructions) })
end
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
for _, message in ipairs(messages) do
remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
end
if opts.history_messages then
if Config.history.max_tokens > 0 then remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens) end
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
local history_messages = {}
for i = #opts.history_messages, 1, -1 do
local message = opts.history_messages[i]
local tokens = Utils.tokens.calculate_tokens(message.content)
remaining_tokens = remaining_tokens - tokens
if remaining_tokens > 0 then
table.insert(history_messages, message)
else
break
end
end
if #history_messages > 0 and history_messages[1].role == "assistant" then table.remove(history_messages, 1) end
-- prepend the history messages to the messages table
vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end)
end
---@type AvantePromptOptions
local code_opts = {
system_prompt = Config.system_prompt,
user_prompts = user_prompts,
system_prompt = system_prompt,
messages = messages,
image_paths = image_paths,
}
@ -164,7 +198,7 @@ M.stream = function(opts)
on_error = function(err)
if err.exit == 23 then
local xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR")
if fn.isdirectory(xdg_runtime_dir) == 0 then
if not xdg_runtime_dir or fn.isdirectory(xdg_runtime_dir) == 0 then
Utils.error(
"$XDG_RUNTIME_DIR="
.. xdg_runtime_dir

View File

@ -13,7 +13,7 @@ local M = {}
M.api_key_name = "AZURE_OPENAI_API_KEY"
M.parse_message = O.parse_message
M.parse_messages = O.parse_messages
M.parse_response = O.parse_response
M.parse_curl_args = function(provider, code_opts)
@ -34,7 +34,7 @@ M.parse_curl_args = function(provider, code_opts)
insecure = base.allow_insecure,
headers = headers,
body = vim.tbl_deep_extend("force", {
messages = M.parse_message(code_opts),
messages = M.parse_messages(code_opts),
stream = true,
}, body_opts),
}

View File

@ -13,8 +13,8 @@ local P = require("avante.providers")
---@field type "image"
---@field source {type: "base64", media_type: string, data: string}
---
---@class AvanteClaudeMessage: AvanteBaseMessage
---@field role "user"
---@class AvanteClaudeMessage
---@field role "user" | "assistant"
---@field content [AvanteClaudeTextMessage | AvanteClaudeImageMessage][]
---@class AvanteProviderFunctor
@ -23,11 +23,44 @@ local M = {}
M.api_key_name = "ANTHROPIC_API_KEY"
M.use_xml_format = true
M.parse_message = function(opts)
---@type AvanteClaudeMessage[]
local message_content = {}
M.role_map = {
user = "user",
assistant = "assistant",
}
if Clipboard.support_paste_image() and opts.image_paths then
M.parse_messages = function(opts)
---@type AvanteClaudeMessage[]
local messages = {}
---@type {idx: integer, length: integer}[]
local messages_with_length = {}
for idx, message in ipairs(opts.messages) do
table.insert(messages_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(message.content) })
end
table.sort(messages_with_length, function(a, b) return a.length > b.length end)
---@type table<integer, boolean>
local top_three = {}
for i = 1, math.min(3, #messages_with_length) do
top_three[messages_with_length[i].idx] = true
end
for idx, message in ipairs(opts.messages) do
table.insert(messages, {
role = M.role_map[message.role],
content = {
{
type = "text",
text = message.content,
cache_control = top_three[idx] and { type = "ephemeral" } or nil,
},
},
})
end
if Clipboard.support_paste_image() and opts.image_paths and #opts.image_paths > 0 then
local message_content = messages[#messages].content
for _, image_path in ipairs(opts.image_paths) do
table.insert(message_content, {
type = "image",
@ -38,36 +71,10 @@ M.parse_message = function(opts)
},
})
end
messages[#messages].content = message_content
end
---@type {idx: integer, length: integer}[]
local user_prompts_with_length = {}
for idx, user_prompt in ipairs(opts.user_prompts) do
table.insert(user_prompts_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(user_prompt) })
end
table.sort(user_prompts_with_length, function(a, b) return a.length > b.length end)
---@type table<integer, boolean>
local top_three = {}
for i = 1, math.min(3, #user_prompts_with_length) do
top_three[user_prompts_with_length[i].idx] = true
end
for idx, prompt_data in ipairs(opts.user_prompts) do
table.insert(message_content, {
type = "text",
text = prompt_data,
cache_control = top_three[idx] and { type = "ephemeral" } or nil,
})
end
return {
{
role = "user",
content = message_content,
},
}
return messages
end
M.parse_response = function(data_stream, event_state, opts)
@ -96,7 +103,7 @@ M.parse_curl_args = function(provider, prompt_opts)
}
if not P.env.is_local("claude") then headers["x-api-key"] = provider.parse_api_key() end
local messages = M.parse_message(prompt_opts)
local messages = M.parse_messages(prompt_opts)
return {
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/messages",

View File

@ -42,17 +42,18 @@ local M = {}
M.api_key_name = "CO_API_KEY"
M.tokenizer_id = "https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json"
M.role_map = {
user = "user",
assistant = "assistant",
}
M.parse_message = function(opts)
---@type CohereMessage[]
local user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
table.insert(acc, { type = "text", text = prompt })
return acc
end)
M.parse_messages = function(opts)
local messages = {
{ role = "system", content = opts.system_prompt },
{ role = "user", content = user_content },
}
vim
.iter(opts.messages)
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
return { messages = messages }
end
@ -91,7 +92,7 @@ M.parse_curl_args = function(provider, code_opts)
body = vim.tbl_deep_extend("force", {
model = base.model,
stream = true,
}, M.parse_message(code_opts), body_opts),
}, M.parse_messages(code_opts), body_opts),
}
end

View File

@ -118,12 +118,19 @@ M.state = nil
M.api_key_name = P.AVANTE_INTERNAL_KEY
M.tokenizer_id = "gpt-4o"
M.role_map = {
user = "user",
assistant = "assistant",
}
M.parse_message = function(opts)
return {
M.parse_messages = function(opts)
local messages = {
{ role = "system", content = opts.system_prompt },
{ role = "user", content = table.concat(opts.user_prompts, "\n") },
}
vim
.iter(opts.messages)
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
return messages
end
M.parse_response = O.parse_response
@ -146,7 +153,7 @@ M.parse_curl_args = function(provider, code_opts)
},
body = vim.tbl_deep_extend("force", {
model = base.model,
messages = M.parse_message(code_opts),
messages = M.parse_messages(code_opts),
stream = true,
}, body_opts),
}

View File

@ -6,10 +6,34 @@ local Clipboard = require("avante.clipboard")
local M = {}
M.api_key_name = "GEMINI_API_KEY"
M.role_map = {
user = "user",
assistant = "model",
}
-- M.tokenizer_id = "google/gemma-2b"
M.parse_message = function(opts)
local message_content = {}
M.parse_messages = function(opts)
local contents = {}
local prev_role = nil
vim.iter(opts.messages):each(function(message)
local role = message.role
if role == prev_role then
if role == "user" then
table.insert(contents, { role = "model", parts = {
{ text = "Ok, I understand." },
} })
else
table.insert(contents, { role = "user", parts = {
{ text = "Ok" },
} })
end
end
prev_role = role
table.insert(contents, { role = M.role_map[role] or role, parts = {
{ text = message.content },
} })
end)
if Clipboard.support_paste_image() and opts.image_paths then
for _, image_path in ipairs(opts.image_paths) do
@ -20,13 +44,10 @@ M.parse_message = function(opts)
},
}
table.insert(message_content, image_data)
table.insert(contents[#contents].parts, image_data)
end
end
-- insert a part into parts
table.insert(message_content, { text = table.concat(opts.user_prompts, "\n") })
return {
systemInstruction = {
role = "user",
@ -36,12 +57,7 @@ M.parse_message = function(opts)
},
},
},
contents = {
{
role = "user",
parts = message_content,
},
},
contents = contents,
}
end
@ -78,7 +94,7 @@ M.parse_curl_args = function(provider, code_opts)
proxy = base.proxy,
insecure = base.allow_insecure,
headers = { ["Content-Type"] = "application/json" },
body = vim.tbl_deep_extend("force", {}, M.parse_message(code_opts), body_opts),
body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts),
}
end

View File

@ -14,22 +14,22 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@field on_chunk AvanteChunkParser
---@field on_complete AvanteCompleteParser
---
---@class AvanteLLMMessage
---@field role "user" | "assistant"
---@field content string
---
---@class AvantePromptOptions: table<[string], string>
---@field system_prompt string
---@field user_prompts string[]
---@field messages AvanteLLMMessage[]
---@field image_paths? string[]
---
---@class AvanteBaseMessage
---@field role "user" | "system"
---@field content string
---
---@class AvanteGeminiMessage
---@field role "user"
---@field parts { text: string }[]
---
---@alias AvanteChatMessage AvanteClaudeMessage | OpenAIMessage | AvanteGeminiMessage
---
---@alias AvanteMessageParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
---
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>}
---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput
@ -65,13 +65,14 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@field parse_api_key? fun(): string | nil
---
---@class AvanteProviderFunctor
---@field parse_message AvanteMessageParser
---@field role_map table<"user" | "assistant", string>
---@field parse_messages AvanteMessagesParser
---@field parse_response AvanteResponseParser
---@field parse_curl_args AvanteCurlArgsParser
---@field setup fun(): nil
---@field has fun(): boolean
---@field api_key_name string
---@field tokenizer_id [string] | "gpt-4o"
---@field tokenizer_id string | "gpt-4o"
---@field use_xml_format boolean
---@field model? string
---@field parse_api_key fun(): string | nil

View File

@ -33,29 +33,15 @@ local M = {}
M.api_key_name = "OPENAI_API_KEY"
M.role_map = {
user = "user",
assistant = "assistant",
}
---@param opts AvantePromptOptions
M.get_user_message = function(opts) return table.concat(opts.user_prompts, "\n") end
M.parse_message = function(opts)
---@type OpenAIMessage[]
local user_content = {}
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
for _, image_path in ipairs(opts.image_paths) do
table.insert(user_content, {
type = "image_url",
image_url = {
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
},
})
end
vim.iter(opts.user_prompts):each(function(prompt) table.insert(user_content, { type = "text", text = prompt }) end)
else
user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
table.insert(acc, { type = "text", text = prompt })
return acc
end)
end
M.get_user_message = function(opts) return table.concat(opts.messages, "\n") end
M.parse_messages = function(opts)
local messages = {}
local provider = P[Config.provider]
local base, _ = P.parse_config(provider)
@ -68,8 +54,23 @@ M.parse_message = function(opts)
table.insert(messages, { role = "system", content = opts.system_prompt })
end
-- User message after the prompt
table.insert(messages, { role = "user", content = user_content })
vim
.iter(opts.messages)
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
local message_content = messages[#messages].content
if type(message_content) ~= "table" then message_content = { type = "text", text = message_content } end
for _, image_path in ipairs(opts.image_paths) do
table.insert(message_content, {
type = "image_url",
image_url = {
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
},
})
end
messages[#messages].content = message_content
end
return messages
end
@ -128,7 +129,7 @@ M.parse_curl_args = function(provider, code_opts)
headers = headers,
body = vim.tbl_deep_extend("force", {
model = base.model,
messages = M.parse_message(code_opts),
messages = M.parse_messages(code_opts),
stream = stream,
}, body_opts),
}

View File

@ -1392,10 +1392,23 @@ function Sidebar:create_input(opts)
local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil
local history_messages = vim.tbl_map(
function(history)
return {
{ role = "user", content = history.request },
{ role = "assistant", content = history.original_response },
}
end,
chat_history
)
history_messages = vim.iter(history_messages):flatten():totable()
Llm.stream({
bufnr = self.code.bufnr,
ask = opts.ask,
project_context = vim.json.encode(project_context),
history_messages = history_messages,
file_content = content,
code_lang = filetype,
selected_code = selected_code_content,

View File

@ -1,12 +0,0 @@
{%- if use_xml_format -%}
{%- if memory_context -%}
<memory_context>
{{memory_context}}
</memory_context>
{%- endif %}
{%- else -%}
{%- if memory_context -%}
MEMORY CONTEXT:
{{memory_context}}
{%- endif %}
{%- endif %}

View File

@ -7,10 +7,11 @@
"file_content": "local Config = require('avante.config')"
}
#}
Act as an expert software developer.
Always use best practices when coding.
Respect and use existing conventions, libraries, etc that are already present in the code base.
{%- if ask %}
{%- if not use_xml_format -%}
INSTRUCTION:{% else -%}
<instruction>{% endif -%}
{% block user_prompt %}
Take requests for changes to the supplied code.
If the request is ambiguous, ask questions.
@ -150,19 +151,4 @@ To rename files which have been added to the chat, use shell commands at the end
ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*!
{% endblock %}
{%- if use_xml_format -%}
</instruction>
<question>{{question}}</question>
{%- else %}
QUESTION:
{{question}}
{%- endif %}
{% else %}
{% if use_xml_format -%}
<question>{{question}}</question>
{% else %}
QUESTION:
{{question}}
{%- endif %}
{%- endif %}