From 2b89f0d529add18fd39cee37073c17c8179b2358 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 4 Sep 2024 03:19:33 -0400 Subject: [PATCH] perf(anthropic): prompt-caching (#517) bring back prompt caching support on Anthropic Signed-off-by: Aaron Pham --- lua/avante/llm.lua | 19 ++++++-- lua/avante/path.lua | 6 ++- lua/avante/providers/claude.lua | 26 ++++++++--- lua/avante/providers/cohere.lua | 2 +- lua/avante/providers/copilot.lua | 2 +- lua/avante/providers/gemini.lua | 2 +- lua/avante/providers/init.lua | 2 +- lua/avante/providers/openai.lua | 14 +++--- lua/avante/templates/_context.avanterules | 37 +++++++++++++++ lua/avante/templates/_memory.avanterules | 12 +++++ lua/avante/templates/_project.avanterules | 12 +++++ lua/avante/templates/planning.avanterules | 57 ++--------------------- 12 files changed, 116 insertions(+), 75 deletions(-) create mode 100644 lua/avante/templates/_context.avanterules create mode 100644 lua/avante/templates/_memory.avanterules create mode 100644 lua/avante/templates/_project.avanterules diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 2846523..21e7a62 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -57,7 +57,8 @@ M.stream = function(opts) end Path.prompts.initialize(Path.prompts.get(opts.bufnr)) - local user_prompt = Path.prompts.render(mode, { + + local template_opts = { use_xml_format = Provider.use_xml_format, ask = true, -- TODO: add mode without ask instruction question = original_instructions, @@ -66,14 +67,24 @@ M.stream = function(opts) selected_code = opts.selected_code, project_context = opts.project_context, memory_context = opts.memory_context, - }) + } - Utils.debug(user_prompt) + local user_prompts = vim + .iter({ + Path.prompts.render_file("_context.avanterules", template_opts), + Path.prompts.render_file("_project.avanterules", template_opts), + Path.prompts.render_file("_memory.avanterules", template_opts), + Path.prompts.render_mode(mode, template_opts), + }) + :filter(function(k) return k ~= "" end) + :totable() + + Utils.debug(user_prompts) ---@type AvantePromptOptions local code_opts = { system_prompt = Config.system_prompt, - user_prompt = user_prompt, + user_prompts = user_prompts, image_paths = image_paths, } diff --git a/lua/avante/path.lua b/lua/avante/path.lua index dbc727a..4327abe 100644 --- a/lua/avante/path.lua +++ b/lua/avante/path.lua @@ -110,9 +110,13 @@ N.get_file = function(mode) return string.format("%s.avanterules", mode) end +---@param path string +---@param opts TemplateOptions +N.render_file = function(path, opts) return templates.render(path, opts) end + ---@param mode LlmMode ---@param opts TemplateOptions -N.render = function(mode, opts) return templates.render(N.get_file(mode), opts) end +N.render_mode = function(mode, opts) return templates.render(N.get_file(mode), opts) end N.initialize = function(directory) templates.initialize(directory) end diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 6e8c181..aa999c3 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -24,13 +24,27 @@ M.parse_message = function(opts) end end - local user_prompt_obj = { - type = "text", - text = opts.user_prompt, - } - if Utils.tokens.calculate_tokens(opts.user_prompt) then user_prompt_obj.cache_control = { type = "ephemeral" } end + ---@type {idx: integer, length: integer}[] + local user_prompts_with_length = {} + for idx, user_prompt in ipairs(opts.user_prompts) do + table.insert(user_prompts_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(user_prompt) }) + end - table.insert(message_content, user_prompt_obj) + table.sort(user_prompts_with_length, function(a, b) return a.length > b.length end) + + ---@type table + local top_three = {} + for i = 1, math.min(3, #user_prompts_with_length) do + top_three[user_prompts_with_length[i].idx] = true + end + + for idx, prompt_data in ipairs(opts.user_prompts) do + table.insert(message_content, { + type = "text", + text = prompt_data, + cache_control = top_three[idx] and { type = "ephemeral" } or nil, + }) + end return { { diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua index d24f2e7..eab6d00 100644 --- a/lua/avante/providers/cohere.lua +++ b/lua/avante/providers/cohere.lua @@ -34,7 +34,7 @@ M.tokenizer_id = "CohereForAI/c4ai-command-r-plus-08-2024" M.parse_message = function(opts) return { preamble = opts.system_prompt, - message = opts.user_prompt, + message = table.concat(opts.user_prompts, "\n"), } end diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index 2b25d0f..fc7f2b3 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -118,7 +118,7 @@ M.tokenizer_id = "gpt-4o" M.parse_message = function(opts) return { { role = "system", content = opts.system_prompt }, - { role = "user", content = opts.user_prompt }, + { role = "user", content = table.concat(opts.user_prompts, "\n") }, } end diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 06a2b50..7e83ee7 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -25,7 +25,7 @@ M.parse_message = function(opts) end -- insert a part into parts - table.insert(message_content, { text = opts.user_prompt }) + table.insert(message_content, { text = table.concat(opts.user_prompts, "\n") }) return { systemInstruction = { diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index d0f9482..9cb8874 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -16,7 +16,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } --- ---@class AvantePromptOptions: table<[string], string> ---@field system_prompt string ----@field user_prompt string +---@field user_prompts string[] ---@field image_paths? string[] --- ---@class AvanteBaseMessage diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index d35c9fa..52e62b1 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -28,13 +28,12 @@ local M = {} M.api_key_name = "OPENAI_API_KEY" ---@param opts AvantePromptOptions -M.get_user_message = function(opts) return opts.user_prompt end +M.get_user_message = function(opts) return table.concat(opts.user_prompts, "\n") end M.parse_message = function(opts) - ---@type string | OpenAIMessage[] - local user_content + ---@type OpenAIMessage[] + local user_content = {} if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then - user_content = {} for _, image_path in ipairs(opts.image_paths) do table.insert(user_content, { type = "image_url", @@ -43,9 +42,12 @@ M.parse_message = function(opts) }, }) end - table.insert(user_content, { type = "text", text = opts.user_prompt }) + vim.iter(opts.user_prompts):each(function(prompt) table.insert(user_content, { type = "text", text = prompt }) end) else - user_content = opts.user_prompt + user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt) + table.insert(acc, { type = "text", text = prompt }) + return acc + end) end return { diff --git a/lua/avante/templates/_context.avanterules b/lua/avante/templates/_context.avanterules new file mode 100644 index 0000000..6f60bdb --- /dev/null +++ b/lua/avante/templates/_context.avanterules @@ -0,0 +1,37 @@ +{%- if use_xml_format -%} +{%- if selected_code -%} + +```{{code_lang}} +{{file_content}} +``` + + + +```{{code_lang}} +{{selected_code}} +``` + +{%- else -%} + +```{{code_lang}} +{{file_content}} +``` + +{%- endif %} +{% else %} +{%- if selected_code -%} +CONTEXT: +```{{code_lang}} +{{file_content}} +``` + +CODE: +```{{code_lang}} +{{selected_code}} +``` +{%- else -%} +CODE: +```{{code_lang}} +{{file_content}} +``` +{%- endif %}{%- endif %} diff --git a/lua/avante/templates/_memory.avanterules b/lua/avante/templates/_memory.avanterules new file mode 100644 index 0000000..69ac2a3 --- /dev/null +++ b/lua/avante/templates/_memory.avanterules @@ -0,0 +1,12 @@ +{%- if use_xml_format -%} +{%- if memory_context -%} + +{{memory_context}} + +{%- endif %} +{%- else -%} +{%- if memory_context -%} +MEMORY CONTEXT: +{{memory_context}} +{%- endif %} +{%- endif %} diff --git a/lua/avante/templates/_project.avanterules b/lua/avante/templates/_project.avanterules new file mode 100644 index 0000000..c92c579 --- /dev/null +++ b/lua/avante/templates/_project.avanterules @@ -0,0 +1,12 @@ +{%- if use_xml_format -%} +{%- if project_context -%} + +{{project_context}} + +{%- endif %} +{%- else -%} +{%- if project_context -%} +PROJECT CONTEXT: +{{project_context}} +{%- endif %} +{%- endif %} diff --git a/lua/avante/templates/planning.avanterules b/lua/avante/templates/planning.avanterules index c63effb..fcc2989 100644 --- a/lua/avante/templates/planning.avanterules +++ b/lua/avante/templates/planning.avanterules @@ -7,60 +7,9 @@ "file_content": "local Config = require('avante.config')" } #} -{%- if use_xml_format -%} -{%- if selected_code -%} - -```{{code_lang}} -{{file_content}} -``` - - - -```{{code_lang}} -{{selected_code}} -``` - -{%- else -%} - -```{{code_lang}} -{{file_content}} -``` - -{%- endif %}{%- if project_context -%} - -{{project_context}} - -{%- endif %}{%- if memory_context -%} - -{{memory_context}} - -{%- endif %} -{% else %} -{%- if selected_code -%} -CONTEXT: -```{{code_lang}} -{{file_content}} -``` - -CODE: -```{{code_lang}} -{{selected_code}} -``` -{%- else -%} -CODE: -```{{code_lang}} -{{file_content}} -``` -{%- endif %}{%- if project_context -%} -PROJECT CONTEXT: -{{project_context}} -{%- endif %}{%- if memory_context -%} -MEMORY CONTEXT: -{{memory_context}} -{%- endif %}{%- endif %}{%- if ask %} -{%- if not use_xml_format %} - -INSTRUCTION: {% else %} +{%- if ask %} +{%- if not use_xml_format -%} +INSTRUCTION:{% else -%} {% endif -%} {% block user_prompt %} Your primary task is to suggest code modifications with precise line number ranges. Follow these instructions meticulously: