2024-08-17 16:04:40 -04:00
local api = vim.api
2024-10-07 17:53:18 +02:00
local fn = vim.fn
2024-11-03 14:14:12 +05:30
local uv = vim.uv
2024-08-15 19:04:15 +08:00
2024-09-03 05:12:07 -04:00
local curl = require ( " plenary.curl " )
2024-08-15 19:04:15 +08:00
2024-09-03 05:12:07 -04:00
local Utils = require ( " avante.utils " )
local Config = require ( " avante.config " )
local Path = require ( " avante.path " )
local P = require ( " avante.providers " )
2024-08-18 15:03:25 -04:00
2024-08-18 22:20:29 -04:00
---@class avante.LLM
2024-08-17 15:14:30 +08:00
local M = { }
2024-08-15 19:04:15 +08:00
2024-08-18 22:20:29 -04:00
M.CANCEL_PATTERN = " AvanteLLMEscape "
2024-08-18 15:03:25 -04:00
------------------------------Prompt and type------------------------------
2024-08-26 22:31:57 -04:00
local group = api.nvim_create_augroup ( " avante_llm " , { clear = true } )
2024-08-18 15:03:25 -04:00
2024-12-17 14:43:25 +02:00
---@param opts GeneratePromptsOptions
---@return AvantePromptOptions
M.generate_prompts = function ( opts )
local Provider = opts.provider or P [ Config.provider ]
2024-08-30 18:53:49 +08:00
local mode = opts.mode or " planning "
2024-09-03 04:09:13 -04:00
---@type AvanteProviderFunctor
2024-11-04 16:20:28 +08:00
local _ , body_opts = P.parse_config ( Provider )
local max_tokens = body_opts.max_tokens or 4096
2024-08-30 18:53:49 +08:00
2024-08-30 22:21:50 +08:00
-- Check if the instructions contains an image path
2024-08-30 18:53:49 +08:00
local image_paths = { }
2024-11-04 16:20:28 +08:00
local instructions = opts.instructions
2024-09-03 05:12:07 -04:00
if opts.instructions : match ( " image: " ) then
2024-08-30 18:53:49 +08:00
local lines = vim.split ( opts.instructions , " \n " )
2024-08-28 14:43:14 -04:00
for i , line in ipairs ( lines ) do
2024-09-03 05:12:07 -04:00
if line : match ( " ^image: " ) then
2024-08-30 18:53:49 +08:00
local image_path = line : gsub ( " ^image: " , " " )
table.insert ( image_paths , image_path )
2024-08-28 14:43:14 -04:00
table.remove ( lines , i )
end
end
2024-11-04 16:20:28 +08:00
instructions = table.concat ( lines , " \n " )
2024-08-28 14:43:14 -04:00
end
2024-12-17 14:43:25 +02:00
local project_root = Utils.root . get ( )
Path.prompts . initialize ( Path.prompts . get ( project_root ) )
2024-09-04 03:19:33 -04:00
local template_opts = {
2024-09-03 04:09:13 -04:00
use_xml_format = Provider.use_xml_format ,
2024-09-05 02:43:31 -04:00
ask = opts.ask , -- TODO: add mode without ask instruction
2024-09-03 04:09:13 -04:00
code_lang = opts.code_lang ,
2024-12-12 03:29:10 +10:00
selected_files = opts.selected_files ,
2024-09-03 04:09:13 -04:00
selected_code = opts.selected_code ,
project_context = opts.project_context ,
2024-11-23 21:49:33 +08:00
diagnostics = opts.diagnostics ,
2024-09-04 03:19:33 -04:00
}
2024-11-04 16:20:28 +08:00
local system_prompt = Path.prompts . render_mode ( mode , template_opts )
---@type AvanteLLMMessage[]
local messages = { }
if opts.project_context ~= nil and opts.project_context ~= " " and opts.project_context ~= " null " then
local project_context = Path.prompts . render_file ( " _project.avanterules " , template_opts )
if project_context ~= " " then table.insert ( messages , { role = " user " , content = project_context } ) end
end
2024-11-23 21:49:33 +08:00
if opts.diagnostics ~= nil and opts.diagnostics ~= " " and opts.diagnostics ~= " null " then
local diagnostics = Path.prompts . render_file ( " _diagnostics.avanterules " , template_opts )
if diagnostics ~= " " then table.insert ( messages , { role = " user " , content = diagnostics } ) end
end
2024-12-18 21:22:40 +10:00
if # opts.selected_files > 0 or opts.selected_code ~= nil then
local code_context = Path.prompts . render_file ( " _context.avanterules " , template_opts )
if code_context ~= " " then table.insert ( messages , { role = " user " , content = code_context } ) end
end
2024-11-04 16:20:28 +08:00
if opts.use_xml_format then
table.insert ( messages , { role = " user " , content = string.format ( " <question>%s</question> " , instructions ) } )
else
table.insert ( messages , { role = " user " , content = string.format ( " QUESTION: \n %s " , instructions ) } )
end
local remaining_tokens = max_tokens - Utils.tokens . calculate_tokens ( system_prompt )
for _ , message in ipairs ( messages ) do
remaining_tokens = remaining_tokens - Utils.tokens . calculate_tokens ( message.content )
end
if opts.history_messages then
if Config.history . max_tokens > 0 then remaining_tokens = math.min ( Config.history . max_tokens , remaining_tokens ) end
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
local history_messages = { }
for i = # opts.history_messages , 1 , - 1 do
local message = opts.history_messages [ i ]
local tokens = Utils.tokens . calculate_tokens ( message.content )
remaining_tokens = remaining_tokens - tokens
if remaining_tokens > 0 then
table.insert ( history_messages , message )
else
break
end
end
-- prepend the history messages to the messages table
vim.iter ( history_messages ) : each ( function ( msg ) table.insert ( messages , 1 , msg ) end )
2024-11-18 18:07:33 +08:00
if # messages > 0 and messages [ 1 ] . role == " assistant " then table.remove ( messages , 1 ) end
2024-11-04 16:20:28 +08:00
end
2024-08-30 22:21:50 +08:00
2024-08-22 01:48:40 -04:00
---@type AvantePromptOptions
2024-12-17 14:43:25 +02:00
return {
2024-11-04 16:20:28 +08:00
system_prompt = system_prompt ,
messages = messages ,
2024-08-30 18:53:49 +08:00
image_paths = image_paths ,
2024-08-18 22:20:29 -04:00
}
2024-12-17 14:43:25 +02:00
end
---@param opts GeneratePromptsOptions
---@return integer
M.calculate_tokens = function ( opts )
local code_opts = M.generate_prompts ( opts )
local tokens = Utils.tokens . calculate_tokens ( code_opts.system_prompt )
for _ , message in ipairs ( code_opts.messages ) do
tokens = tokens + Utils.tokens . calculate_tokens ( message.content )
end
return tokens
end
---@param opts StreamOptions
M._stream = function ( opts )
local Provider = opts.provider or P [ Config.provider ]
local code_opts = M.generate_prompts ( opts )
2024-08-22 01:48:40 -04:00
---@type string
local current_event_state = nil
2024-08-18 22:20:29 -04:00
2024-08-22 01:48:40 -04:00
---@type AvanteHandlerOptions
2024-08-30 18:53:49 +08:00
local handler_opts = { on_chunk = opts.on_chunk , on_complete = opts.on_complete }
2024-08-22 01:48:40 -04:00
---@type AvanteCurlOutput
2024-08-24 17:52:38 -04:00
local spec = Provider.parse_curl_args ( Provider , code_opts )
2024-08-18 15:03:25 -04:00
---@param line string
2024-08-22 01:48:40 -04:00
local function parse_stream_data ( line )
2024-09-03 05:12:07 -04:00
local event = line : match ( " ^event: (.+)$ " )
2024-08-18 15:03:25 -04:00
if event then
2024-08-19 06:11:02 -04:00
current_event_state = event
2024-08-18 15:03:25 -04:00
return
end
2024-09-03 05:12:07 -04:00
local data_match = line : match ( " ^data: (.+)$ " )
2024-09-03 04:19:54 -04:00
if data_match then Provider.parse_response ( data_match , current_event_state , handler_opts ) end
2024-08-18 15:03:25 -04:00
end
2024-09-27 16:08:10 +03:00
local function parse_response_without_stream ( data )
Provider.parse_response_without_stream ( data , current_event_state , handler_opts )
end
2024-08-28 22:17:00 +08:00
local completed = false
2024-09-03 14:03:59 +08:00
local active_job
2024-10-07 17:53:18 +02:00
local curl_body_file = fn.tempname ( ) .. " .json "
local json_content = vim.json . encode ( spec.body )
fn.writefile ( vim.split ( json_content , " \n " ) , curl_body_file )
Utils.debug ( " curl body file: " , curl_body_file )
local function cleanup ( )
if Config.debug then return end
vim.schedule ( function ( ) fn.delete ( curl_body_file ) end )
end
2024-08-18 15:03:25 -04:00
active_job = curl.post ( spec.url , {
headers = spec.headers ,
2024-08-22 01:48:40 -04:00
proxy = spec.proxy ,
insecure = spec.insecure ,
2024-10-07 17:53:18 +02:00
body = curl_body_file ,
2024-12-04 11:57:07 +01:00
raw = spec.rawArgs ,
2024-08-18 15:03:25 -04:00
stream = function ( err , data , _ )
2024-08-15 19:04:15 +08:00
if err then
2024-08-28 22:17:00 +08:00
completed = true
2024-08-30 18:53:49 +08:00
opts.on_complete ( err )
2024-08-15 19:04:15 +08:00
return
end
2024-09-03 04:19:54 -04:00
if not data then return end
2024-08-18 15:03:25 -04:00
vim.schedule ( function ( )
2024-09-03 04:09:13 -04:00
if Config.options [ Config.provider ] == nil and Provider.parse_stream_data ~= nil then
2024-08-22 01:48:40 -04:00
if Provider.parse_response ~= nil then
2024-08-20 07:54:58 -04:00
Utils.warn (
2024-11-07 01:15:37 -05:00
" parse_stream_data and parse_response are mutually exclusive, and thus parse_response will be ignored. Make sure that you handle the incoming data correctly. " ,
2024-08-20 07:54:58 -04:00
{ once = true }
)
end
2024-08-22 01:48:40 -04:00
Provider.parse_stream_data ( data , handler_opts )
2024-08-20 07:54:58 -04:00
else
2024-08-23 09:36:40 -04:00
if Provider.parse_stream_data ~= nil then
Provider.parse_stream_data ( data , handler_opts )
else
parse_stream_data ( data )
end
2024-08-20 07:54:58 -04:00
end
2024-08-18 15:03:25 -04:00
end )
end ,
2024-11-03 14:14:12 +05:30
on_error = function ( err )
if err.exit == 23 then
local xdg_runtime_dir = os.getenv ( " XDG_RUNTIME_DIR " )
2024-11-04 16:20:28 +08:00
if not xdg_runtime_dir or fn.isdirectory ( xdg_runtime_dir ) == 0 then
2024-11-03 14:14:12 +05:30
Utils.error (
" $XDG_RUNTIME_DIR= "
.. xdg_runtime_dir
.. " is set but does not exist. curl could not write output. Please make sure it exists, or unset. " ,
{ title = " Avante " }
)
elseif not uv.fs_access ( xdg_runtime_dir , " w " ) then
Utils.error (
" $XDG_RUNTIME_DIR= "
.. xdg_runtime_dir
.. " exists but is not writable. curl could not write output. Please make sure it is writable, or unset. " ,
{ title = " Avante " }
)
end
end
2024-09-03 14:03:59 +08:00
active_job = nil
2024-08-28 22:17:00 +08:00
completed = true
2024-10-07 17:53:18 +02:00
cleanup ( )
2024-11-20 01:33:28 +08:00
opts.on_complete ( err )
2024-08-18 15:03:25 -04:00
end ,
2024-08-25 21:26:19 -04:00
callback = function ( result )
2024-09-03 14:03:59 +08:00
active_job = nil
2024-10-07 17:53:18 +02:00
cleanup ( )
2024-08-25 21:26:19 -04:00
if result.status >= 400 then
if Provider.on_error then
Provider.on_error ( result )
else
Utils.error ( " API request failed with status " .. result.status , { once = true , title = " Avante " } )
end
2024-08-28 22:17:00 +08:00
vim.schedule ( function ( )
if not completed then
completed = true
2024-08-30 18:53:49 +08:00
opts.on_complete (
" API request failed with status " .. result.status .. " . Body: " .. vim.inspect ( result.body )
)
2024-08-28 22:17:00 +08:00
end
end )
2024-08-25 21:26:19 -04:00
end
2024-09-27 16:08:10 +03:00
-- If stream is not enabled, then handle the response here
if spec.body . stream == false and result.status == 200 then
vim.schedule ( function ( )
completed = true
parse_response_without_stream ( result.body )
end )
end
2024-08-18 15:03:25 -04:00
end ,
} )
api.nvim_create_autocmd ( " User " , {
group = group ,
pattern = M.CANCEL_PATTERN ,
2024-09-03 14:03:59 +08:00
once = true ,
2024-08-18 15:03:25 -04:00
callback = function ( )
2024-09-03 14:03:59 +08:00
-- Error: cannot resume dead coroutine
2024-08-18 15:03:25 -04:00
if active_job then
2024-09-03 04:19:54 -04:00
xpcall ( function ( ) active_job : shutdown ( ) end , function ( err ) return err end )
2024-09-26 11:18:40 +08:00
Utils.debug ( " LLM request cancelled " )
2024-08-18 15:03:25 -04:00
active_job = nil
2024-08-15 19:04:15 +08:00
end
end ,
} )
2024-08-18 15:03:25 -04:00
return active_job
2024-08-15 19:04:15 +08:00
end
2024-12-17 14:43:25 +02:00
local function _merge_response ( first_response , second_response , opts )
2024-11-17 00:54:01 -07:00
local prompt = " \n " .. Config.dual_boost . prompt
prompt = prompt
: gsub ( " {{[%s]*provider1_output[%s]*}} " , first_response )
: gsub ( " {{[%s]*provider2_output[%s]*}} " , second_response )
prompt = prompt .. " \n "
-- append this reference prompt to the code_opts messages at last
opts.instructions = opts.instructions .. prompt
2024-12-17 14:43:25 +02:00
M._stream ( opts )
2024-11-17 00:54:01 -07:00
end
2024-12-17 14:43:25 +02:00
local function _collector_process_responses ( collector , opts )
2024-11-17 00:54:01 -07:00
if not collector [ 1 ] or not collector [ 2 ] then
Utils.error ( " One or both responses failed to complete " )
return
end
2024-12-17 14:43:25 +02:00
_merge_response ( collector [ 1 ] , collector [ 2 ] , opts )
2024-11-17 00:54:01 -07:00
end
2024-12-17 14:43:25 +02:00
local function _collector_add_response ( collector , index , response , opts )
2024-11-17 00:54:01 -07:00
collector [ index ] = response
collector.count = collector.count + 1
if collector.count == 2 then
collector.timer : stop ( )
2024-12-17 14:43:25 +02:00
_collector_process_responses ( collector , opts )
2024-11-17 00:54:01 -07:00
end
end
2024-12-17 14:43:25 +02:00
M._dual_boost_stream = function ( opts , Provider1 , Provider2 )
2024-11-17 00:54:01 -07:00
Utils.debug ( " Starting Dual Boost Stream " )
local collector = {
count = 0 ,
responses = { } ,
timer = uv.new_timer ( ) ,
timeout_ms = Config.dual_boost . timeout ,
}
-- Setup timeout
collector.timer : start (
collector.timeout_ms ,
0 ,
vim.schedule_wrap ( function ( )
if collector.count < 2 then
Utils.warn ( " Dual boost stream timeout reached " )
collector.timer : stop ( )
-- Process whatever responses we have
2024-12-17 14:43:25 +02:00
_collector_process_responses ( collector , opts )
2024-11-17 00:54:01 -07:00
end
end )
)
-- Create options for both streams
local function create_stream_opts ( index )
local response = " "
return vim.tbl_extend ( " force " , opts , {
on_chunk = function ( chunk )
if chunk then response = response .. chunk end
end ,
on_complete = function ( err )
if err then
Utils.error ( string.format ( " Stream %d failed: %s " , index , err ) )
return
end
Utils.debug ( string.format ( " Response %d completed " , index ) )
2024-12-17 14:43:25 +02:00
_collector_add_response ( collector , index , response , opts )
2024-11-17 00:54:01 -07:00
end ,
} )
end
-- Start both streams
local success , err = xpcall ( function ( )
2024-12-17 14:43:25 +02:00
local opts1 = create_stream_opts ( 1 )
opts1.provider = Provider1
M._stream ( opts1 )
local opts2 = create_stream_opts ( 2 )
opts2.provider = Provider2
M._stream ( opts2 )
2024-11-17 00:54:01 -07:00
end , function ( err ) return err end )
if not success then Utils.error ( " Failed to start dual_boost streams: " .. tostring ( err ) ) end
end
---@alias LlmMode "planning" | "editing" | "suggesting"
---
2024-12-12 03:29:10 +10:00
---@class SelectedFiles
---@field path string
---@field content string
---@field file_type string
---
2024-11-17 00:54:01 -07:00
---@class TemplateOptions
---@field use_xml_format boolean
---@field ask boolean
---@field question string
---@field code_lang string
---@field selected_code string | nil
---@field project_context string | nil
2024-12-12 03:29:10 +10:00
---@field selected_files SelectedFiles[] | nil
2024-11-23 21:49:33 +08:00
---@field diagnostics string | nil
2024-11-17 00:54:01 -07:00
---@field history_messages AvanteLLMMessage[]
---
2024-12-17 14:43:25 +02:00
---@class GeneratePromptsOptions: TemplateOptions
2024-11-17 00:54:01 -07:00
---@field ask boolean
---@field instructions string
---@field mode LlmMode
---@field provider AvanteProviderFunctor | nil
2024-12-17 14:43:25 +02:00
---
---@class StreamOptions: GeneratePromptsOptions
2024-11-17 00:54:01 -07:00
---@field on_chunk AvanteChunkParser
---@field on_complete AvanteCompleteParser
---@param opts StreamOptions
M.stream = function ( opts )
2024-12-12 03:29:10 +10:00
local is_completed = false
if opts.on_chunk ~= nil then
local original_on_chunk = opts.on_chunk
opts.on_chunk = vim.schedule_wrap ( function ( chunk )
if is_completed then return end
return original_on_chunk ( chunk )
end )
end
if opts.on_complete ~= nil then
local original_on_complete = opts.on_complete
opts.on_complete = vim.schedule_wrap ( function ( err )
if is_completed then return end
is_completed = true
return original_on_complete ( err )
end )
end
2024-11-17 00:54:01 -07:00
if Config.dual_boost . enabled then
2024-12-17 14:43:25 +02:00
M._dual_boost_stream ( opts , P [ Config.dual_boost . first_provider ] , P [ Config.dual_boost . second_provider ] )
2024-11-17 00:54:01 -07:00
else
2024-12-17 14:43:25 +02:00
M._stream ( opts )
2024-11-17 00:54:01 -07:00
end
end
2024-09-03 04:19:54 -04:00
function M . cancel_inflight_request ( ) api.nvim_exec_autocmds ( " User " , { pattern = M.CANCEL_PATTERN } ) end
2024-09-03 14:03:59 +08:00
2024-08-20 07:43:53 -04:00
return M