2024-08-17 15:14:30 +08:00
local fn = vim.fn
2024-08-15 19:04:15 +08:00
local curl = require ( " plenary.curl " )
2024-08-17 15:14:30 +08:00
local Utils = require ( " avante.utils " )
local Config = require ( " avante.config " )
local Tiktoken = require ( " avante.tiktoken " )
---@class avante.AiBot
local M = { }
2024-08-15 19:04:15 +08:00
local system_prompt = [ [
You are an excellent programming expert .
] ]
local base_user_prompt = [ [
Your primary task is to suggest code modifications with precise line number ranges . Follow these instructions meticulously :
1. Carefully analyze the original code , paying close attention to its structure and line numbers . Line numbers start from 1 and include ALL lines , even empty ones .
2. When suggesting modifications :
2024-08-17 22:29:05 +08:00
a . Use the language in the question to reply . If there are non - English parts in the question , use the language of those parts .
b . Explain why the change is necessary or beneficial .
c . Provide the exact code snippet to be replaced using this format :
2024-08-15 19:04:15 +08:00
Replace lines : { { start_line } } - { { end_line } }
` ` ` { { language } }
{ { suggested_code } }
` ` `
3. Crucial guidelines for line numbers :
- The range { { start_line } } - { { end_line } } is INCLUSIVE . Both start_line and end_line are included in the replacement .
- Count EVERY line , including empty lines , comments , and the LAST line of the file .
- For single - line changes , use the same number for start and end lines .
- For multi - line changes , ensure the range covers ALL affected lines , from the very first to the very last .
- Include the entire block ( e.g . , complete function ) when modifying structured code .
- Pay special attention to the start_line , ensuring it ' s not omitted or incorrectly set.
- Double - check that your start_line is correct , especially for changes at the beginning of the file .
- Also , be careful with the end_line , especially when it ' s the last line of the file.
- Double - check that your line numbers align perfectly with the original code structure .
4. Context and verification :
- Show 1 - 2 unchanged lines before and after each modification as context .
- These context lines are NOT included in the replacement range .
- After each suggestion , recount the lines to verify the accuracy of your line numbers .
- Double - check that both the start_line and end_line are correct for each modification .
- Verify that your suggested changes align perfectly with the original code structure .
5. Final check :
- Review all suggestions , ensuring each line number is correct , especially the start_line and end_line .
- Pay extra attention to the start_line of each modification , ensuring it hasn ' t shifted down.
- Confirm that no unrelated code is accidentally modified or deleted .
- Verify that the start_line and end_line correctly include all intended lines for replacement .
- If a modification involves the first or last line of the file , explicitly state this in your explanation .
- Perform a final alignment check to ensure your line numbers haven ' t shifted, especially the start_line.
- Double - check that your line numbers align perfectly with the original code structure .
- Do not show the content after these modifications .
Remember : Accurate line numbers are CRITICAL . The range start_line to end_line must include ALL lines to be replaced , from the very first to the very last . Double - check every range before finalizing your response , paying special attention to the start_line to ensure it hasn ' t shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift.
] ]
2024-08-17 22:29:05 +08:00
local function call_claude_api_stream ( question , code_lang , code_content , selected_code_content , on_chunk , on_complete )
2024-08-15 19:04:15 +08:00
local api_key = os.getenv ( " ANTHROPIC_API_KEY " )
if not api_key then
error ( " ANTHROPIC_API_KEY environment variable is not set " )
end
2024-08-17 15:14:30 +08:00
local tokens = Config.claude . max_tokens
2024-08-15 19:04:15 +08:00
local headers = {
[ " Content-Type " ] = " application/json " ,
[ " x-api-key " ] = api_key ,
[ " anthropic-version " ] = " 2023-06-01 " ,
[ " anthropic-beta " ] = " prompt-caching-2024-07-31 " ,
}
local code_prompt_obj = {
type = " text " ,
text = string.format ( " <code>```%s \n %s```</code> " , code_lang , code_content ) ,
}
2024-08-17 22:29:05 +08:00
if Tiktoken.count ( code_prompt_obj.text ) > 1024 then
code_prompt_obj.cache_control = { type = " ephemeral " }
end
if selected_code_content then
code_prompt_obj.text = string.format ( " <code_context>```%s \n %s```</code_context> " , code_lang , code_content )
end
local message_content = {
code_prompt_obj ,
}
if selected_code_content then
local selected_code_obj = {
type = " text " ,
text = string.format ( " <code>```%s \n %s```</code> " , code_lang , selected_code_content ) ,
}
if Tiktoken.count ( selected_code_obj.text ) > 1024 then
selected_code_obj.cache_control = { type = " ephemeral " }
end
table.insert ( message_content , selected_code_obj )
end
table.insert ( message_content , {
type = " text " ,
text = string.format ( " <question>%s</question> " , question ) ,
} )
local user_prompt = base_user_prompt
2024-08-15 19:04:15 +08:00
local user_prompt_obj = {
type = " text " ,
text = user_prompt ,
}
2024-08-17 15:14:30 +08:00
if Tiktoken.count ( user_prompt_obj.text ) > 1024 then
2024-08-15 19:04:15 +08:00
user_prompt_obj.cache_control = { type = " ephemeral " }
end
2024-08-17 22:29:05 +08:00
table.insert ( message_content , user_prompt_obj )
2024-08-15 19:04:15 +08:00
local body = {
2024-08-17 15:14:30 +08:00
model = Config.claude . model ,
2024-08-15 19:04:15 +08:00
system = system_prompt ,
messages = {
{
role = " user " ,
2024-08-17 22:29:05 +08:00
content = message_content ,
2024-08-15 19:04:15 +08:00
} ,
} ,
stream = true ,
2024-08-17 15:14:30 +08:00
temperature = Config.claude . temperature ,
2024-08-15 19:04:15 +08:00
max_tokens = tokens ,
}
2024-08-17 15:14:30 +08:00
local url = Utils.trim_suffix ( Config.claude . endpoint , " / " ) .. " /v1/messages "
2024-08-15 19:04:15 +08:00
2024-08-17 15:14:30 +08:00
-- print("Sending request to Claude API...")
2024-08-15 19:04:15 +08:00
curl.post ( url , {
---@diagnostic disable-next-line: unused-local
stream = function ( err , data , job )
if err then
on_complete ( err )
return
end
if not data then
return
end
2024-08-17 22:56:25 +08:00
for _ , line in ipairs ( vim.split ( data , " \n " ) ) do
2024-08-15 19:04:15 +08:00
if line : sub ( 1 , 6 ) ~= " data: " then
return
end
vim.schedule ( function ( )
local success , parsed = pcall ( fn.json_decode , line : sub ( 7 ) )
if not success then
error ( " Error: failed to parse json: " .. parsed )
return
end
if parsed and parsed.type == " content_block_delta " then
on_chunk ( parsed.delta . text )
elseif parsed and parsed.type == " message_stop " then
-- Stream request completed
on_complete ( nil )
elseif parsed and parsed.type == " error " then
-- Stream request completed
on_complete ( parsed )
end
end )
end
end ,
headers = headers ,
body = fn.json_encode ( body ) ,
} )
end
2024-08-17 22:29:05 +08:00
local function call_openai_api_stream ( question , code_lang , code_content , selected_code_content , on_chunk , on_complete )
2024-08-15 19:04:15 +08:00
local api_key = os.getenv ( " OPENAI_API_KEY " )
2024-08-17 15:14:30 +08:00
if not api_key and Config.provider == " openai " then
2024-08-15 19:04:15 +08:00
error ( " OPENAI_API_KEY environment variable is not set " )
end
local user_prompt = base_user_prompt
.. " \n \n CODE: \n "
.. " ``` "
.. code_lang
.. " \n "
.. code_content
.. " \n ``` "
2024-08-17 22:29:05 +08:00
.. " \n \n QUESTION: \n "
.. question
if selected_code_content then
user_prompt = base_user_prompt
.. " \n \n CODE CONTEXT: \n "
.. " ``` "
.. code_lang
.. " \n "
.. code_content
.. " \n ``` "
.. " \n \n CODE: \n "
.. " ``` "
.. code_lang
.. " \n "
.. selected_code_content
.. " \n ``` "
.. " \n \n QUESTION: \n "
.. question
end
2024-08-15 19:04:15 +08:00
local url , headers , body
2024-08-17 15:14:30 +08:00
if Config.provider == " azure " then
2024-08-15 19:04:15 +08:00
api_key = os.getenv ( " AZURE_OPENAI_API_KEY " ) or os.getenv ( " OPENAI_API_KEY " )
if not api_key then
error ( " Azure OpenAI API key is not set. Please set AZURE_OPENAI_API_KEY or OPENAI_API_KEY environment variable. " )
end
2024-08-17 15:14:30 +08:00
url = Config.azure . endpoint
2024-08-15 19:04:15 +08:00
.. " /openai/deployments/ "
2024-08-17 15:14:30 +08:00
.. Config.azure . deployment
2024-08-15 19:04:15 +08:00
.. " /chat/completions?api-version= "
2024-08-17 15:14:30 +08:00
.. Config.azure . api_version
2024-08-15 19:04:15 +08:00
headers = {
[ " Content-Type " ] = " application/json " ,
[ " api-key " ] = api_key ,
}
body = {
messages = {
{ role = " system " , content = system_prompt } ,
{ role = " user " , content = user_prompt } ,
} ,
2024-08-17 15:14:30 +08:00
temperature = Config.azure . temperature ,
max_tokens = Config.azure . max_tokens ,
2024-08-15 19:04:15 +08:00
stream = true ,
}
else
2024-08-17 15:14:30 +08:00
url = Utils.trim_suffix ( Config.openai . endpoint , " / " ) .. " /v1/chat/completions "
2024-08-15 19:04:15 +08:00
headers = {
[ " Content-Type " ] = " application/json " ,
[ " Authorization " ] = " Bearer " .. api_key ,
}
body = {
2024-08-17 15:14:30 +08:00
model = Config.openai . model ,
2024-08-15 19:04:15 +08:00
messages = {
{ role = " system " , content = system_prompt } ,
{ role = " user " , content = user_prompt } ,
} ,
2024-08-17 15:14:30 +08:00
temperature = Config.openai . temperature ,
max_tokens = Config.openai . max_tokens ,
2024-08-15 19:04:15 +08:00
stream = true ,
}
end
2024-08-17 15:14:30 +08:00
-- print("Sending request to " .. (config.get().provider == "azure" and "Azure OpenAI" or "OpenAI") .. " API...")
2024-08-15 19:04:15 +08:00
curl.post ( url , {
---@diagnostic disable-next-line: unused-local
stream = function ( err , data , job )
if err then
on_complete ( err )
return
end
if not data then
return
end
2024-08-17 22:56:25 +08:00
for _ , line in ipairs ( vim.split ( data , " \n " ) ) do
2024-08-15 19:04:15 +08:00
if line : sub ( 1 , 6 ) ~= " data: " then
return
end
vim.schedule ( function ( )
local piece = line : sub ( 7 )
local success , parsed = pcall ( fn.json_decode , piece )
if not success then
if piece == " [DONE] " then
on_complete ( nil )
return
end
error ( " Error: failed to parse json: " .. parsed )
return
end
2024-08-16 13:40:41 +08:00
if parsed and parsed.choices and parsed.choices [ 1 ] then
local choice = parsed.choices [ 1 ]
if choice.finish_reason == " stop " then
on_complete ( nil )
elseif choice.delta and choice.delta . content then
on_chunk ( choice.delta . content )
end
2024-08-15 19:04:15 +08:00
end
end )
end
end ,
headers = headers ,
body = fn.json_encode ( body ) ,
} )
end
2024-08-17 15:14:30 +08:00
---@param question string
---@param code_lang string
---@param code_content string
2024-08-17 22:29:05 +08:00
---@param selected_content_content string | nil
2024-08-17 15:14:30 +08:00
---@param on_chunk fun(chunk: string): any
---@param on_complete fun(err: string|nil): any
2024-08-17 22:29:05 +08:00
function M . call_ai_api_stream ( question , code_lang , code_content , selected_content_content , on_chunk , on_complete )
2024-08-17 15:14:30 +08:00
if Config.provider == " openai " or Config.provider == " azure " then
2024-08-17 22:29:05 +08:00
call_openai_api_stream ( question , code_lang , code_content , selected_content_content , on_chunk , on_complete )
2024-08-17 15:14:30 +08:00
elseif Config.provider == " claude " then
2024-08-17 22:29:05 +08:00
call_claude_api_stream ( question , code_lang , code_content , selected_content_content , on_chunk , on_complete )
2024-08-15 19:04:15 +08:00
end
end
return M