From 0b6a85ee649dcfc1dece430b6ceeed4f1dcfa6c9 Mon Sep 17 00:00:00 2001
From: yetone <yetoneful@gmail.com>
Date: Thu, 15 Aug 2024 11:42:56 +0800
Subject: [PATCH] feat: enable prompt caching for the Anthropic API

---
 lua/avante/init.lua     | 188 ++++++++++++++++++++++++++--------------
 lua/avante/tiktoken.lua | 103 ++++++++++++++++++++++
 2 files changed, 226 insertions(+), 65 deletions(-)
 create mode 100644 lua/avante/tiktoken.lua

diff --git a/lua/avante/init.lua b/lua/avante/init.lua
index 01350fe..8631983 100644
--- a/lua/avante/init.lua
+++ b/lua/avante/init.lua
@@ -4,6 +4,7 @@ local Path = require("plenary.path")
 local n = require("nui-components")
 local diff = require("avante.diff")
 local utils = require("avante.utils")
+local tiktoken = require("avante.tiktoken")
 local api = vim.api
 local fn = vim.fn
 
@@ -140,7 +141,7 @@ local system_prompt = [[
 You are an excellent programming expert.
 ]]
 
-local user_prompt_tpl = [[
+local base_user_prompt = [[
 Your primary task is to suggest code modifications with precise line number ranges. Follow these instructions meticulously:
 
 1. Carefully analyze the original code, paying close attention to its structure and line numbers. Line numbers start from 1 and include ALL lines, even empty ones.
@@ -183,87 +184,119 @@ Replace lines: {{start_line}}-{{end_line}}
    - Do not show the content after these modifications.
 
 Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift.
-
-QUESTION: ${{question}}
-
-CODE:
-```
-${{code}}
-```
 ]]
 
-local function call_claude_api_stream(prompt, original_content, on_chunk, on_complete)
+local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete)
   local api_key = os.getenv("ANTHROPIC_API_KEY")
   if not api_key then
     error("ANTHROPIC_API_KEY environment variable is not set")
   end
 
-  local user_prompt = user_prompt_tpl:gsub("${{question}}", prompt):gsub("${{code}}", original_content)
+  local user_prompt = base_user_prompt
 
-  print("Sending request to Claude API...")
-
-  local tokens = M.config.claude.model == "claude-3-5-sonnet-20240620" and 8192 or 4096
+  local tokens = M.config.claude.max_tokens
   local headers = {
     ["Content-Type"] = "application/json",
     ["x-api-key"] = api_key,
     ["anthropic-version"] = "2023-06-01",
-    ["anthropic-beta"] = "messages-2023-12-15",
+    ["anthropic-beta"] = "prompt-caching-2024-07-31",
   }
 
-  if M.config.claude.model == "claude-3-5-sonnet-20240620" then
-    headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
+  local code_prompt_obj = {
+    type = "text",
+    text = string.format("<code>```%s\n%s```</code>", code_lang, code_content),
+  }
+
+  local user_prompt_obj = {
+    type = "text",
+    text = user_prompt,
+  }
+
+  if tiktoken.count(code_prompt_obj.text) > 1024 then
+    code_prompt_obj.cache_control = { type = "ephemeral" }
   end
 
+  if tiktoken.count(user_prompt_obj.text) > 1024 then
+    user_prompt_obj.cache_control = { type = "ephemeral" }
+  end
+
+  local params = {
+    model = M.config.claude.model,
+    system = system_prompt,
+    messages = {
+      {
+        role = "user",
+        content = {
+          code_prompt_obj,
+          {
+            type = "text",
+            text = string.format("<question>%s</question>", question),
+          },
+          user_prompt_obj,
+        },
+      },
+    },
+    stream = true,
+    temperature = M.config.claude.temperature,
+    max_tokens = tokens,
+  }
+
   local url = utils.trim_suffix(M.config.claude.endpoint, "/") .. "/v1/messages"
 
+  print("Sending request to Claude API...")
+
   curl.post(url, {
     ---@diagnostic disable-next-line: unused-local
     stream = function(err, data, job)
       if err then
-        error("Error: " .. vim.inspect(err))
+        on_complete(err)
         return
       end
-      if data then
-        for line in data:gmatch("[^\r\n]+") do
-          if line:sub(1, 6) == "data: " then
-            vim.schedule(function()
-              local success, parsed = pcall(fn.json_decode, line:sub(7))
-              if success and parsed and parsed.type == "content_block_delta" then
-                on_chunk(parsed.delta.text)
-              elseif success and parsed and parsed.type == "message_stop" then
-                -- Stream request completed
-                on_complete()
-              elseif success and parsed and parsed.type == "error" then
-                print("Error: " .. vim.inspect(parsed))
-                -- Stream request completed
-                on_complete()
-              end
-            end)
-          end
+      if not data then
+        return
+      end
+      for line in data:gmatch("[^\r\n]+") do
+        if line:sub(1, 6) ~= "data: " then
+          return
         end
+        vim.schedule(function()
+          local success, parsed = pcall(fn.json_decode, line:sub(7))
+          if not success then
+            error("Error: failed to parse json: " .. parsed)
+            return
+          end
+          if parsed and parsed.type == "content_block_delta" then
+            on_chunk(parsed.delta.text)
+          elseif parsed and parsed.type == "message_stop" then
+            -- Stream request completed
+            on_complete(nil)
+          elseif parsed and parsed.type == "error" then
+            -- Stream request completed
+            on_complete(parsed)
+          end
+        end)
       end
     end,
     headers = headers,
-    body = fn.json_encode({
-      model = M.config.claude.model,
-      system = system_prompt,
-      messages = {
-        { role = "user", content = user_prompt },
-      },
-      stream = true,
-      temperature = M.config.claude.temperature,
-      max_tokens = tokens,
-    }),
+    body = fn.json_encode(params),
   })
 end
 
-local function call_openai_api_stream(prompt, original_content, on_chunk, on_complete)
+local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
   local api_key = os.getenv("OPENAI_API_KEY")
   if not api_key then
     error("OPENAI_API_KEY environment variable is not set")
   end
 
-  local user_prompt = user_prompt_tpl:gsub("${{question}}", prompt):gsub("${{code}}", original_content)
+  local user_prompt = base_user_prompt
+    .. "\n\nQUESTION:\n"
+    .. question
+    .. "\n\nCODE:\n"
+    .. "```"
+    .. code_lang
+    .. "\n"
+    .. code_content
+    .. "\n```"
 
   local url = utils.trim_suffix(M.config.openai.endpoint, "/") .. "/v1/chat/completions"
   if M.config.provider == "azure" then
@@ -276,23 +309,29 @@ local function call_openai_api_stream(prompt, original_content, on_chunk, on_com
     ---@diagnostic disable-next-line: unused-local
     stream = function(err, data, job)
       if err then
-        error("Error: " .. vim.inspect(err))
+        on_complete(err)
         return
       end
-      if data then
-        for line in data:gmatch("[^\r\n]+") do
-          if line:sub(1, 6) == "data: " then
-            vim.schedule(function()
-              local success, parsed = pcall(fn.json_decode, line:sub(7))
-              if success and parsed and parsed.choices and parsed.choices[1].delta.content then
-                on_chunk(parsed.choices[1].delta.content)
-              elseif success and parsed and parsed.choices and parsed.choices[1].finish_reason == "stop" then
-                -- Stream request completed
-                on_complete()
-              end
-            end)
-          end
+      if not data then
+        return
+      end
+      for line in data:gmatch("[^\r\n]+") do
+        if line:sub(1, 6) ~= "data: " then
+          return
         end
+        vim.schedule(function()
+          local success, parsed = pcall(fn.json_decode, line:sub(7))
+          if not success then
+            error("Error: failed to parse json: " .. parsed)
+            return
+          end
+          if parsed and parsed.choices and parsed.choices[1].delta.content then
+            on_chunk(parsed.choices[1].delta.content)
+          elseif parsed and parsed.choices and parsed.choices[1].finish_reason == "stop" then
+            -- Stream request completed
+            on_complete(nil)
+          end
+        end)
       end
     end,
     headers = {
@@ -313,11 +352,11 @@ local function call_openai_api_stream(prompt, original_content, on_chunk, on_com
   })
 end
 
-local function call_ai_api_stream(prompt, original_content, on_chunk, on_complete)
+local function call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
   if M.config.provider == "openai" or M.config.provider == "azure" then
-    call_openai_api_stream(prompt, original_content, on_chunk, on_complete)
+    call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
   elseif M.config.provider == "claude" then
-    call_claude_api_stream(prompt, original_content, on_chunk, on_complete)
+    call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete)
   end
 end
 
@@ -522,7 +561,9 @@ function M.render_sidebar()
 
     signal.is_loading = true
 
-    call_ai_api_stream(user_input, content_with_line_numbers, function(chunk)
+    local filetype = api.nvim_get_option_value("filetype", { buf = code_buf })
+
+    call_ai_api_stream(user_input, filetype, content_with_line_numbers, function(chunk)
       full_response = full_response .. chunk
       update_result_buf_content(
         "## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response
@@ -530,8 +571,23 @@ function M.render_sidebar()
       vim.schedule(function()
         vim.cmd("redraw")
       end)
-    end, function()
+    end, function(err)
       signal.is_loading = false
+
+      if err ~= nil then
+        update_result_buf_content(
+          "## "
+            .. timestamp
+            .. "\n\n> "
+            .. user_input:gsub("\n", "\n> ")
+            .. "\n\n"
+            .. full_response
+            .. "\n\n**Error**: "
+            .. vim.inspect(err)
+        )
+        return
+      end
+
       -- Execute when the stream request is actually completed
       update_result_buf_content(
         "## "
@@ -687,6 +743,8 @@ function M.setup(opts)
     _cur_code_buf = bufnr
   end
 
+  tiktoken.setup("gpt-4o")
+
   diff.setup({
     debug = false, -- log output to console
     default_mappings = M.config.mappings.diff, -- disable buffer local mapping created by this plugin
diff --git a/lua/avante/tiktoken.lua b/lua/avante/tiktoken.lua
new file mode 100644
index 0000000..7792662
--- /dev/null
+++ b/lua/avante/tiktoken.lua
@@ -0,0 +1,103 @@
+-- NOTE: this file is copied from: https://github.com/CopilotC-Nvim/CopilotChat.nvim/blob/canary/lua/CopilotChat/tiktoken.lua
+
+local curl = require("plenary.curl")
+local tiktoken_core = nil
+
+---Get the path of the cache directory
+---@param fname string
+---@return string
+local function get_cache_path(fname)
+  return vim.fn.stdpath("cache") .. "/" .. fname
+end
+
+local function file_exists(name)
+  local f = io.open(name, "r")
+  if f ~= nil then
+    io.close(f)
+    return true
+  else
+    return false
+  end
+end
+
+--- Load tiktoken data from cache or download it
+local function load_tiktoken_data(done, model)
+  local tiktoken_url = "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"
+  -- If model is gpt-4o, use o200k_base.tiktoken
+  if model ~= nil and vim.startswith(model, "gpt-4o") then
+    tiktoken_url = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken"
+  end
+  local async
+  async = vim.loop.new_async(function()
+    -- Take filename after the last slash of the url
+    local cache_path = get_cache_path(tiktoken_url:match(".+/(.+)"))
+    if not file_exists(cache_path) then
+      vim.schedule(function()
+        curl.get(tiktoken_url, {
+          output = cache_path,
+        })
+        done(cache_path)
+      end)
+    else
+      done(cache_path)
+    end
+    async:close()
+  end)
+  async:send()
+end
+
+local M = {}
+
+---@param model string|nil
+function M.setup(model)
+  local ok, core = pcall(require, "tiktoken_core")
+  if not ok then
+    print("Warn: tiktoken_core is not found!!!!")
+    return
+  end
+
+  load_tiktoken_data(function(path)
+    local special_tokens = {}
+    special_tokens["<|endoftext|>"] = 100257
+    special_tokens["<|fim_prefix|>"] = 100258
+    special_tokens["<|fim_middle|>"] = 100259
+    special_tokens["<|fim_suffix|>"] = 100260
+    special_tokens["<|endofprompt|>"] = 100276
+    local pat_str =
+      "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+    core.new(path, special_tokens, pat_str)
+    tiktoken_core = core
+  end, model)
+end
+
+function M.available()
+  return tiktoken_core ~= nil
+end
+
+function M.encode(prompt)
+  if not tiktoken_core then
+    return nil
+  end
+  if not prompt or prompt == "" then
+    return nil
+  end
+  -- Check if prompt is a string
+  if type(prompt) ~= "string" then
+    error("Prompt must be a string")
+  end
+  return tiktoken_core.encode(prompt)
+end
+
+function M.count(prompt)
+  if not tiktoken_core then
+    return math.ceil(#prompt * 0.2) -- Fallback to 0.2 character count
+  end
+
+  local tokens = M.encode(prompt)
+  if not tokens then
+    return 0
+  end
+  return #tokens
+end
+
+return M