2024-08-17 15:14:30 +08:00
local fn = vim.fn
2024-08-17 16:04:40 -04:00
local api = vim.api
2024-08-15 19:04:15 +08:00
local curl = require ( " plenary.curl " )
2024-08-17 15:14:30 +08:00
local Utils = require ( " avante.utils " )
local Config = require ( " avante.config " )
local Tiktoken = require ( " avante.tiktoken " )
---@class avante.AiBot
local M = { }
2024-08-15 19:04:15 +08:00
2024-08-17 16:04:40 -04:00
---@class EnvironmentHandler: table<[Provider], string>
local E = {
2024-08-18 17:54:29 +08:00
---@type table<Provider, string>
2024-08-17 16:04:40 -04:00
env = {
openai = " OPENAI_API_KEY " ,
claude = " ANTHROPIC_API_KEY " ,
2024-08-18 17:54:29 +08:00
azure = " AZURE_OPENAI_API_KEY " ,
2024-08-18 21:33:45 +08:00
deepseek = " DEEPSEEK_API_KEY " ,
2024-08-17 16:04:40 -04:00
} ,
_once = false ,
}
E = setmetatable ( E , {
---@param k Provider
__index = function ( _ , k )
2024-08-18 17:54:29 +08:00
return os.getenv ( E.env [ k ] ) and true or false
2024-08-17 16:04:40 -04:00
end ,
} )
--- return the environment variable name for the given provider
---@param provider? Provider
---@return string the envvar key
E.key = function ( provider )
provider = provider or Config.provider
local var = E.env [ provider ]
return type ( var ) == " table " and var [ 1 ] ---@cast var string
or var
end
E.setup = function ( var )
2024-08-18 06:07:29 -04:00
local Dressing = require ( " avante.ui.dressing " )
2024-08-17 16:04:40 -04:00
if E._once then
return
end
2024-08-18 06:07:29 -04:00
---@param value string
---@return nil
local function on_confirm ( value )
if value then
2024-08-17 16:04:40 -04:00
vim.fn . setenv ( var , value )
2024-08-18 06:07:29 -04:00
E._once = true
else
2024-08-17 16:04:40 -04:00
if not E [ Config.provider ] then
vim.notify_once ( " Failed to set " .. var .. " . Avante won't work as expected " , vim.log . levels.WARN )
end
2024-08-18 06:07:29 -04:00
end
end
2024-08-17 16:04:40 -04:00
api.nvim_create_autocmd ( { " BufEnter " , " BufWinEnter " } , {
pattern = " * " ,
2024-08-18 06:07:29 -04:00
once = true ,
2024-08-17 16:04:40 -04:00
callback = function ( )
vim.defer_fn ( function ( )
-- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
local exclude_buftypes = { " dashboard " , " alpha " , " qf " , " nofile " }
2024-08-18 17:54:29 +08:00
local exclude_filetypes = {
" NvimTree " ,
" Outline " ,
" help " ,
" dashboard " ,
" alpha " ,
" qf " ,
" ministarter " ,
" TelescopePrompt " ,
" gitcommit " ,
}
2024-08-17 16:04:40 -04:00
if
not vim.tbl_contains ( exclude_buftypes , vim.bo . buftype )
and not vim.tbl_contains ( exclude_filetypes , vim.bo . filetype )
then
2024-08-18 06:07:29 -04:00
Dressing.initialize_input_buffer ( { opts = { prompt = " Enter " .. var .. " : " } , on_confirm = on_confirm } )
2024-08-17 16:04:40 -04:00
end
end , 200 )
end ,
} )
end
2024-08-15 19:04:15 +08:00
local system_prompt = [ [
You are an excellent programming expert .
] ]
local base_user_prompt = [ [
Your primary task is to suggest code modifications with precise line number ranges . Follow these instructions meticulously :
1. Carefully analyze the original code , paying close attention to its structure and line numbers . Line numbers start from 1 and include ALL lines , even empty ones .
2. When suggesting modifications :
2024-08-17 22:29:05 +08:00
a . Use the language in the question to reply . If there are non - English parts in the question , use the language of those parts .
b . Explain why the change is necessary or beneficial .
c . Provide the exact code snippet to be replaced using this format :
2024-08-15 19:04:15 +08:00
Replace lines : { { start_line } } - { { end_line } }
` ` ` { { language } }
{ { suggested_code } }
` ` `
2024-08-18 02:46:19 +08:00
3. Crucial guidelines for suggested code snippets :
- Only apply the change ( s ) suggested by the most recent assistant message ( before your generation ) .
- Do not make any unrelated changes to the code .
- Produce a valid full rewrite of the entire original file without skipping any lines . Do not be lazy !
- Do not arbitrarily delete pre - existing comments / empty Lines .
- Do not omit large parts of the original file for no reason .
- Do not omit any needed changes from the requisite messages / code blocks .
- If there is a clicked code block , bias towards just applying that ( and applying other changes implied ) .
2024-08-18 02:52:27 +08:00
- Please keep your suggested code changes minimal , and do not include irrelevant lines in the code snippet .
2024-08-18 02:46:19 +08:00
4. Crucial guidelines for line numbers :
2024-08-18 04:47:47 +08:00
- The content regarding line numbers must strictly follow the format " Replace lines: {{start_line}}-{{end_line}} " .
2024-08-15 19:04:15 +08:00
- The range { { start_line } } - { { end_line } } is INCLUSIVE . Both start_line and end_line are included in the replacement .
2024-08-18 04:47:47 +08:00
- Count EVERY line , including empty lines and comments lines , comments . Do not be lazy !
2024-08-15 19:04:15 +08:00
- For single - line changes , use the same number for start and end lines .
- For multi - line changes , ensure the range covers ALL affected lines , from the very first to the very last .
- Double - check that your line numbers align perfectly with the original code structure .
5. Final check :
- Review all suggestions , ensuring each line number is correct , especially the start_line and end_line .
- Confirm that no unrelated code is accidentally modified or deleted .
- Verify that the start_line and end_line correctly include all intended lines for replacement .
- Perform a final alignment check to ensure your line numbers haven ' t shifted, especially the start_line.
- Double - check that your line numbers align perfectly with the original code structure .
2024-08-18 02:52:27 +08:00
- Do not show the full content after these modifications .
2024-08-15 19:04:15 +08:00
Remember : Accurate line numbers are CRITICAL . The range start_line to end_line must include ALL lines to be replaced , from the very first to the very last . Double - check every range before finalizing your response , paying special attention to the start_line to ensure it hasn ' t shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift.
] ]
2024-08-17 22:29:05 +08:00
local function call_claude_api_stream ( question , code_lang , code_content , selected_code_content , on_chunk , on_complete )
2024-08-17 16:17:16 -04:00
local api_key = os.getenv ( E.key ( " claude " ) )
2024-08-15 19:04:15 +08:00
2024-08-17 15:14:30 +08:00
local tokens = Config.claude . max_tokens
2024-08-15 19:04:15 +08:00
local headers = {
[ " Content-Type " ] = " application/json " ,
[ " x-api-key " ] = api_key ,
[ " anthropic-version " ] = " 2023-06-01 " ,
[ " anthropic-beta " ] = " prompt-caching-2024-07-31 " ,
}
local code_prompt_obj = {
type = " text " ,
text = string.format ( " <code>```%s \n %s```</code> " , code_lang , code_content ) ,
}
2024-08-17 22:29:05 +08:00
if Tiktoken.count ( code_prompt_obj.text ) > 1024 then
code_prompt_obj.cache_control = { type = " ephemeral " }
end
if selected_code_content then
code_prompt_obj.text = string.format ( " <code_context>```%s \n %s```</code_context> " , code_lang , code_content )
end
local message_content = {
code_prompt_obj ,
}
if selected_code_content then
local selected_code_obj = {
type = " text " ,
text = string.format ( " <code>```%s \n %s```</code> " , code_lang , selected_code_content ) ,
}
if Tiktoken.count ( selected_code_obj.text ) > 1024 then
selected_code_obj.cache_control = { type = " ephemeral " }
end
table.insert ( message_content , selected_code_obj )
end
table.insert ( message_content , {
type = " text " ,
text = string.format ( " <question>%s</question> " , question ) ,
} )
local user_prompt = base_user_prompt
2024-08-15 19:04:15 +08:00
local user_prompt_obj = {
type = " text " ,
text = user_prompt ,
}
2024-08-17 15:14:30 +08:00
if Tiktoken.count ( user_prompt_obj.text ) > 1024 then
2024-08-15 19:04:15 +08:00
user_prompt_obj.cache_control = { type = " ephemeral " }
end
2024-08-17 22:29:05 +08:00
table.insert ( message_content , user_prompt_obj )
2024-08-15 19:04:15 +08:00
local body = {
2024-08-17 15:14:30 +08:00
model = Config.claude . model ,
2024-08-15 19:04:15 +08:00
system = system_prompt ,
messages = {
{
role = " user " ,
2024-08-17 22:29:05 +08:00
content = message_content ,
2024-08-15 19:04:15 +08:00
} ,
} ,
stream = true ,
2024-08-17 15:14:30 +08:00
temperature = Config.claude . temperature ,
2024-08-15 19:04:15 +08:00
max_tokens = tokens ,
}
2024-08-17 15:14:30 +08:00
local url = Utils.trim_suffix ( Config.claude . endpoint , " / " ) .. " /v1/messages "
2024-08-15 19:04:15 +08:00
curl.post ( url , {
---@diagnostic disable-next-line: unused-local
stream = function ( err , data , job )
if err then
on_complete ( err )
return
end
if not data then
return
end
2024-08-17 22:56:25 +08:00
for _ , line in ipairs ( vim.split ( data , " \n " ) ) do
2024-08-15 19:04:15 +08:00
if line : sub ( 1 , 6 ) ~= " data: " then
return
end
vim.schedule ( function ( )
local success , parsed = pcall ( fn.json_decode , line : sub ( 7 ) )
if not success then
error ( " Error: failed to parse json: " .. parsed )
return
end
if parsed and parsed.type == " content_block_delta " then
on_chunk ( parsed.delta . text )
elseif parsed and parsed.type == " message_stop " then
-- Stream request completed
on_complete ( nil )
elseif parsed and parsed.type == " error " then
-- Stream request completed
on_complete ( parsed )
end
end )
end
end ,
headers = headers ,
body = fn.json_encode ( body ) ,
} )
end
2024-08-17 22:29:05 +08:00
local function call_openai_api_stream ( question , code_lang , code_content , selected_code_content , on_chunk , on_complete )
2024-08-17 16:04:40 -04:00
local api_key = os.getenv ( E.key ( " openai " ) )
2024-08-15 19:04:15 +08:00
local user_prompt = base_user_prompt
.. " \n \n CODE: \n "
.. " ``` "
.. code_lang
.. " \n "
.. code_content
.. " \n ``` "
2024-08-17 22:29:05 +08:00
.. " \n \n QUESTION: \n "
.. question
if selected_code_content then
user_prompt = base_user_prompt
.. " \n \n CODE CONTEXT: \n "
.. " ``` "
.. code_lang
.. " \n "
.. code_content
.. " \n ``` "
.. " \n \n CODE: \n "
.. " ``` "
.. code_lang
.. " \n "
.. selected_code_content
.. " \n ``` "
.. " \n \n QUESTION: \n "
.. question
end
2024-08-15 19:04:15 +08:00
local url , headers , body
2024-08-17 15:14:30 +08:00
if Config.provider == " azure " then
2024-08-17 16:04:40 -04:00
api_key = os.getenv ( E.key ( " azure " ) )
2024-08-17 15:14:30 +08:00
url = Config.azure . endpoint
2024-08-15 19:04:15 +08:00
.. " /openai/deployments/ "
2024-08-17 15:14:30 +08:00
.. Config.azure . deployment
2024-08-15 19:04:15 +08:00
.. " /chat/completions?api-version= "
2024-08-17 15:14:30 +08:00
.. Config.azure . api_version
2024-08-15 19:04:15 +08:00
headers = {
[ " Content-Type " ] = " application/json " ,
[ " api-key " ] = api_key ,
}
body = {
messages = {
{ role = " system " , content = system_prompt } ,
{ role = " user " , content = user_prompt } ,
} ,
2024-08-17 15:14:30 +08:00
temperature = Config.azure . temperature ,
max_tokens = Config.azure . max_tokens ,
2024-08-15 19:04:15 +08:00
stream = true ,
}
2024-08-18 21:33:45 +08:00
elseif Config.provider == " deepseek " then
api_key = os.getenv ( E.key ( " deepseek " ) )
url = Utils.trim_suffix ( Config.deepseek . endpoint , " / " ) .. " /chat/completions "
headers = {
[ " Content-Type " ] = " application/json " ,
[ " Authorization " ] = " Bearer " .. api_key ,
}
body = {
model = Config.deepseek . model ,
messages = {
{ role = " system " , content = system_prompt } ,
{ role = " user " , content = user_prompt } ,
} ,
temperature = Config.deepseek . temperature ,
max_tokens = Config.deepseek . max_tokens ,
stream = true ,
}
2024-08-15 19:04:15 +08:00
else
2024-08-17 15:14:30 +08:00
url = Utils.trim_suffix ( Config.openai . endpoint , " / " ) .. " /v1/chat/completions "
2024-08-15 19:04:15 +08:00
headers = {
[ " Content-Type " ] = " application/json " ,
[ " Authorization " ] = " Bearer " .. api_key ,
}
body = {
2024-08-17 15:14:30 +08:00
model = Config.openai . model ,
2024-08-15 19:04:15 +08:00
messages = {
{ role = " system " , content = system_prompt } ,
{ role = " user " , content = user_prompt } ,
} ,
2024-08-17 15:14:30 +08:00
temperature = Config.openai . temperature ,
max_tokens = Config.openai . max_tokens ,
2024-08-15 19:04:15 +08:00
stream = true ,
}
end
curl.post ( url , {
---@diagnostic disable-next-line: unused-local
stream = function ( err , data , job )
if err then
on_complete ( err )
return
end
if not data then
return
end
2024-08-17 22:56:25 +08:00
for _ , line in ipairs ( vim.split ( data , " \n " ) ) do
2024-08-15 19:04:15 +08:00
if line : sub ( 1 , 6 ) ~= " data: " then
return
end
vim.schedule ( function ( )
local piece = line : sub ( 7 )
local success , parsed = pcall ( fn.json_decode , piece )
if not success then
if piece == " [DONE] " then
on_complete ( nil )
return
end
error ( " Error: failed to parse json: " .. parsed )
return
end
2024-08-16 13:40:41 +08:00
if parsed and parsed.choices and parsed.choices [ 1 ] then
local choice = parsed.choices [ 1 ]
if choice.finish_reason == " stop " then
on_complete ( nil )
elseif choice.delta and choice.delta . content then
on_chunk ( choice.delta . content )
end
2024-08-15 19:04:15 +08:00
end
end )
end
end ,
headers = headers ,
body = fn.json_encode ( body ) ,
} )
end
2024-08-17 15:14:30 +08:00
---@param question string
---@param code_lang string
---@param code_content string
2024-08-17 22:29:05 +08:00
---@param selected_content_content string | nil
2024-08-17 15:14:30 +08:00
---@param on_chunk fun(chunk: string): any
---@param on_complete fun(err: string|nil): any
2024-08-17 22:29:05 +08:00
function M . call_ai_api_stream ( question , code_lang , code_content , selected_content_content , on_chunk , on_complete )
2024-08-18 21:33:45 +08:00
if Config.provider == " openai " or Config.provider == " azure " or Config.provider == " deepseek " then
2024-08-17 22:29:05 +08:00
call_openai_api_stream ( question , code_lang , code_content , selected_content_content , on_chunk , on_complete )
2024-08-17 15:14:30 +08:00
elseif Config.provider == " claude " then
2024-08-17 22:29:05 +08:00
call_claude_api_stream ( question , code_lang , code_content , selected_content_content , on_chunk , on_complete )
2024-08-15 19:04:15 +08:00
end
end
2024-08-17 16:04:40 -04:00
function M . setup ( )
local has = E [ Config.provider ]
if not has then
E.setup ( E.key ( ) )
end
end
2024-08-15 19:04:15 +08:00
return M