feat: memory 🧠 (#793)
This commit is contained in:
parent
579ef12f76
commit
1e8abbf798
@ -19,13 +19,11 @@ impl<'a> State<'a> {
|
|||||||
struct TemplateContext {
|
struct TemplateContext {
|
||||||
use_xml_format: bool,
|
use_xml_format: bool,
|
||||||
ask: bool,
|
ask: bool,
|
||||||
question: String,
|
|
||||||
code_lang: String,
|
code_lang: String,
|
||||||
filepath: String,
|
filepath: String,
|
||||||
file_content: String,
|
file_content: String,
|
||||||
selected_code: Option<String>,
|
selected_code: Option<String>,
|
||||||
project_context: Option<String>,
|
project_context: Option<String>,
|
||||||
memory_context: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given the file name registered after add, the context table in Lua, resulted in a formatted
|
// Given the file name registered after add, the context table in Lua, resulted in a formatted
|
||||||
@ -44,13 +42,11 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
|
|||||||
.render(context! {
|
.render(context! {
|
||||||
use_xml_format => context.use_xml_format,
|
use_xml_format => context.use_xml_format,
|
||||||
ask => context.ask,
|
ask => context.ask,
|
||||||
question => context.question,
|
|
||||||
code_lang => context.code_lang,
|
code_lang => context.code_lang,
|
||||||
filepath => context.filepath,
|
filepath => context.filepath,
|
||||||
file_content => context.file_content,
|
file_content => context.file_content,
|
||||||
selected_code => context.selected_code,
|
selected_code => context.selected_code,
|
||||||
project_context => context.project_context,
|
project_context => context.project_context,
|
||||||
memory_context => context.memory_context,
|
|
||||||
})
|
})
|
||||||
.map_err(LuaError::external)
|
.map_err(LuaError::external)
|
||||||
.unwrap())
|
.unwrap())
|
||||||
|
@ -18,15 +18,6 @@ M.defaults = {
|
|||||||
-- For most providers that we support we will determine this automatically.
|
-- For most providers that we support we will determine this automatically.
|
||||||
-- If you wish to use a given implementation, then you can override it here.
|
-- If you wish to use a given implementation, then you can override it here.
|
||||||
tokenizer = "tiktoken",
|
tokenizer = "tiktoken",
|
||||||
---@alias AvanteSystemPrompt string
|
|
||||||
-- Default system prompt. Users can override this with their own prompt
|
|
||||||
-- You can use `require('avante.config').override({system_prompt = "MY_SYSTEM_PROMPT"}) conditionally
|
|
||||||
-- in your own autocmds to do it per directory, or that fit your needs.
|
|
||||||
system_prompt = [[
|
|
||||||
Act as an expert software developer.
|
|
||||||
Always use best practices when coding.
|
|
||||||
Respect and use existing conventions, libraries, etc that are already present in the code base.
|
|
||||||
]],
|
|
||||||
---@type AvanteSupportedProvider
|
---@type AvanteSupportedProvider
|
||||||
openai = {
|
openai = {
|
||||||
endpoint = "https://api.openai.com/v1",
|
endpoint = "https://api.openai.com/v1",
|
||||||
@ -102,6 +93,7 @@ Respect and use existing conventions, libraries, etc that are already present in
|
|||||||
support_paste_from_clipboard = false,
|
support_paste_from_clipboard = false,
|
||||||
},
|
},
|
||||||
history = {
|
history = {
|
||||||
|
max_tokens = 4096,
|
||||||
storage_path = vim.fn.stdpath("state") .. "/avante",
|
storage_path = vim.fn.stdpath("state") .. "/avante",
|
||||||
paste = {
|
paste = {
|
||||||
extension = "png",
|
extension = "png",
|
||||||
@ -315,6 +307,7 @@ M.BASE_PROVIDER_KEYS = {
|
|||||||
"_shellenv",
|
"_shellenv",
|
||||||
"tokenizer_id",
|
"tokenizer_id",
|
||||||
"use_xml_format",
|
"use_xml_format",
|
||||||
|
"role_map",
|
||||||
}
|
}
|
||||||
|
|
||||||
return M
|
return M
|
||||||
|
@ -28,7 +28,7 @@ local group = api.nvim_create_augroup("avante_llm", { clear = true })
|
|||||||
---@field file_content string
|
---@field file_content string
|
||||||
---@field selected_code string | nil
|
---@field selected_code string | nil
|
||||||
---@field project_context string | nil
|
---@field project_context string | nil
|
||||||
---@field memory_context string | nil
|
---@field history_messages AvanteLLMMessage[]
|
||||||
---
|
---
|
||||||
---@class StreamOptions: TemplateOptions
|
---@class StreamOptions: TemplateOptions
|
||||||
---@field ask boolean
|
---@field ask boolean
|
||||||
@ -44,10 +44,12 @@ M.stream = function(opts)
|
|||||||
local mode = opts.mode or "planning"
|
local mode = opts.mode or "planning"
|
||||||
---@type AvanteProviderFunctor
|
---@type AvanteProviderFunctor
|
||||||
local Provider = opts.provider or P[Config.provider]
|
local Provider = opts.provider or P[Config.provider]
|
||||||
|
local _, body_opts = P.parse_config(Provider)
|
||||||
|
local max_tokens = body_opts.max_tokens or 4096
|
||||||
|
|
||||||
-- Check if the instructions contains an image path
|
-- Check if the instructions contains an image path
|
||||||
local image_paths = {}
|
local image_paths = {}
|
||||||
local original_instructions = opts.instructions
|
local instructions = opts.instructions
|
||||||
if opts.instructions:match("image: ") then
|
if opts.instructions:match("image: ") then
|
||||||
local lines = vim.split(opts.instructions, "\n")
|
local lines = vim.split(opts.instructions, "\n")
|
||||||
for i, line in ipairs(lines) do
|
for i, line in ipairs(lines) do
|
||||||
@ -57,7 +59,7 @@ M.stream = function(opts)
|
|||||||
table.remove(lines, i)
|
table.remove(lines, i)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
original_instructions = table.concat(lines, "\n")
|
instructions = table.concat(lines, "\n")
|
||||||
end
|
end
|
||||||
|
|
||||||
Path.prompts.initialize(Path.prompts.get(opts.bufnr))
|
Path.prompts.initialize(Path.prompts.get(opts.bufnr))
|
||||||
@ -67,29 +69,61 @@ M.stream = function(opts)
|
|||||||
local template_opts = {
|
local template_opts = {
|
||||||
use_xml_format = Provider.use_xml_format,
|
use_xml_format = Provider.use_xml_format,
|
||||||
ask = opts.ask, -- TODO: add mode without ask instruction
|
ask = opts.ask, -- TODO: add mode without ask instruction
|
||||||
question = original_instructions,
|
|
||||||
code_lang = opts.code_lang,
|
code_lang = opts.code_lang,
|
||||||
filepath = filepath,
|
filepath = filepath,
|
||||||
file_content = opts.file_content,
|
file_content = opts.file_content,
|
||||||
selected_code = opts.selected_code,
|
selected_code = opts.selected_code,
|
||||||
project_context = opts.project_context,
|
project_context = opts.project_context,
|
||||||
memory_context = opts.memory_context,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
local user_prompts = vim
|
local system_prompt = Path.prompts.render_mode(mode, template_opts)
|
||||||
.iter({
|
|
||||||
Path.prompts.render_file("_project.avanterules", template_opts),
|
---@type AvanteLLMMessage[]
|
||||||
Path.prompts.render_file("_memory.avanterules", template_opts),
|
local messages = {}
|
||||||
Path.prompts.render_file("_context.avanterules", template_opts),
|
|
||||||
Path.prompts.render_mode(mode, template_opts),
|
if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then
|
||||||
})
|
local project_context = Path.prompts.render_file("_project.avanterules", template_opts)
|
||||||
:filter(function(k) return k ~= "" end)
|
if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end
|
||||||
:totable()
|
end
|
||||||
|
|
||||||
|
local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
|
||||||
|
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
|
||||||
|
|
||||||
|
if opts.use_xml_format then
|
||||||
|
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
|
||||||
|
else
|
||||||
|
table.insert(messages, { role = "user", content = string.format("QUESTION:\n%s", instructions) })
|
||||||
|
end
|
||||||
|
|
||||||
|
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
|
||||||
|
|
||||||
|
for _, message in ipairs(messages) do
|
||||||
|
remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
|
||||||
|
end
|
||||||
|
|
||||||
|
if opts.history_messages then
|
||||||
|
if Config.history.max_tokens > 0 then remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens) end
|
||||||
|
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
|
||||||
|
local history_messages = {}
|
||||||
|
for i = #opts.history_messages, 1, -1 do
|
||||||
|
local message = opts.history_messages[i]
|
||||||
|
local tokens = Utils.tokens.calculate_tokens(message.content)
|
||||||
|
remaining_tokens = remaining_tokens - tokens
|
||||||
|
if remaining_tokens > 0 then
|
||||||
|
table.insert(history_messages, message)
|
||||||
|
else
|
||||||
|
break
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if #history_messages > 0 and history_messages[1].role == "assistant" then table.remove(history_messages, 1) end
|
||||||
|
-- prepend the history messages to the messages table
|
||||||
|
vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end)
|
||||||
|
end
|
||||||
|
|
||||||
---@type AvantePromptOptions
|
---@type AvantePromptOptions
|
||||||
local code_opts = {
|
local code_opts = {
|
||||||
system_prompt = Config.system_prompt,
|
system_prompt = system_prompt,
|
||||||
user_prompts = user_prompts,
|
messages = messages,
|
||||||
image_paths = image_paths,
|
image_paths = image_paths,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,7 +198,7 @@ M.stream = function(opts)
|
|||||||
on_error = function(err)
|
on_error = function(err)
|
||||||
if err.exit == 23 then
|
if err.exit == 23 then
|
||||||
local xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR")
|
local xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR")
|
||||||
if fn.isdirectory(xdg_runtime_dir) == 0 then
|
if not xdg_runtime_dir or fn.isdirectory(xdg_runtime_dir) == 0 then
|
||||||
Utils.error(
|
Utils.error(
|
||||||
"$XDG_RUNTIME_DIR="
|
"$XDG_RUNTIME_DIR="
|
||||||
.. xdg_runtime_dir
|
.. xdg_runtime_dir
|
||||||
|
@ -13,7 +13,7 @@ local M = {}
|
|||||||
|
|
||||||
M.api_key_name = "AZURE_OPENAI_API_KEY"
|
M.api_key_name = "AZURE_OPENAI_API_KEY"
|
||||||
|
|
||||||
M.parse_message = O.parse_message
|
M.parse_messages = O.parse_messages
|
||||||
M.parse_response = O.parse_response
|
M.parse_response = O.parse_response
|
||||||
|
|
||||||
M.parse_curl_args = function(provider, code_opts)
|
M.parse_curl_args = function(provider, code_opts)
|
||||||
@ -34,7 +34,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
insecure = base.allow_insecure,
|
insecure = base.allow_insecure,
|
||||||
headers = headers,
|
headers = headers,
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
messages = M.parse_message(code_opts),
|
messages = M.parse_messages(code_opts),
|
||||||
stream = true,
|
stream = true,
|
||||||
}, body_opts),
|
}, body_opts),
|
||||||
}
|
}
|
||||||
|
@ -13,8 +13,8 @@ local P = require("avante.providers")
|
|||||||
---@field type "image"
|
---@field type "image"
|
||||||
---@field source {type: "base64", media_type: string, data: string}
|
---@field source {type: "base64", media_type: string, data: string}
|
||||||
---
|
---
|
||||||
---@class AvanteClaudeMessage: AvanteBaseMessage
|
---@class AvanteClaudeMessage
|
||||||
---@field role "user"
|
---@field role "user" | "assistant"
|
||||||
---@field content [AvanteClaudeTextMessage | AvanteClaudeImageMessage][]
|
---@field content [AvanteClaudeTextMessage | AvanteClaudeImageMessage][]
|
||||||
|
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
@ -23,11 +23,44 @@ local M = {}
|
|||||||
M.api_key_name = "ANTHROPIC_API_KEY"
|
M.api_key_name = "ANTHROPIC_API_KEY"
|
||||||
M.use_xml_format = true
|
M.use_xml_format = true
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.role_map = {
|
||||||
---@type AvanteClaudeMessage[]
|
user = "user",
|
||||||
local message_content = {}
|
assistant = "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
if Clipboard.support_paste_image() and opts.image_paths then
|
M.parse_messages = function(opts)
|
||||||
|
---@type AvanteClaudeMessage[]
|
||||||
|
local messages = {}
|
||||||
|
|
||||||
|
---@type {idx: integer, length: integer}[]
|
||||||
|
local messages_with_length = {}
|
||||||
|
for idx, message in ipairs(opts.messages) do
|
||||||
|
table.insert(messages_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(message.content) })
|
||||||
|
end
|
||||||
|
|
||||||
|
table.sort(messages_with_length, function(a, b) return a.length > b.length end)
|
||||||
|
|
||||||
|
---@type table<integer, boolean>
|
||||||
|
local top_three = {}
|
||||||
|
for i = 1, math.min(3, #messages_with_length) do
|
||||||
|
top_three[messages_with_length[i].idx] = true
|
||||||
|
end
|
||||||
|
|
||||||
|
for idx, message in ipairs(opts.messages) do
|
||||||
|
table.insert(messages, {
|
||||||
|
role = M.role_map[message.role],
|
||||||
|
content = {
|
||||||
|
{
|
||||||
|
type = "text",
|
||||||
|
text = message.content,
|
||||||
|
cache_control = top_three[idx] and { type = "ephemeral" } or nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
end
|
||||||
|
|
||||||
|
if Clipboard.support_paste_image() and opts.image_paths and #opts.image_paths > 0 then
|
||||||
|
local message_content = messages[#messages].content
|
||||||
for _, image_path in ipairs(opts.image_paths) do
|
for _, image_path in ipairs(opts.image_paths) do
|
||||||
table.insert(message_content, {
|
table.insert(message_content, {
|
||||||
type = "image",
|
type = "image",
|
||||||
@ -38,36 +71,10 @@ M.parse_message = function(opts)
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
end
|
end
|
||||||
|
messages[#messages].content = message_content
|
||||||
end
|
end
|
||||||
|
|
||||||
---@type {idx: integer, length: integer}[]
|
return messages
|
||||||
local user_prompts_with_length = {}
|
|
||||||
for idx, user_prompt in ipairs(opts.user_prompts) do
|
|
||||||
table.insert(user_prompts_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(user_prompt) })
|
|
||||||
end
|
|
||||||
|
|
||||||
table.sort(user_prompts_with_length, function(a, b) return a.length > b.length end)
|
|
||||||
|
|
||||||
---@type table<integer, boolean>
|
|
||||||
local top_three = {}
|
|
||||||
for i = 1, math.min(3, #user_prompts_with_length) do
|
|
||||||
top_three[user_prompts_with_length[i].idx] = true
|
|
||||||
end
|
|
||||||
|
|
||||||
for idx, prompt_data in ipairs(opts.user_prompts) do
|
|
||||||
table.insert(message_content, {
|
|
||||||
type = "text",
|
|
||||||
text = prompt_data,
|
|
||||||
cache_control = top_three[idx] and { type = "ephemeral" } or nil,
|
|
||||||
})
|
|
||||||
end
|
|
||||||
|
|
||||||
return {
|
|
||||||
{
|
|
||||||
role = "user",
|
|
||||||
content = message_content,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
end
|
end
|
||||||
|
|
||||||
M.parse_response = function(data_stream, event_state, opts)
|
M.parse_response = function(data_stream, event_state, opts)
|
||||||
@ -96,7 +103,7 @@ M.parse_curl_args = function(provider, prompt_opts)
|
|||||||
}
|
}
|
||||||
if not P.env.is_local("claude") then headers["x-api-key"] = provider.parse_api_key() end
|
if not P.env.is_local("claude") then headers["x-api-key"] = provider.parse_api_key() end
|
||||||
|
|
||||||
local messages = M.parse_message(prompt_opts)
|
local messages = M.parse_messages(prompt_opts)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/messages",
|
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/messages",
|
||||||
|
@ -42,17 +42,18 @@ local M = {}
|
|||||||
|
|
||||||
M.api_key_name = "CO_API_KEY"
|
M.api_key_name = "CO_API_KEY"
|
||||||
M.tokenizer_id = "https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json"
|
M.tokenizer_id = "https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json"
|
||||||
|
M.role_map = {
|
||||||
|
user = "user",
|
||||||
|
assistant = "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_messages = function(opts)
|
||||||
---@type CohereMessage[]
|
|
||||||
local user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
|
|
||||||
table.insert(acc, { type = "text", text = prompt })
|
|
||||||
return acc
|
|
||||||
end)
|
|
||||||
local messages = {
|
local messages = {
|
||||||
{ role = "system", content = opts.system_prompt },
|
{ role = "system", content = opts.system_prompt },
|
||||||
{ role = "user", content = user_content },
|
|
||||||
}
|
}
|
||||||
|
vim
|
||||||
|
.iter(opts.messages)
|
||||||
|
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
|
||||||
return { messages = messages }
|
return { messages = messages }
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -91,7 +92,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = base.model,
|
model = base.model,
|
||||||
stream = true,
|
stream = true,
|
||||||
}, M.parse_message(code_opts), body_opts),
|
}, M.parse_messages(code_opts), body_opts),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -118,12 +118,19 @@ M.state = nil
|
|||||||
|
|
||||||
M.api_key_name = P.AVANTE_INTERNAL_KEY
|
M.api_key_name = P.AVANTE_INTERNAL_KEY
|
||||||
M.tokenizer_id = "gpt-4o"
|
M.tokenizer_id = "gpt-4o"
|
||||||
|
M.role_map = {
|
||||||
|
user = "user",
|
||||||
|
assistant = "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_messages = function(opts)
|
||||||
return {
|
local messages = {
|
||||||
{ role = "system", content = opts.system_prompt },
|
{ role = "system", content = opts.system_prompt },
|
||||||
{ role = "user", content = table.concat(opts.user_prompts, "\n") },
|
|
||||||
}
|
}
|
||||||
|
vim
|
||||||
|
.iter(opts.messages)
|
||||||
|
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
|
||||||
|
return messages
|
||||||
end
|
end
|
||||||
|
|
||||||
M.parse_response = O.parse_response
|
M.parse_response = O.parse_response
|
||||||
@ -146,7 +153,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
},
|
},
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = base.model,
|
model = base.model,
|
||||||
messages = M.parse_message(code_opts),
|
messages = M.parse_messages(code_opts),
|
||||||
stream = true,
|
stream = true,
|
||||||
}, body_opts),
|
}, body_opts),
|
||||||
}
|
}
|
||||||
|
@ -6,10 +6,34 @@ local Clipboard = require("avante.clipboard")
|
|||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.api_key_name = "GEMINI_API_KEY"
|
M.api_key_name = "GEMINI_API_KEY"
|
||||||
|
M.role_map = {
|
||||||
|
user = "user",
|
||||||
|
assistant = "model",
|
||||||
|
}
|
||||||
-- M.tokenizer_id = "google/gemma-2b"
|
-- M.tokenizer_id = "google/gemma-2b"
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_messages = function(opts)
|
||||||
local message_content = {}
|
local contents = {}
|
||||||
|
local prev_role = nil
|
||||||
|
|
||||||
|
vim.iter(opts.messages):each(function(message)
|
||||||
|
local role = message.role
|
||||||
|
if role == prev_role then
|
||||||
|
if role == "user" then
|
||||||
|
table.insert(contents, { role = "model", parts = {
|
||||||
|
{ text = "Ok, I understand." },
|
||||||
|
} })
|
||||||
|
else
|
||||||
|
table.insert(contents, { role = "user", parts = {
|
||||||
|
{ text = "Ok" },
|
||||||
|
} })
|
||||||
|
end
|
||||||
|
end
|
||||||
|
prev_role = role
|
||||||
|
table.insert(contents, { role = M.role_map[role] or role, parts = {
|
||||||
|
{ text = message.content },
|
||||||
|
} })
|
||||||
|
end)
|
||||||
|
|
||||||
if Clipboard.support_paste_image() and opts.image_paths then
|
if Clipboard.support_paste_image() and opts.image_paths then
|
||||||
for _, image_path in ipairs(opts.image_paths) do
|
for _, image_path in ipairs(opts.image_paths) do
|
||||||
@ -20,13 +44,10 @@ M.parse_message = function(opts)
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
table.insert(message_content, image_data)
|
table.insert(contents[#contents].parts, image_data)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
-- insert a part into parts
|
|
||||||
table.insert(message_content, { text = table.concat(opts.user_prompts, "\n") })
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
systemInstruction = {
|
systemInstruction = {
|
||||||
role = "user",
|
role = "user",
|
||||||
@ -36,12 +57,7 @@ M.parse_message = function(opts)
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
contents = {
|
contents = contents,
|
||||||
{
|
|
||||||
role = "user",
|
|
||||||
parts = message_content,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -78,7 +94,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
proxy = base.proxy,
|
proxy = base.proxy,
|
||||||
insecure = base.allow_insecure,
|
insecure = base.allow_insecure,
|
||||||
headers = { ["Content-Type"] = "application/json" },
|
headers = { ["Content-Type"] = "application/json" },
|
||||||
body = vim.tbl_deep_extend("force", {}, M.parse_message(code_opts), body_opts),
|
body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -14,22 +14,22 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
|
|||||||
---@field on_chunk AvanteChunkParser
|
---@field on_chunk AvanteChunkParser
|
||||||
---@field on_complete AvanteCompleteParser
|
---@field on_complete AvanteCompleteParser
|
||||||
---
|
---
|
||||||
|
---@class AvanteLLMMessage
|
||||||
|
---@field role "user" | "assistant"
|
||||||
|
---@field content string
|
||||||
|
---
|
||||||
---@class AvantePromptOptions: table<[string], string>
|
---@class AvantePromptOptions: table<[string], string>
|
||||||
---@field system_prompt string
|
---@field system_prompt string
|
||||||
---@field user_prompts string[]
|
---@field messages AvanteLLMMessage[]
|
||||||
---@field image_paths? string[]
|
---@field image_paths? string[]
|
||||||
---
|
---
|
||||||
---@class AvanteBaseMessage
|
|
||||||
---@field role "user" | "system"
|
|
||||||
---@field content string
|
|
||||||
---
|
|
||||||
---@class AvanteGeminiMessage
|
---@class AvanteGeminiMessage
|
||||||
---@field role "user"
|
---@field role "user"
|
||||||
---@field parts { text: string }[]
|
---@field parts { text: string }[]
|
||||||
---
|
---
|
||||||
---@alias AvanteChatMessage AvanteClaudeMessage | OpenAIMessage | AvanteGeminiMessage
|
---@alias AvanteChatMessage AvanteClaudeMessage | OpenAIMessage | AvanteGeminiMessage
|
||||||
---
|
---
|
||||||
---@alias AvanteMessageParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
|
---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
|
||||||
---
|
---
|
||||||
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>}
|
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>}
|
||||||
---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput
|
---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||||
@ -65,13 +65,14 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
|
|||||||
---@field parse_api_key? fun(): string | nil
|
---@field parse_api_key? fun(): string | nil
|
||||||
---
|
---
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
---@field parse_message AvanteMessageParser
|
---@field role_map table<"user" | "assistant", string>
|
||||||
|
---@field parse_messages AvanteMessagesParser
|
||||||
---@field parse_response AvanteResponseParser
|
---@field parse_response AvanteResponseParser
|
||||||
---@field parse_curl_args AvanteCurlArgsParser
|
---@field parse_curl_args AvanteCurlArgsParser
|
||||||
---@field setup fun(): nil
|
---@field setup fun(): nil
|
||||||
---@field has fun(): boolean
|
---@field has fun(): boolean
|
||||||
---@field api_key_name string
|
---@field api_key_name string
|
||||||
---@field tokenizer_id [string] | "gpt-4o"
|
---@field tokenizer_id string | "gpt-4o"
|
||||||
---@field use_xml_format boolean
|
---@field use_xml_format boolean
|
||||||
---@field model? string
|
---@field model? string
|
||||||
---@field parse_api_key fun(): string | nil
|
---@field parse_api_key fun(): string | nil
|
||||||
|
@ -33,29 +33,15 @@ local M = {}
|
|||||||
|
|
||||||
M.api_key_name = "OPENAI_API_KEY"
|
M.api_key_name = "OPENAI_API_KEY"
|
||||||
|
|
||||||
|
M.role_map = {
|
||||||
|
user = "user",
|
||||||
|
assistant = "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
---@param opts AvantePromptOptions
|
---@param opts AvantePromptOptions
|
||||||
M.get_user_message = function(opts) return table.concat(opts.user_prompts, "\n") end
|
M.get_user_message = function(opts) return table.concat(opts.messages, "\n") end
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
|
||||||
---@type OpenAIMessage[]
|
|
||||||
local user_content = {}
|
|
||||||
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
|
|
||||||
for _, image_path in ipairs(opts.image_paths) do
|
|
||||||
table.insert(user_content, {
|
|
||||||
type = "image_url",
|
|
||||||
image_url = {
|
|
||||||
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
end
|
|
||||||
vim.iter(opts.user_prompts):each(function(prompt) table.insert(user_content, { type = "text", text = prompt }) end)
|
|
||||||
else
|
|
||||||
user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
|
|
||||||
table.insert(acc, { type = "text", text = prompt })
|
|
||||||
return acc
|
|
||||||
end)
|
|
||||||
end
|
|
||||||
|
|
||||||
|
M.parse_messages = function(opts)
|
||||||
local messages = {}
|
local messages = {}
|
||||||
local provider = P[Config.provider]
|
local provider = P[Config.provider]
|
||||||
local base, _ = P.parse_config(provider)
|
local base, _ = P.parse_config(provider)
|
||||||
@ -68,8 +54,23 @@ M.parse_message = function(opts)
|
|||||||
table.insert(messages, { role = "system", content = opts.system_prompt })
|
table.insert(messages, { role = "system", content = opts.system_prompt })
|
||||||
end
|
end
|
||||||
|
|
||||||
-- User message after the prompt
|
vim
|
||||||
table.insert(messages, { role = "user", content = user_content })
|
.iter(opts.messages)
|
||||||
|
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
|
||||||
|
|
||||||
|
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
|
||||||
|
local message_content = messages[#messages].content
|
||||||
|
if type(message_content) ~= "table" then message_content = { type = "text", text = message_content } end
|
||||||
|
for _, image_path in ipairs(opts.image_paths) do
|
||||||
|
table.insert(message_content, {
|
||||||
|
type = "image_url",
|
||||||
|
image_url = {
|
||||||
|
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
end
|
||||||
|
messages[#messages].content = message_content
|
||||||
|
end
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
end
|
end
|
||||||
@ -128,7 +129,7 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
headers = headers,
|
headers = headers,
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = base.model,
|
model = base.model,
|
||||||
messages = M.parse_message(code_opts),
|
messages = M.parse_messages(code_opts),
|
||||||
stream = stream,
|
stream = stream,
|
||||||
}, body_opts),
|
}, body_opts),
|
||||||
}
|
}
|
||||||
|
@ -1392,10 +1392,23 @@ function Sidebar:create_input(opts)
|
|||||||
|
|
||||||
local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil
|
local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil
|
||||||
|
|
||||||
|
local history_messages = vim.tbl_map(
|
||||||
|
function(history)
|
||||||
|
return {
|
||||||
|
{ role = "user", content = history.request },
|
||||||
|
{ role = "assistant", content = history.original_response },
|
||||||
|
}
|
||||||
|
end,
|
||||||
|
chat_history
|
||||||
|
)
|
||||||
|
|
||||||
|
history_messages = vim.iter(history_messages):flatten():totable()
|
||||||
|
|
||||||
Llm.stream({
|
Llm.stream({
|
||||||
bufnr = self.code.bufnr,
|
bufnr = self.code.bufnr,
|
||||||
ask = opts.ask,
|
ask = opts.ask,
|
||||||
project_context = vim.json.encode(project_context),
|
project_context = vim.json.encode(project_context),
|
||||||
|
history_messages = history_messages,
|
||||||
file_content = content,
|
file_content = content,
|
||||||
code_lang = filetype,
|
code_lang = filetype,
|
||||||
selected_code = selected_code_content,
|
selected_code = selected_code_content,
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
{%- if use_xml_format -%}
|
|
||||||
{%- if memory_context -%}
|
|
||||||
<memory_context>
|
|
||||||
{{memory_context}}
|
|
||||||
</memory_context>
|
|
||||||
{%- endif %}
|
|
||||||
{%- else -%}
|
|
||||||
{%- if memory_context -%}
|
|
||||||
MEMORY CONTEXT:
|
|
||||||
{{memory_context}}
|
|
||||||
{%- endif %}
|
|
||||||
{%- endif %}
|
|
@ -7,10 +7,11 @@
|
|||||||
"file_content": "local Config = require('avante.config')"
|
"file_content": "local Config = require('avante.config')"
|
||||||
}
|
}
|
||||||
#}
|
#}
|
||||||
|
Act as an expert software developer.
|
||||||
|
Always use best practices when coding.
|
||||||
|
Respect and use existing conventions, libraries, etc that are already present in the code base.
|
||||||
|
|
||||||
{%- if ask %}
|
{%- if ask %}
|
||||||
{%- if not use_xml_format -%}
|
|
||||||
INSTRUCTION:{% else -%}
|
|
||||||
<instruction>{% endif -%}
|
|
||||||
{% block user_prompt %}
|
{% block user_prompt %}
|
||||||
Take requests for changes to the supplied code.
|
Take requests for changes to the supplied code.
|
||||||
If the request is ambiguous, ask questions.
|
If the request is ambiguous, ask questions.
|
||||||
@ -150,19 +151,4 @@ To rename files which have been added to the chat, use shell commands at the end
|
|||||||
|
|
||||||
ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*!
|
ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*!
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
{%- if use_xml_format -%}
|
|
||||||
</instruction>
|
|
||||||
|
|
||||||
<question>{{question}}</question>
|
|
||||||
{%- else %}
|
|
||||||
QUESTION:
|
|
||||||
{{question}}
|
|
||||||
{%- endif %}
|
|
||||||
{% else %}
|
|
||||||
{% if use_xml_format -%}
|
|
||||||
<question>{{question}}</question>
|
|
||||||
{% else %}
|
|
||||||
QUESTION:
|
|
||||||
{{question}}
|
|
||||||
{%- endif %}
|
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user