feat: enable prompt caching for the Anthropic API
This commit is contained in:
		
							parent
							
								
									036a0b5f1d
								
							
						
					
					
						commit
						0b6a85ee64
					
				| @ -4,6 +4,7 @@ local Path = require("plenary.path") | ||||
| local n = require("nui-components") | ||||
| local diff = require("avante.diff") | ||||
| local utils = require("avante.utils") | ||||
| local tiktoken = require("avante.tiktoken") | ||||
| local api = vim.api | ||||
| local fn = vim.fn | ||||
| 
 | ||||
| @ -140,7 +141,7 @@ local system_prompt = [[ | ||||
| You are an excellent programming expert. | ||||
| ]] | ||||
| 
 | ||||
| local user_prompt_tpl = [[ | ||||
| local base_user_prompt = [[ | ||||
| Your primary task is to suggest code modifications with precise line number ranges. Follow these instructions meticulously: | ||||
| 
 | ||||
| 1. Carefully analyze the original code, paying close attention to its structure and line numbers. Line numbers start from 1 and include ALL lines, even empty ones. | ||||
| @ -183,87 +184,119 @@ Replace lines: {{start_line}}-{{end_line}} | ||||
|    - Do not show the content after these modifications. | ||||
| 
 | ||||
| Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift. | ||||
| 
 | ||||
| QUESTION: ${{question}} | ||||
| 
 | ||||
| CODE: | ||||
| ``` | ||||
| ${{code}} | ||||
| ``` | ||||
| ]] | ||||
| 
 | ||||
| local function call_claude_api_stream(prompt, original_content, on_chunk, on_complete) | ||||
| local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) | ||||
|   local api_key = os.getenv("ANTHROPIC_API_KEY") | ||||
|   if not api_key then | ||||
|     error("ANTHROPIC_API_KEY environment variable is not set") | ||||
|   end | ||||
| 
 | ||||
|   local user_prompt = user_prompt_tpl:gsub("${{question}}", prompt):gsub("${{code}}", original_content) | ||||
|   local user_prompt = base_user_prompt | ||||
| 
 | ||||
|   print("Sending request to Claude API...") | ||||
| 
 | ||||
|   local tokens = M.config.claude.model == "claude-3-5-sonnet-20240620" and 8192 or 4096 | ||||
|   local tokens = M.config.claude.max_tokens | ||||
|   local headers = { | ||||
|     ["Content-Type"] = "application/json", | ||||
|     ["x-api-key"] = api_key, | ||||
|     ["anthropic-version"] = "2023-06-01", | ||||
|     ["anthropic-beta"] = "messages-2023-12-15", | ||||
|     ["anthropic-beta"] = "prompt-caching-2024-07-31", | ||||
|   } | ||||
| 
 | ||||
|   if M.config.claude.model == "claude-3-5-sonnet-20240620" then | ||||
|     headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" | ||||
|   local code_prompt_obj = { | ||||
|     type = "text", | ||||
|     text = string.format("<code>```%s\n%s```</code>", code_lang, code_content), | ||||
|   } | ||||
| 
 | ||||
|   local user_prompt_obj = { | ||||
|     type = "text", | ||||
|     text = user_prompt, | ||||
|   } | ||||
| 
 | ||||
|   if tiktoken.count(code_prompt_obj.text) > 1024 then | ||||
|     code_prompt_obj.cache_control = { type = "ephemeral" } | ||||
|   end | ||||
| 
 | ||||
|   if tiktoken.count(user_prompt_obj.text) > 1024 then | ||||
|     user_prompt_obj.cache_control = { type = "ephemeral" } | ||||
|   end | ||||
| 
 | ||||
|   local params = { | ||||
|     model = M.config.claude.model, | ||||
|     system = system_prompt, | ||||
|     messages = { | ||||
|       { | ||||
|         role = "user", | ||||
|         content = { | ||||
|           code_prompt_obj, | ||||
|           { | ||||
|             type = "text", | ||||
|             text = string.format("<question>%s</question>", question), | ||||
|           }, | ||||
|           user_prompt_obj, | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|     stream = true, | ||||
|     temperature = M.config.claude.temperature, | ||||
|     max_tokens = tokens, | ||||
|   } | ||||
| 
 | ||||
|   local url = utils.trim_suffix(M.config.claude.endpoint, "/") .. "/v1/messages" | ||||
| 
 | ||||
|   print("Sending request to Claude API...") | ||||
| 
 | ||||
|   curl.post(url, { | ||||
|     ---@diagnostic disable-next-line: unused-local | ||||
|     stream = function(err, data, job) | ||||
|       if err then | ||||
|         error("Error: " .. vim.inspect(err)) | ||||
|         on_complete(err) | ||||
|         return | ||||
|       end | ||||
|       if not data then | ||||
|         return | ||||
|       end | ||||
|       if data then | ||||
|       for line in data:gmatch("[^\r\n]+") do | ||||
|           if line:sub(1, 6) == "data: " then | ||||
|         if line:sub(1, 6) ~= "data: " then | ||||
|           return | ||||
|         end | ||||
|         vim.schedule(function() | ||||
|           local success, parsed = pcall(fn.json_decode, line:sub(7)) | ||||
|               if success and parsed and parsed.type == "content_block_delta" then | ||||
|           if not success then | ||||
|             error("Error: failed to parse json: " .. parsed) | ||||
|             return | ||||
|           end | ||||
|           if parsed and parsed.type == "content_block_delta" then | ||||
|             on_chunk(parsed.delta.text) | ||||
|               elseif success and parsed and parsed.type == "message_stop" then | ||||
|           elseif parsed and parsed.type == "message_stop" then | ||||
|             -- Stream request completed | ||||
|                 on_complete() | ||||
|               elseif success and parsed and parsed.type == "error" then | ||||
|                 print("Error: " .. vim.inspect(parsed)) | ||||
|             on_complete(nil) | ||||
|           elseif parsed and parsed.type == "error" then | ||||
|             -- Stream request completed | ||||
|                 on_complete() | ||||
|             on_complete(parsed) | ||||
|           end | ||||
|         end) | ||||
|       end | ||||
|         end | ||||
|       end | ||||
|     end, | ||||
|     headers = headers, | ||||
|     body = fn.json_encode({ | ||||
|       model = M.config.claude.model, | ||||
|       system = system_prompt, | ||||
|       messages = { | ||||
|         { role = "user", content = user_prompt }, | ||||
|       }, | ||||
|       stream = true, | ||||
|       temperature = M.config.claude.temperature, | ||||
|       max_tokens = tokens, | ||||
|     }), | ||||
|     body = fn.json_encode(params), | ||||
|   }) | ||||
| end | ||||
| 
 | ||||
| local function call_openai_api_stream(prompt, original_content, on_chunk, on_complete) | ||||
| local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) | ||||
|   local api_key = os.getenv("OPENAI_API_KEY") | ||||
|   if not api_key then | ||||
|     error("OPENAI_API_KEY environment variable is not set") | ||||
|   end | ||||
| 
 | ||||
|   local user_prompt = user_prompt_tpl:gsub("${{question}}", prompt):gsub("${{code}}", original_content) | ||||
|   local user_prompt = base_user_prompt | ||||
|     .. "\n\nQUESTION:\n" | ||||
|     .. question | ||||
|     .. "\n\nCODE:\n" | ||||
|     .. "```" | ||||
|     .. code_lang | ||||
|     .. "\n" | ||||
|     .. code_content | ||||
|     .. "\n```" | ||||
| 
 | ||||
|   local url = utils.trim_suffix(M.config.openai.endpoint, "/") .. "/v1/chat/completions" | ||||
|   if M.config.provider == "azure" then | ||||
| @ -276,24 +309,30 @@ local function call_openai_api_stream(prompt, original_content, on_chunk, on_com | ||||
|     ---@diagnostic disable-next-line: unused-local | ||||
|     stream = function(err, data, job) | ||||
|       if err then | ||||
|         error("Error: " .. vim.inspect(err)) | ||||
|         on_complete(err) | ||||
|         return | ||||
|       end | ||||
|       if not data then | ||||
|         return | ||||
|       end | ||||
|       if data then | ||||
|       for line in data:gmatch("[^\r\n]+") do | ||||
|           if line:sub(1, 6) == "data: " then | ||||
|         if line:sub(1, 6) ~= "data: " then | ||||
|           return | ||||
|         end | ||||
|         vim.schedule(function() | ||||
|           local success, parsed = pcall(fn.json_decode, line:sub(7)) | ||||
|               if success and parsed and parsed.choices and parsed.choices[1].delta.content then | ||||
|           if not success then | ||||
|             error("Error: failed to parse json: " .. parsed) | ||||
|             return | ||||
|           end | ||||
|           if parsed and parsed.choices and parsed.choices[1].delta.content then | ||||
|             on_chunk(parsed.choices[1].delta.content) | ||||
|               elseif success and parsed and parsed.choices and parsed.choices[1].finish_reason == "stop" then | ||||
|           elseif parsed and parsed.choices and parsed.choices[1].finish_reason == "stop" then | ||||
|             -- Stream request completed | ||||
|                 on_complete() | ||||
|             on_complete(nil) | ||||
|           end | ||||
|         end) | ||||
|       end | ||||
|         end | ||||
|       end | ||||
|     end, | ||||
|     headers = { | ||||
|       ["Content-Type"] = "application/json", | ||||
| @ -313,11 +352,11 @@ local function call_openai_api_stream(prompt, original_content, on_chunk, on_com | ||||
|   }) | ||||
| end | ||||
| 
 | ||||
| local function call_ai_api_stream(prompt, original_content, on_chunk, on_complete) | ||||
| local function call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete) | ||||
|   if M.config.provider == "openai" or M.config.provider == "azure" then | ||||
|     call_openai_api_stream(prompt, original_content, on_chunk, on_complete) | ||||
|     call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) | ||||
|   elseif M.config.provider == "claude" then | ||||
|     call_claude_api_stream(prompt, original_content, on_chunk, on_complete) | ||||
|     call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) | ||||
|   end | ||||
| end | ||||
| 
 | ||||
| @ -522,7 +561,9 @@ function M.render_sidebar() | ||||
| 
 | ||||
|     signal.is_loading = true | ||||
| 
 | ||||
|     call_ai_api_stream(user_input, content_with_line_numbers, function(chunk) | ||||
|     local filetype = api.nvim_get_option_value("filetype", { buf = code_buf }) | ||||
| 
 | ||||
|     call_ai_api_stream(user_input, filetype, content_with_line_numbers, function(chunk) | ||||
|       full_response = full_response .. chunk | ||||
|       update_result_buf_content( | ||||
|         "## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response | ||||
| @ -530,8 +571,23 @@ function M.render_sidebar() | ||||
|       vim.schedule(function() | ||||
|         vim.cmd("redraw") | ||||
|       end) | ||||
|     end, function() | ||||
|     end, function(err) | ||||
|       signal.is_loading = false | ||||
| 
 | ||||
|       if err ~= nil then | ||||
|         update_result_buf_content( | ||||
|           "## " | ||||
|             .. timestamp | ||||
|             .. "\n\n> " | ||||
|             .. user_input:gsub("\n", "\n> ") | ||||
|             .. "\n\n" | ||||
|             .. full_response | ||||
|             .. "\n\n**Error**: " | ||||
|             .. vim.inspect(err) | ||||
|         ) | ||||
|         return | ||||
|       end | ||||
| 
 | ||||
|       -- Execute when the stream request is actually completed | ||||
|       update_result_buf_content( | ||||
|         "## " | ||||
| @ -687,6 +743,8 @@ function M.setup(opts) | ||||
|     _cur_code_buf = bufnr | ||||
|   end | ||||
| 
 | ||||
|   tiktoken.setup("gpt-4o") | ||||
| 
 | ||||
|   diff.setup({ | ||||
|     debug = false, -- log output to console | ||||
|     default_mappings = M.config.mappings.diff, -- disable buffer local mapping created by this plugin | ||||
|  | ||||
							
								
								
									
										103
									
								
								lua/avante/tiktoken.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								lua/avante/tiktoken.lua
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,103 @@ | ||||
| -- NOTE: this file is copied from: https://github.com/CopilotC-Nvim/CopilotChat.nvim/blob/canary/lua/CopilotChat/tiktoken.lua | ||||
| 
 | ||||
| local curl = require("plenary.curl") | ||||
| local tiktoken_core = nil | ||||
| 
 | ||||
| ---Get the path of the cache directory | ||||
| ---@param fname string | ||||
| ---@return string | ||||
| local function get_cache_path(fname) | ||||
|   return vim.fn.stdpath("cache") .. "/" .. fname | ||||
| end | ||||
| 
 | ||||
| local function file_exists(name) | ||||
|   local f = io.open(name, "r") | ||||
|   if f ~= nil then | ||||
|     io.close(f) | ||||
|     return true | ||||
|   else | ||||
|     return false | ||||
|   end | ||||
| end | ||||
| 
 | ||||
| --- Load tiktoken data from cache or download it | ||||
| local function load_tiktoken_data(done, model) | ||||
|   local tiktoken_url = "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken" | ||||
|   -- If model is gpt-4o, use o200k_base.tiktoken | ||||
|   if model ~= nil and vim.startswith(model, "gpt-4o") then | ||||
|     tiktoken_url = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken" | ||||
|   end | ||||
|   local async | ||||
|   async = vim.loop.new_async(function() | ||||
|     -- Take filename after the last slash of the url | ||||
|     local cache_path = get_cache_path(tiktoken_url:match(".+/(.+)")) | ||||
|     if not file_exists(cache_path) then | ||||
|       vim.schedule(function() | ||||
|         curl.get(tiktoken_url, { | ||||
|           output = cache_path, | ||||
|         }) | ||||
|         done(cache_path) | ||||
|       end) | ||||
|     else | ||||
|       done(cache_path) | ||||
|     end | ||||
|     async:close() | ||||
|   end) | ||||
|   async:send() | ||||
| end | ||||
| 
 | ||||
| local M = {} | ||||
| 
 | ||||
| ---@param model string|nil | ||||
| function M.setup(model) | ||||
|   local ok, core = pcall(require, "tiktoken_core") | ||||
|   if not ok then | ||||
|     print("Warn: tiktoken_core is not found!!!!") | ||||
|     return | ||||
|   end | ||||
| 
 | ||||
|   load_tiktoken_data(function(path) | ||||
|     local special_tokens = {} | ||||
|     special_tokens["<|endoftext|>"] = 100257 | ||||
|     special_tokens["<|fim_prefix|>"] = 100258 | ||||
|     special_tokens["<|fim_middle|>"] = 100259 | ||||
|     special_tokens["<|fim_suffix|>"] = 100260 | ||||
|     special_tokens["<|endofprompt|>"] = 100276 | ||||
|     local pat_str = | ||||
|       "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" | ||||
|     core.new(path, special_tokens, pat_str) | ||||
|     tiktoken_core = core | ||||
|   end, model) | ||||
| end | ||||
| 
 | ||||
| function M.available() | ||||
|   return tiktoken_core ~= nil | ||||
| end | ||||
| 
 | ||||
| function M.encode(prompt) | ||||
|   if not tiktoken_core then | ||||
|     return nil | ||||
|   end | ||||
|   if not prompt or prompt == "" then | ||||
|     return nil | ||||
|   end | ||||
|   -- Check if prompt is a string | ||||
|   if type(prompt) ~= "string" then | ||||
|     error("Prompt must be a string") | ||||
|   end | ||||
|   return tiktoken_core.encode(prompt) | ||||
| end | ||||
| 
 | ||||
| function M.count(prompt) | ||||
|   if not tiktoken_core then | ||||
|     return math.ceil(#prompt * 0.2) -- Fallback to 0.2 character count | ||||
|   end | ||||
| 
 | ||||
|   local tokens = M.encode(prompt) | ||||
|   if not tokens then | ||||
|     return 0 | ||||
|   end | ||||
|   return #tokens | ||||
| end | ||||
| 
 | ||||
| return M | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 yetone
						yetone