2024-08-17 16:04:40 -04:00
local api = vim.api
2024-08-15 19:04:15 +08:00
local curl = require ( " plenary.curl " )
2024-08-17 15:14:30 +08:00
local Utils = require ( " avante.utils " )
local Config = require ( " avante.config " )
local Tiktoken = require ( " avante.tiktoken " )
2024-08-18 15:03:25 -04:00
local Dressing = require ( " avante.ui.dressing " )
---@private
---@class AvanteAiBotInternal
local H = { }
2024-08-17 15:14:30 +08:00
---@class avante.AiBot
local M = { }
2024-08-15 19:04:15 +08:00
2024-08-18 15:03:25 -04:00
M.CANCEL_PATTERN = " AvanteAiBotEscape "
2024-08-17 16:04:40 -04:00
---@class EnvironmentHandler: table<[Provider], string>
local E = {
2024-08-18 17:54:29 +08:00
---@type table<Provider, string>
2024-08-17 16:04:40 -04:00
env = {
openai = " OPENAI_API_KEY " ,
claude = " ANTHROPIC_API_KEY " ,
2024-08-18 17:54:29 +08:00
azure = " AZURE_OPENAI_API_KEY " ,
2024-08-18 21:33:45 +08:00
deepseek = " DEEPSEEK_API_KEY " ,
2024-08-18 12:11:39 -04:00
groq = " GROQ_API_KEY " ,
2024-08-17 16:04:40 -04:00
} ,
}
E = setmetatable ( E , {
---@param k Provider
__index = function ( _ , k )
2024-08-18 17:54:29 +08:00
return os.getenv ( E.env [ k ] ) and true or false
2024-08-17 16:04:40 -04:00
end ,
} )
2024-08-18 15:03:25 -04:00
E._once = false
2024-08-17 16:04:40 -04:00
--- return the environment variable name for the given provider
---@param provider? Provider
---@return string the envvar key
E.key = function ( provider )
2024-08-18 15:03:25 -04:00
return E.env [ provider or Config.provider ]
2024-08-17 16:04:40 -04:00
end
2024-08-18 15:03:25 -04:00
---@param provider? Provider
E.value = function ( provider )
return os.getenv ( E.key ( provider or Config.provider ) )
end
2024-08-18 06:07:29 -04:00
2024-08-18 15:03:25 -04:00
--- intialize the environment variable for current neovim session.
--- This will only run once and spawn a UI for users to input the envvar.
---@param var Provider supported providers
---@param refresh? boolean
E.setup = function ( var , refresh )
refresh = refresh or false
2024-08-17 16:04:40 -04:00
2024-08-18 06:07:29 -04:00
---@param value string
---@return nil
local function on_confirm ( value )
if value then
2024-08-17 16:04:40 -04:00
vim.fn . setenv ( var , value )
2024-08-18 06:07:29 -04:00
else
2024-08-17 16:04:40 -04:00
if not E [ Config.provider ] then
vim.notify_once ( " Failed to set " .. var .. " . Avante won't work as expected " , vim.log . levels.WARN )
end
2024-08-18 06:07:29 -04:00
end
end
2024-08-17 16:04:40 -04:00
2024-08-18 15:03:25 -04:00
if refresh then
vim.defer_fn ( function ( )
Dressing.initialize_input_buffer ( { opts = { prompt = " Enter " .. var .. " : " } , on_confirm = on_confirm } )
end , 200 )
elseif not E._once then
E._once = true
api.nvim_create_autocmd ( { " BufEnter " , " BufWinEnter " } , {
pattern = " * " ,
once = true ,
callback = function ( )
vim.defer_fn ( function ( )
-- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
local exclude_buftypes = { " dashboard " , " alpha " , " qf " , " nofile " }
local exclude_filetypes = {
" NvimTree " ,
" Outline " ,
" help " ,
" dashboard " ,
" alpha " ,
" qf " ,
" ministarter " ,
" TelescopePrompt " ,
" gitcommit " ,
" gitrebase " ,
}
if
not vim.tbl_contains ( exclude_buftypes , vim.bo . buftype )
and not vim.tbl_contains ( exclude_filetypes , vim.bo . filetype )
then
Dressing.initialize_input_buffer ( { opts = { prompt = " Enter " .. var .. " : " } , on_confirm = on_confirm } )
end
end , 200 )
end ,
} )
end
2024-08-17 16:04:40 -04:00
end
2024-08-18 15:03:25 -04:00
------------------------------Prompt and type------------------------------
2024-08-15 19:04:15 +08:00
local system_prompt = [ [
You are an excellent programming expert .
] ]
local base_user_prompt = [ [
Your primary task is to suggest code modifications with precise line number ranges . Follow these instructions meticulously :
1. Carefully analyze the original code , paying close attention to its structure and line numbers . Line numbers start from 1 and include ALL lines , even empty ones .
2. When suggesting modifications :
2024-08-17 22:29:05 +08:00
a . Use the language in the question to reply . If there are non - English parts in the question , use the language of those parts .
b . Explain why the change is necessary or beneficial .
c . Provide the exact code snippet to be replaced using this format :
2024-08-15 19:04:15 +08:00
Replace lines : { { start_line } } - { { end_line } }
` ` ` { { language } }
{ { suggested_code } }
` ` `
2024-08-18 02:46:19 +08:00
3. Crucial guidelines for suggested code snippets :
- Only apply the change ( s ) suggested by the most recent assistant message ( before your generation ) .
- Do not make any unrelated changes to the code .
- Produce a valid full rewrite of the entire original file without skipping any lines . Do not be lazy !
- Do not arbitrarily delete pre - existing comments / empty Lines .
- Do not omit large parts of the original file for no reason .
- Do not omit any needed changes from the requisite messages / code blocks .
- If there is a clicked code block , bias towards just applying that ( and applying other changes implied ) .
2024-08-18 02:52:27 +08:00
- Please keep your suggested code changes minimal , and do not include irrelevant lines in the code snippet .
2024-08-18 02:46:19 +08:00
4. Crucial guidelines for line numbers :
2024-08-19 00:05:13 +08:00
- The content regarding line numbers MUST strictly follow the format " Replace lines: {{start_line}}-{{end_line}} " . Do not be lazy !
2024-08-15 19:04:15 +08:00
- The range { { start_line } } - { { end_line } } is INCLUSIVE . Both start_line and end_line are included in the replacement .
2024-08-18 04:47:47 +08:00
- Count EVERY line , including empty lines and comments lines , comments . Do not be lazy !
2024-08-15 19:04:15 +08:00
- For single - line changes , use the same number for start and end lines .
- For multi - line changes , ensure the range covers ALL affected lines , from the very first to the very last .
- Double - check that your line numbers align perfectly with the original code structure .
5. Final check :
- Review all suggestions , ensuring each line number is correct , especially the start_line and end_line .
- Confirm that no unrelated code is accidentally modified or deleted .
- Verify that the start_line and end_line correctly include all intended lines for replacement .
- Perform a final alignment check to ensure your line numbers haven ' t shifted, especially the start_line.
- Double - check that your line numbers align perfectly with the original code structure .
2024-08-18 02:52:27 +08:00
- Do not show the full content after these modifications .
2024-08-15 19:04:15 +08:00
Remember : Accurate line numbers are CRITICAL . The range start_line to end_line must include ALL lines to be replaced , from the very first to the very last . Double - check every range before finalizing your response , paying special attention to the start_line to ensure it hasn ' t shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift.
] ]
2024-08-18 15:03:25 -04:00
---@class AvantePromptOptions: table<[string], string>
---@field question string
---@field code_lang string
---@field code_content string
---@field selected_code_content? string
---
---@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): {role: "user" | "system", content: string | table<string, any>}[]
---
---@class AvanteCurlOutput: {url: string, body: table<string, any> | string, headers: table<string, string>}
---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
---
---@class ResponseParser
---@field event_state string
---@field on_chunk fun(chunk: string): any
---@field on_complete fun(err: string|nil): any
---@field on_error? fun(err_type: string): nil
---@alias AvanteAiResponseParser fun(data_stream: string, opts: ResponseParser): nil
------------------------------Anthropic------------------------------
---@type AvanteAiMessageBuilder
H.make_claude_message = function ( opts )
2024-08-15 19:04:15 +08:00
local code_prompt_obj = {
type = " text " ,
2024-08-18 15:03:25 -04:00
text = string.format ( " <code>```%s \n %s```</code> " , opts.code_lang , opts.code_content ) ,
2024-08-15 19:04:15 +08:00
}
2024-08-17 22:29:05 +08:00
if Tiktoken.count ( code_prompt_obj.text ) > 1024 then
code_prompt_obj.cache_control = { type = " ephemeral " }
end
2024-08-18 15:03:25 -04:00
if opts.selected_code_content then
code_prompt_obj.text = string.format ( " <code_context>```%s \n %s```</code_context> " , opts.code_lang , opts.code_content )
2024-08-17 22:29:05 +08:00
end
local message_content = {
code_prompt_obj ,
}
2024-08-18 15:03:25 -04:00
if opts.selected_code_content then
2024-08-17 22:29:05 +08:00
local selected_code_obj = {
type = " text " ,
2024-08-18 15:03:25 -04:00
text = string.format ( " <code>```%s \n %s```</code> " , opts.code_lang , opts.selected_code_content ) ,
2024-08-17 22:29:05 +08:00
}
if Tiktoken.count ( selected_code_obj.text ) > 1024 then
selected_code_obj.cache_control = { type = " ephemeral " }
end
table.insert ( message_content , selected_code_obj )
end
table.insert ( message_content , {
type = " text " ,
2024-08-18 15:03:25 -04:00
text = string.format ( " <question>%s</question> " , opts.question ) ,
2024-08-17 22:29:05 +08:00
} )
local user_prompt = base_user_prompt
2024-08-15 19:04:15 +08:00
local user_prompt_obj = {
type = " text " ,
text = user_prompt ,
}
2024-08-17 15:14:30 +08:00
if Tiktoken.count ( user_prompt_obj.text ) > 1024 then
2024-08-15 19:04:15 +08:00
user_prompt_obj.cache_control = { type = " ephemeral " }
end
2024-08-17 22:29:05 +08:00
table.insert ( message_content , user_prompt_obj )
2024-08-18 15:03:25 -04:00
return {
{
role = " user " ,
content = message_content ,
2024-08-15 19:04:15 +08:00
} ,
}
2024-08-18 15:03:25 -04:00
end
2024-08-15 19:04:15 +08:00
2024-08-18 15:03:25 -04:00
---@type AvanteAiResponseParser
H.parse_claude_response = function ( data_stream , opts )
if opts.event_state == " content_block_delta " then
local json = vim.json . decode ( data_stream )
opts.on_chunk ( json.delta . text )
elseif opts.event_state == " message_stop " then
opts.on_complete ( nil )
return
elseif opts.event_state == " error " then
opts.on_complete ( vim.json . decode ( data_stream ) )
end
end
2024-08-15 19:04:15 +08:00
2024-08-18 15:03:25 -04:00
---@type AvanteCurlArgsBuilder
H.make_claude_curl_args = function ( code_opts )
return {
url = Utils.trim ( Config.claude . endpoint , { suffix = " / " } ) .. " /v1/messages " ,
headers = {
[ " Content-Type " ] = " application/json " ,
[ " x-api-key " ] = E.value ( " claude " ) ,
[ " anthropic-version " ] = " 2023-06-01 " ,
[ " anthropic-beta " ] = " prompt-caching-2024-07-31 " ,
} ,
body = {
model = Config.claude . model ,
system = system_prompt ,
stream = true ,
messages = H.make_claude_message ( code_opts ) ,
temperature = Config.claude . temperature ,
max_tokens = Config.claude . max_tokens ,
} ,
}
2024-08-15 19:04:15 +08:00
end
2024-08-18 15:03:25 -04:00
------------------------------OpenAI------------------------------
---@type AvanteAiMessageBuilder
H.make_openai_message = function ( opts )
2024-08-15 19:04:15 +08:00
local user_prompt = base_user_prompt
.. " \n \n CODE: \n "
.. " ``` "
2024-08-18 15:03:25 -04:00
.. opts.code_lang
2024-08-15 19:04:15 +08:00
.. " \n "
2024-08-18 15:03:25 -04:00
.. opts.code_content
2024-08-15 19:04:15 +08:00
.. " \n ``` "
2024-08-17 22:29:05 +08:00
.. " \n \n QUESTION: \n "
2024-08-18 15:03:25 -04:00
.. opts.question
2024-08-17 22:29:05 +08:00
2024-08-18 15:03:25 -04:00
if opts.selected_code_content ~= nil then
2024-08-17 22:29:05 +08:00
user_prompt = base_user_prompt
.. " \n \n CODE CONTEXT: \n "
.. " ``` "
2024-08-18 15:03:25 -04:00
.. opts.code_lang
2024-08-17 22:29:05 +08:00
.. " \n "
2024-08-18 15:03:25 -04:00
.. opts.code_content
2024-08-17 22:29:05 +08:00
.. " \n ``` "
.. " \n \n CODE: \n "
.. " ``` "
2024-08-18 15:03:25 -04:00
.. opts.code_lang
2024-08-17 22:29:05 +08:00
.. " \n "
2024-08-18 15:03:25 -04:00
.. opts.selected_code_content
2024-08-17 22:29:05 +08:00
.. " \n ``` "
.. " \n \n QUESTION: \n "
2024-08-18 15:03:25 -04:00
.. opts.question
end
return {
{ role = " system " , content = system_prompt } ,
{ role = " user " , content = user_prompt } ,
}
end
---@type AvanteAiResponseParser
H.parse_openai_response = function ( data_stream , opts )
if data_stream : match ( ' "%[DONE%]": ' ) then
opts.on_complete ( nil )
return
2024-08-17 22:29:05 +08:00
end
2024-08-18 15:03:25 -04:00
if data_stream : match ( ' "delta": ' ) then
local json = vim.json . decode ( data_stream )
if json.choices and json.choices [ 1 ] then
local choice = json.choices [ 1 ]
if choice.finish_reason == " stop " then
opts.on_complete ( nil )
elseif choice.delta . content then
opts.on_chunk ( choice.delta . content )
end
end
end
end
2024-08-15 19:04:15 +08:00
2024-08-18 15:03:25 -04:00
---@type AvanteCurlArgsBuilder
H.make_openai_curl_args = function ( code_opts )
return {
url = Utils.trim ( Config.openai . endpoint , { suffix = " / " } ) .. " /v1/chat/completions " ,
headers = {
[ " Content-Type " ] = " application/json " ,
[ " Authorization " ] = " Bearer " .. E.value ( " openai " ) ,
} ,
body = {
model = Config.openai . model ,
messages = H.make_openai_message ( code_opts ) ,
temperature = Config.openai . temperature ,
max_tokens = Config.openai . max_tokens ,
stream = true ,
} ,
}
end
------------------------------Azure------------------------------
---@type AvanteAiMessageBuilder
H.make_azure_message = H.make_openai_message
---@type AvanteAiResponseParser
H.parse_azure_response = H.parse_openai_response
---@type AvanteCurlArgsBuilder
H.make_azure_curl_args = function ( code_opts )
return {
2024-08-17 15:14:30 +08:00
url = Config.azure . endpoint
2024-08-15 19:04:15 +08:00
.. " /openai/deployments/ "
2024-08-17 15:14:30 +08:00
.. Config.azure . deployment
2024-08-15 19:04:15 +08:00
.. " /chat/completions?api-version= "
2024-08-18 15:03:25 -04:00
.. Config.azure . api_version ,
2024-08-15 19:04:15 +08:00
headers = {
[ " Content-Type " ] = " application/json " ,
2024-08-18 15:03:25 -04:00
[ " api-key " ] = E.value ( " azure " ) ,
} ,
2024-08-15 19:04:15 +08:00
body = {
2024-08-18 15:03:25 -04:00
messages = H.make_openai_message ( code_opts ) ,
2024-08-17 15:14:30 +08:00
temperature = Config.azure . temperature ,
max_tokens = Config.azure . max_tokens ,
2024-08-15 19:04:15 +08:00
stream = true ,
2024-08-18 15:03:25 -04:00
} ,
}
end
------------------------------Deepseek------------------------------
---@type AvanteAiMessageBuilder
H.make_deepseek_message = H.make_openai_message
---@type AvanteAiResponseParser
H.parse_deepseek_response = H.parse_openai_response
---@type AvanteCurlArgsBuilder
H.make_deepseek_curl_args = function ( code_opts )
return {
url = Utils.trim ( Config.deepseek . endpoint , { suffix = " / " } ) .. " /chat/completions " ,
2024-08-18 21:33:45 +08:00
headers = {
[ " Content-Type " ] = " application/json " ,
2024-08-18 15:03:25 -04:00
[ " Authorization " ] = " Bearer " .. E.value ( " deepseek " ) ,
} ,
2024-08-18 21:33:45 +08:00
body = {
model = Config.deepseek . model ,
2024-08-18 15:03:25 -04:00
messages = H.make_openai_message ( code_opts ) ,
2024-08-18 21:33:45 +08:00
temperature = Config.deepseek . temperature ,
max_tokens = Config.deepseek . max_tokens ,
stream = true ,
2024-08-18 15:03:25 -04:00
} ,
}
end
------------------------------Grok------------------------------
---@type AvanteAiMessageBuilder
H.make_groq_message = H.make_openai_message
---@type AvanteAiResponseParser
H.parse_groq_response = H.parse_openai_response
---@type AvanteCurlArgsBuilder
H.make_groq_curl_args = function ( code_opts )
return {
url = Utils.trim ( Config.groq . endpoint , { suffix = " / " } ) .. " /openai/v1/chat/completions " ,
2024-08-18 12:11:39 -04:00
headers = {
[ " Content-Type " ] = " application/json " ,
2024-08-18 15:03:25 -04:00
[ " Authorization " ] = " Bearer " .. E.value ( " groq " ) ,
} ,
2024-08-18 12:11:39 -04:00
body = {
model = Config.groq . model ,
2024-08-18 15:03:25 -04:00
messages = H.make_openai_message ( code_opts ) ,
2024-08-18 12:11:39 -04:00
temperature = Config.groq . temperature ,
max_tokens = Config.groq . max_tokens ,
stream = true ,
2024-08-18 15:03:25 -04:00
} ,
}
end
------------------------------Logic------------------------------
local group = vim.api . nvim_create_augroup ( " AvanteAiBot " , { clear = true } )
local active_job = nil
---@param question string
---@param code_lang string
---@param code_content string
---@param selected_content_content string | nil
---@param on_chunk fun(chunk: string): any
---@param on_complete fun(err: string|nil): any
M.invoke_llm_stream = function ( question , code_lang , code_content , selected_content_content , on_chunk , on_complete )
local provider = Config.provider
local event_state = nil
---@type AvanteCurlOutput
local spec = H [ " make_ " .. provider .. " _curl_args " ] ( {
question = question ,
code_lang = code_lang ,
code_content = code_content ,
selected_code_content = selected_content_content ,
} )
---@param line string
local function parse_and_call ( line )
local event = line : match ( " ^event: (.+)$ " )
if event then
event_state = event
return
end
local data_match = line : match ( " ^data: (.+)$ " )
if data_match then
H [ " parse_ " .. provider .. " _response " ] (
data_match ,
vim.deepcopy ( { on_chunk = on_chunk , on_complete = on_complete , event_state = event_state } , true )
)
end
end
if active_job then
active_job : shutdown ( )
active_job = nil
2024-08-15 19:04:15 +08:00
end
2024-08-18 15:03:25 -04:00
active_job = curl.post ( spec.url , {
headers = spec.headers ,
body = vim.json . encode ( spec.body ) ,
stream = function ( err , data , _ )
2024-08-15 19:04:15 +08:00
if err then
on_complete ( err )
return
end
if not data then
return
end
2024-08-18 15:03:25 -04:00
vim.schedule ( function ( )
parse_and_call ( data )
end )
end ,
on_error = function ( err )
on_complete ( err )
end ,
callback = function ( _ )
active_job = nil
end ,
} )
api.nvim_create_autocmd ( " User " , {
group = group ,
pattern = M.CANCEL_PATTERN ,
callback = function ( )
if active_job then
active_job : shutdown ( )
vim.notify ( " LLM request cancelled " , vim.log . levels.DEBUG )
active_job = nil
2024-08-15 19:04:15 +08:00
end
end ,
} )
2024-08-18 15:03:25 -04:00
return active_job
2024-08-15 19:04:15 +08:00
end
2024-08-17 16:04:40 -04:00
function M . setup ( )
local has = E [ Config.provider ]
if not has then
E.setup ( E.key ( ) )
end
2024-08-18 15:03:25 -04:00
M.commands ( )
end
---@param provider Provider
function M . refresh ( provider )
local has = E [ provider ]
if not has then
E.setup ( E.key ( provider ) , true )
else
vim.notify_once ( " Switch to provider: " .. provider , vim.log . levels.INFO )
end
require ( " avante " ) . setup ( { provider = provider } )
end
M.commands = function ( )
api.nvim_create_user_command ( " AvanteSwitchProvider " , function ( args )
local cmd = vim.trim ( args.args or " " )
M.refresh ( cmd )
end , {
nargs = 1 ,
desc = " avante: switch provider " ,
complete = function ( _ , line )
if line : match ( " ^%s*AvanteSwitchProvider %w " ) then
return { }
end
local prefix = line : match ( " ^%s*AvanteSwitchProvider (%w*) " ) or " "
return vim.tbl_filter ( function ( key )
return key : find ( prefix ) == 1
end , vim.tbl_keys ( E.env ) )
end ,
} )
2024-08-17 16:04:40 -04:00
end
2024-08-15 19:04:15 +08:00
return M