diff --git a/crates/avante-templates/src/lib.rs b/crates/avante-templates/src/lib.rs index 0f05669..55c18ec 100644 --- a/crates/avante-templates/src/lib.rs +++ b/crates/avante-templates/src/lib.rs @@ -19,13 +19,11 @@ impl<'a> State<'a> { struct TemplateContext { use_xml_format: bool, ask: bool, - question: String, code_lang: String, filepath: String, file_content: String, selected_code: Option<String>, project_context: Option<String>, - memory_context: Option<String>, } // Given the file name registered after add, the context table in Lua, resulted in a formatted @@ -44,13 +42,11 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult< .render(context! { use_xml_format => context.use_xml_format, ask => context.ask, - question => context.question, code_lang => context.code_lang, filepath => context.filepath, file_content => context.file_content, selected_code => context.selected_code, project_context => context.project_context, - memory_context => context.memory_context, }) .map_err(LuaError::external) .unwrap()) diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 2d61112..d80ee56 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -18,15 +18,6 @@ M.defaults = { -- For most providers that we support we will determine this automatically. -- If you wish to use a given implementation, then you can override it here. tokenizer = "tiktoken", - ---@alias AvanteSystemPrompt string - -- Default system prompt. Users can override this with their own prompt - -- You can use `require('avante.config').override({system_prompt = "MY_SYSTEM_PROMPT"}) conditionally - -- in your own autocmds to do it per directory, or that fit your needs. - system_prompt = [[ -Act as an expert software developer. -Always use best practices when coding. -Respect and use existing conventions, libraries, etc that are already present in the code base. -]], ---@type AvanteSupportedProvider openai = { endpoint = "https://api.openai.com/v1", @@ -102,6 +93,7 @@ Respect and use existing conventions, libraries, etc that are already present in support_paste_from_clipboard = false, }, history = { + max_tokens = 4096, storage_path = vim.fn.stdpath("state") .. "/avante", paste = { extension = "png", @@ -315,6 +307,7 @@ M.BASE_PROVIDER_KEYS = { "_shellenv", "tokenizer_id", "use_xml_format", + "role_map", } return M diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index be14af7..8b31d30 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -28,7 +28,7 @@ local group = api.nvim_create_augroup("avante_llm", { clear = true }) ---@field file_content string ---@field selected_code string | nil ---@field project_context string | nil ----@field memory_context string | nil +---@field history_messages AvanteLLMMessage[] --- ---@class StreamOptions: TemplateOptions ---@field ask boolean @@ -44,10 +44,12 @@ M.stream = function(opts) local mode = opts.mode or "planning" ---@type AvanteProviderFunctor local Provider = opts.provider or P[Config.provider] + local _, body_opts = P.parse_config(Provider) + local max_tokens = body_opts.max_tokens or 4096 -- Check if the instructions contains an image path local image_paths = {} - local original_instructions = opts.instructions + local instructions = opts.instructions if opts.instructions:match("image: ") then local lines = vim.split(opts.instructions, "\n") for i, line in ipairs(lines) do @@ -57,7 +59,7 @@ M.stream = function(opts) table.remove(lines, i) end end - original_instructions = table.concat(lines, "\n") + instructions = table.concat(lines, "\n") end Path.prompts.initialize(Path.prompts.get(opts.bufnr)) @@ -67,29 +69,61 @@ M.stream = function(opts) local template_opts = { use_xml_format = Provider.use_xml_format, ask = opts.ask, -- TODO: add mode without ask instruction - question = original_instructions, code_lang = opts.code_lang, filepath = filepath, file_content = opts.file_content, selected_code = opts.selected_code, project_context = opts.project_context, - memory_context = opts.memory_context, } - local user_prompts = vim - .iter({ - Path.prompts.render_file("_project.avanterules", template_opts), - Path.prompts.render_file("_memory.avanterules", template_opts), - Path.prompts.render_file("_context.avanterules", template_opts), - Path.prompts.render_mode(mode, template_opts), - }) - :filter(function(k) return k ~= "" end) - :totable() + local system_prompt = Path.prompts.render_mode(mode, template_opts) + + ---@type AvanteLLMMessage[] + local messages = {} + + if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then + local project_context = Path.prompts.render_file("_project.avanterules", template_opts) + if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end + end + + local code_context = Path.prompts.render_file("_context.avanterules", template_opts) + if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end + + if opts.use_xml_format then + table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) }) + else + table.insert(messages, { role = "user", content = string.format("QUESTION:\n%s", instructions) }) + end + + local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt) + + for _, message in ipairs(messages) do + remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content) + end + + if opts.history_messages then + if Config.history.max_tokens > 0 then remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens) end + -- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user" + local history_messages = {} + for i = #opts.history_messages, 1, -1 do + local message = opts.history_messages[i] + local tokens = Utils.tokens.calculate_tokens(message.content) + remaining_tokens = remaining_tokens - tokens + if remaining_tokens > 0 then + table.insert(history_messages, message) + else + break + end + end + if #history_messages > 0 and history_messages[1].role == "assistant" then table.remove(history_messages, 1) end + -- prepend the history messages to the messages table + vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end) + end ---@type AvantePromptOptions local code_opts = { - system_prompt = Config.system_prompt, - user_prompts = user_prompts, + system_prompt = system_prompt, + messages = messages, image_paths = image_paths, } @@ -164,7 +198,7 @@ M.stream = function(opts) on_error = function(err) if err.exit == 23 then local xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR") - if fn.isdirectory(xdg_runtime_dir) == 0 then + if not xdg_runtime_dir or fn.isdirectory(xdg_runtime_dir) == 0 then Utils.error( "$XDG_RUNTIME_DIR=" .. xdg_runtime_dir diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua index 1f56d1a..9839e09 100644 --- a/lua/avante/providers/azure.lua +++ b/lua/avante/providers/azure.lua @@ -13,7 +13,7 @@ local M = {} M.api_key_name = "AZURE_OPENAI_API_KEY" -M.parse_message = O.parse_message +M.parse_messages = O.parse_messages M.parse_response = O.parse_response M.parse_curl_args = function(provider, code_opts) @@ -34,7 +34,7 @@ M.parse_curl_args = function(provider, code_opts) insecure = base.allow_insecure, headers = headers, body = vim.tbl_deep_extend("force", { - messages = M.parse_message(code_opts), + messages = M.parse_messages(code_opts), stream = true, }, body_opts), } diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 7efd891..4b0f818 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -13,8 +13,8 @@ local P = require("avante.providers") ---@field type "image" ---@field source {type: "base64", media_type: string, data: string} --- ----@class AvanteClaudeMessage: AvanteBaseMessage ----@field role "user" +---@class AvanteClaudeMessage +---@field role "user" | "assistant" ---@field content [AvanteClaudeTextMessage | AvanteClaudeImageMessage][] ---@class AvanteProviderFunctor @@ -23,11 +23,44 @@ local M = {} M.api_key_name = "ANTHROPIC_API_KEY" M.use_xml_format = true -M.parse_message = function(opts) - ---@type AvanteClaudeMessage[] - local message_content = {} +M.role_map = { + user = "user", + assistant = "assistant", +} - if Clipboard.support_paste_image() and opts.image_paths then +M.parse_messages = function(opts) + ---@type AvanteClaudeMessage[] + local messages = {} + + ---@type {idx: integer, length: integer}[] + local messages_with_length = {} + for idx, message in ipairs(opts.messages) do + table.insert(messages_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(message.content) }) + end + + table.sort(messages_with_length, function(a, b) return a.length > b.length end) + + ---@type table<integer, boolean> + local top_three = {} + for i = 1, math.min(3, #messages_with_length) do + top_three[messages_with_length[i].idx] = true + end + + for idx, message in ipairs(opts.messages) do + table.insert(messages, { + role = M.role_map[message.role], + content = { + { + type = "text", + text = message.content, + cache_control = top_three[idx] and { type = "ephemeral" } or nil, + }, + }, + }) + end + + if Clipboard.support_paste_image() and opts.image_paths and #opts.image_paths > 0 then + local message_content = messages[#messages].content for _, image_path in ipairs(opts.image_paths) do table.insert(message_content, { type = "image", @@ -38,36 +71,10 @@ M.parse_message = function(opts) }, }) end + messages[#messages].content = message_content end - ---@type {idx: integer, length: integer}[] - local user_prompts_with_length = {} - for idx, user_prompt in ipairs(opts.user_prompts) do - table.insert(user_prompts_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(user_prompt) }) - end - - table.sort(user_prompts_with_length, function(a, b) return a.length > b.length end) - - ---@type table<integer, boolean> - local top_three = {} - for i = 1, math.min(3, #user_prompts_with_length) do - top_three[user_prompts_with_length[i].idx] = true - end - - for idx, prompt_data in ipairs(opts.user_prompts) do - table.insert(message_content, { - type = "text", - text = prompt_data, - cache_control = top_three[idx] and { type = "ephemeral" } or nil, - }) - end - - return { - { - role = "user", - content = message_content, - }, - } + return messages end M.parse_response = function(data_stream, event_state, opts) @@ -96,7 +103,7 @@ M.parse_curl_args = function(provider, prompt_opts) } if not P.env.is_local("claude") then headers["x-api-key"] = provider.parse_api_key() end - local messages = M.parse_message(prompt_opts) + local messages = M.parse_messages(prompt_opts) return { url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/messages", diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua index db020de..c2843f7 100644 --- a/lua/avante/providers/cohere.lua +++ b/lua/avante/providers/cohere.lua @@ -42,17 +42,18 @@ local M = {} M.api_key_name = "CO_API_KEY" M.tokenizer_id = "https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json" +M.role_map = { + user = "user", + assistant = "assistant", +} -M.parse_message = function(opts) - ---@type CohereMessage[] - local user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt) - table.insert(acc, { type = "text", text = prompt }) - return acc - end) +M.parse_messages = function(opts) local messages = { { role = "system", content = opts.system_prompt }, - { role = "user", content = user_content }, } + vim + .iter(opts.messages) + :each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end) return { messages = messages } end @@ -91,7 +92,7 @@ M.parse_curl_args = function(provider, code_opts) body = vim.tbl_deep_extend("force", { model = base.model, stream = true, - }, M.parse_message(code_opts), body_opts), + }, M.parse_messages(code_opts), body_opts), } end diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index eba07c7..1078859 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -118,12 +118,19 @@ M.state = nil M.api_key_name = P.AVANTE_INTERNAL_KEY M.tokenizer_id = "gpt-4o" +M.role_map = { + user = "user", + assistant = "assistant", +} -M.parse_message = function(opts) - return { +M.parse_messages = function(opts) + local messages = { { role = "system", content = opts.system_prompt }, - { role = "user", content = table.concat(opts.user_prompts, "\n") }, } + vim + .iter(opts.messages) + :each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end) + return messages end M.parse_response = O.parse_response @@ -146,7 +153,7 @@ M.parse_curl_args = function(provider, code_opts) }, body = vim.tbl_deep_extend("force", { model = base.model, - messages = M.parse_message(code_opts), + messages = M.parse_messages(code_opts), stream = true, }, body_opts), } diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 7e83ee7..4ad9694 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -6,10 +6,34 @@ local Clipboard = require("avante.clipboard") local M = {} M.api_key_name = "GEMINI_API_KEY" +M.role_map = { + user = "user", + assistant = "model", +} -- M.tokenizer_id = "google/gemma-2b" -M.parse_message = function(opts) - local message_content = {} +M.parse_messages = function(opts) + local contents = {} + local prev_role = nil + + vim.iter(opts.messages):each(function(message) + local role = message.role + if role == prev_role then + if role == "user" then + table.insert(contents, { role = "model", parts = { + { text = "Ok, I understand." }, + } }) + else + table.insert(contents, { role = "user", parts = { + { text = "Ok" }, + } }) + end + end + prev_role = role + table.insert(contents, { role = M.role_map[role] or role, parts = { + { text = message.content }, + } }) + end) if Clipboard.support_paste_image() and opts.image_paths then for _, image_path in ipairs(opts.image_paths) do @@ -20,13 +44,10 @@ M.parse_message = function(opts) }, } - table.insert(message_content, image_data) + table.insert(contents[#contents].parts, image_data) end end - -- insert a part into parts - table.insert(message_content, { text = table.concat(opts.user_prompts, "\n") }) - return { systemInstruction = { role = "user", @@ -36,12 +57,7 @@ M.parse_message = function(opts) }, }, }, - contents = { - { - role = "user", - parts = message_content, - }, - }, + contents = contents, } end @@ -78,7 +94,7 @@ M.parse_curl_args = function(provider, code_opts) proxy = base.proxy, insecure = base.allow_insecure, headers = { ["Content-Type"] = "application/json" }, - body = vim.tbl_deep_extend("force", {}, M.parse_message(code_opts), body_opts), + body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts), } end diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 2debe03..6582648 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -14,22 +14,22 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field on_chunk AvanteChunkParser ---@field on_complete AvanteCompleteParser --- +---@class AvanteLLMMessage +---@field role "user" | "assistant" +---@field content string +--- ---@class AvantePromptOptions: table<[string], string> ---@field system_prompt string ----@field user_prompts string[] +---@field messages AvanteLLMMessage[] ---@field image_paths? string[] --- ----@class AvanteBaseMessage ----@field role "user" | "system" ----@field content string ---- ---@class AvanteGeminiMessage ---@field role "user" ---@field parts { text: string }[] --- ---@alias AvanteChatMessage AvanteClaudeMessage | OpenAIMessage | AvanteGeminiMessage --- ----@alias AvanteMessageParser fun(opts: AvantePromptOptions): AvanteChatMessage[] +---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[] --- ---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>} ---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput @@ -65,13 +65,14 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field parse_api_key? fun(): string | nil --- ---@class AvanteProviderFunctor ----@field parse_message AvanteMessageParser +---@field role_map table<"user" | "assistant", string> +---@field parse_messages AvanteMessagesParser ---@field parse_response AvanteResponseParser ---@field parse_curl_args AvanteCurlArgsParser ---@field setup fun(): nil ---@field has fun(): boolean ---@field api_key_name string ----@field tokenizer_id [string] | "gpt-4o" +---@field tokenizer_id string | "gpt-4o" ---@field use_xml_format boolean ---@field model? string ---@field parse_api_key fun(): string | nil diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 57b3ae0..01ffa55 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -33,29 +33,15 @@ local M = {} M.api_key_name = "OPENAI_API_KEY" +M.role_map = { + user = "user", + assistant = "assistant", +} + ---@param opts AvantePromptOptions -M.get_user_message = function(opts) return table.concat(opts.user_prompts, "\n") end - -M.parse_message = function(opts) - ---@type OpenAIMessage[] - local user_content = {} - if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then - for _, image_path in ipairs(opts.image_paths) do - table.insert(user_content, { - type = "image_url", - image_url = { - url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path), - }, - }) - end - vim.iter(opts.user_prompts):each(function(prompt) table.insert(user_content, { type = "text", text = prompt }) end) - else - user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt) - table.insert(acc, { type = "text", text = prompt }) - return acc - end) - end +M.get_user_message = function(opts) return table.concat(opts.messages, "\n") end +M.parse_messages = function(opts) local messages = {} local provider = P[Config.provider] local base, _ = P.parse_config(provider) @@ -68,8 +54,23 @@ M.parse_message = function(opts) table.insert(messages, { role = "system", content = opts.system_prompt }) end - -- User message after the prompt - table.insert(messages, { role = "user", content = user_content }) + vim + .iter(opts.messages) + :each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end) + + if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then + local message_content = messages[#messages].content + if type(message_content) ~= "table" then message_content = { type = "text", text = message_content } end + for _, image_path in ipairs(opts.image_paths) do + table.insert(message_content, { + type = "image_url", + image_url = { + url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path), + }, + }) + end + messages[#messages].content = message_content + end return messages end @@ -128,7 +129,7 @@ M.parse_curl_args = function(provider, code_opts) headers = headers, body = vim.tbl_deep_extend("force", { model = base.model, - messages = M.parse_message(code_opts), + messages = M.parse_messages(code_opts), stream = stream, }, body_opts), } diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 06499ea..d280fbd 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -1392,10 +1392,23 @@ function Sidebar:create_input(opts) local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil + local history_messages = vim.tbl_map( + function(history) + return { + { role = "user", content = history.request }, + { role = "assistant", content = history.original_response }, + } + end, + chat_history + ) + + history_messages = vim.iter(history_messages):flatten():totable() + Llm.stream({ bufnr = self.code.bufnr, ask = opts.ask, project_context = vim.json.encode(project_context), + history_messages = history_messages, file_content = content, code_lang = filetype, selected_code = selected_code_content, diff --git a/lua/avante/templates/_memory.avanterules b/lua/avante/templates/_memory.avanterules deleted file mode 100644 index 69ac2a3..0000000 --- a/lua/avante/templates/_memory.avanterules +++ /dev/null @@ -1,12 +0,0 @@ -{%- if use_xml_format -%} -{%- if memory_context -%} -<memory_context> -{{memory_context}} -</memory_context> -{%- endif %} -{%- else -%} -{%- if memory_context -%} -MEMORY CONTEXT: -{{memory_context}} -{%- endif %} -{%- endif %} diff --git a/lua/avante/templates/planning.avanterules b/lua/avante/templates/planning.avanterules index 8f1a5d3..74b2e31 100644 --- a/lua/avante/templates/planning.avanterules +++ b/lua/avante/templates/planning.avanterules @@ -7,10 +7,11 @@ "file_content": "local Config = require('avante.config')" } #} +Act as an expert software developer. +Always use best practices when coding. +Respect and use existing conventions, libraries, etc that are already present in the code base. + {%- if ask %} -{%- if not use_xml_format -%} -INSTRUCTION:{% else -%} -<instruction>{% endif -%} {% block user_prompt %} Take requests for changes to the supplied code. If the request is ambiguous, ask questions. @@ -150,19 +151,4 @@ To rename files which have been added to the chat, use shell commands at the end ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*! {% endblock %} -{%- if use_xml_format -%} -</instruction> - -<question>{{question}}</question> -{%- else %} -QUESTION: -{{question}} -{%- endif %} -{% else %} -{% if use_xml_format -%} -<question>{{question}}</question> -{% else %} -QUESTION: -{{question}} -{%- endif %} {%- endif %}