2024-08-17 16:04:40 -04:00
local api = vim.api
2024-08-15 19:04:15 +08:00
local curl = require ( " plenary.curl " )
2024-08-17 15:14:30 +08:00
local Utils = require ( " avante.utils " )
local Config = require ( " avante.config " )
local Tiktoken = require ( " avante.tiktoken " )
2024-08-18 15:03:25 -04:00
local Dressing = require ( " avante.ui.dressing " )
2024-08-18 22:20:29 -04:00
---@class avante.LLM
2024-08-17 15:14:30 +08:00
local M = { }
2024-08-15 19:04:15 +08:00
2024-08-18 22:20:29 -04:00
M.CANCEL_PATTERN = " AvanteLLMEscape "
2024-08-18 15:03:25 -04:00
2024-08-20 14:24:33 -04:00
---@class CopilotToken
---@field annotations_enabled boolean
---@field chat_enabled boolean
---@field chat_jetbrains_enabled boolean
---@field code_quote_enabled boolean
---@field codesearch boolean
---@field copilotignore_enabled boolean
---@field endpoints {api: string, ["origin-tracker"]: string, proxy: string, telemetry: string}
---@field expires_at integer
---@field individual boolean
---@field nes_enabled boolean
---@field prompt_8k boolean
---@field public_suggestions string
---@field refresh_in integer
---@field sku string
---@field snippy_load_test_enabled boolean
---@field telemetry string
---@field token string
---@field tracking_id string
---@field vsc_electron_fetcher boolean
---@field xcode boolean
---@field xcode_chat boolean
---
2024-08-20 19:00:41 -04:00
---@private
2024-08-20 14:24:33 -04:00
---@class AvanteCopilot: table<string, any>
---@field proxy string
---@field allow_insecure boolean
---@field token? CopilotToken
---@field github_token? string
---@field sessionid? string
---@field machineid? string
M.copilot = nil
2024-08-17 16:04:40 -04:00
---@class EnvironmentHandler: table<[Provider], string>
local E = {
2024-08-20 14:24:33 -04:00
---@type table<Provider, string | fun(): boolean>
2024-08-17 16:04:40 -04:00
env = {
openai = " OPENAI_API_KEY " ,
claude = " ANTHROPIC_API_KEY " ,
2024-08-18 17:54:29 +08:00
azure = " AZURE_OPENAI_API_KEY " ,
2024-08-18 21:33:45 +08:00
deepseek = " DEEPSEEK_API_KEY " ,
2024-08-18 12:11:39 -04:00
groq = " GROQ_API_KEY " ,
2024-08-21 17:52:25 +01:00
gemini = " GEMINI_API_KEY " ,
2024-08-20 14:24:33 -04:00
copilot = function ( )
2024-08-20 19:00:41 -04:00
if Utils.has ( " copilot.lua " ) or Utils.has ( " copilot.vim " ) or Utils.copilot . find_config_path ( ) then
2024-08-20 14:24:33 -04:00
return true
end
Utils.warn ( " copilot is not setup correctly. Please use copilot.lua or copilot.vim for authentication. " )
return false
end ,
2024-08-17 16:04:40 -04:00
} ,
}
2024-08-19 08:35:36 -04:00
setmetatable ( E , {
2024-08-17 16:04:40 -04:00
---@param k Provider
__index = function ( _ , k )
2024-08-20 14:24:33 -04:00
if E.is_local ( k ) then
return true
end
2024-08-18 22:20:29 -04:00
local builtins = E.env [ k ]
if builtins then
2024-08-20 14:24:33 -04:00
if type ( builtins ) == " function " then
return builtins ( )
2024-08-19 08:35:36 -04:00
end
2024-08-18 22:20:29 -04:00
return os.getenv ( builtins ) and true or false
end
2024-08-19 08:35:36 -04:00
---@type AvanteProvider | nil
2024-08-18 22:20:29 -04:00
local external = Config.vendors [ k ]
if external then
return os.getenv ( external.api_key_name ) and true or false
end
2024-08-17 16:04:40 -04:00
end ,
} )
2024-08-18 22:20:29 -04:00
---@private
2024-08-18 15:03:25 -04:00
E._once = false
2024-08-17 16:04:40 -04:00
2024-08-19 08:35:36 -04:00
---@param provider Provider
2024-08-18 22:20:29 -04:00
E.is_default = function ( provider )
return E.env [ provider ] and true or false
end
2024-08-20 14:24:33 -04:00
local AVANTE_INTERNAL_KEY = " __avante_internal "
2024-08-17 16:04:40 -04:00
--- return the environment variable name for the given provider
---@param provider? Provider
---@return string the envvar key
E.key = function ( provider )
2024-08-18 22:20:29 -04:00
provider = provider or Config.provider
if E.is_default ( provider ) then
2024-08-20 14:24:33 -04:00
local result = E.env [ provider ]
return type ( result ) == " function " and AVANTE_INTERNAL_KEY or result
2024-08-18 22:20:29 -04:00
end
2024-08-19 08:35:36 -04:00
---@type AvanteProvider | nil
2024-08-18 22:20:29 -04:00
local external = Config.vendors [ provider ]
if external then
return external.api_key_name
end
2024-08-20 14:24:33 -04:00
error ( " Failed to find provider: " .. provider , 2 )
2024-08-17 16:04:40 -04:00
end
2024-08-19 08:35:36 -04:00
---@param provider Provider
E.is_local = function ( provider )
if Config.options [ provider ] then
return Config.options [ provider ] [ " local " ]
elseif Config.vendors [ provider ] then
return Config.vendors [ provider ] [ " local " ]
else
return false
end
end
2024-08-18 15:03:25 -04:00
---@param provider? Provider
E.value = function ( provider )
2024-08-19 08:35:36 -04:00
if E.is_local ( provider or Config.provider ) then
2024-08-20 14:24:33 -04:00
return " __avante_dummy "
2024-08-19 08:35:36 -04:00
end
2024-08-18 15:03:25 -04:00
return os.getenv ( E.key ( provider or Config.provider ) )
end
2024-08-18 06:07:29 -04:00
2024-08-18 15:03:25 -04:00
--- intialize the environment variable for current neovim session.
--- This will only run once and spawn a UI for users to input the envvar.
2024-08-20 14:24:33 -04:00
---@param var string supported providers
2024-08-18 15:03:25 -04:00
---@param refresh? boolean
2024-08-18 22:20:29 -04:00
---@private
2024-08-18 15:03:25 -04:00
E.setup = function ( var , refresh )
2024-08-20 14:24:33 -04:00
if var == AVANTE_INTERNAL_KEY then
return
end
2024-08-18 15:03:25 -04:00
refresh = refresh or false
2024-08-17 16:04:40 -04:00
2024-08-18 06:07:29 -04:00
---@param value string
---@return nil
local function on_confirm ( value )
if value then
2024-08-17 16:04:40 -04:00
vim.fn . setenv ( var , value )
2024-08-18 06:07:29 -04:00
else
2024-08-17 16:04:40 -04:00
if not E [ Config.provider ] then
2024-08-19 08:35:36 -04:00
Utils.warn ( " Failed to set " .. var .. " . Avante won't work as expected " , { once = true , title = " Avante " } )
2024-08-17 16:04:40 -04:00
end
2024-08-18 06:07:29 -04:00
end
end
2024-08-17 16:04:40 -04:00
2024-08-18 15:03:25 -04:00
if refresh then
vim.defer_fn ( function ( )
Dressing.initialize_input_buffer ( { opts = { prompt = " Enter " .. var .. " : " } , on_confirm = on_confirm } )
end , 200 )
elseif not E._once then
E._once = true
api.nvim_create_autocmd ( { " BufEnter " , " BufWinEnter " } , {
pattern = " * " ,
once = true ,
callback = function ( )
vim.defer_fn ( function ( )
-- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
local exclude_buftypes = { " dashboard " , " alpha " , " qf " , " nofile " }
local exclude_filetypes = {
" NvimTree " ,
" Outline " ,
" help " ,
" dashboard " ,
" alpha " ,
" qf " ,
" ministarter " ,
" TelescopePrompt " ,
" gitcommit " ,
" gitrebase " ,
2024-08-19 11:52:52 -04:00
" DressingInput " ,
2024-08-18 15:03:25 -04:00
}
if
not vim.tbl_contains ( exclude_buftypes , vim.bo . buftype )
and not vim.tbl_contains ( exclude_filetypes , vim.bo . filetype )
then
2024-08-21 13:11:00 +08:00
Dressing.initialize_input_buffer ( {
opts = { prompt = " Enter " .. var .. " : " } ,
on_confirm = on_confirm ,
} )
2024-08-18 15:03:25 -04:00
end
end , 200 )
end ,
} )
end
2024-08-17 16:04:40 -04:00
end
2024-08-18 15:03:25 -04:00
------------------------------Prompt and type------------------------------
2024-08-15 19:04:15 +08:00
local system_prompt = [ [
You are an excellent programming expert .
] ]
local base_user_prompt = [ [
Your primary task is to suggest code modifications with precise line number ranges . Follow these instructions meticulously :
1. Carefully analyze the original code , paying close attention to its structure and line numbers . Line numbers start from 1 and include ALL lines , even empty ones .
2. When suggesting modifications :
2024-08-17 22:29:05 +08:00
a . Use the language in the question to reply . If there are non - English parts in the question , use the language of those parts .
b . Explain why the change is necessary or beneficial .
c . Provide the exact code snippet to be replaced using this format :
2024-08-15 19:04:15 +08:00
Replace lines : { { start_line } } - { { end_line } }
` ` ` { { language } }
{ { suggested_code } }
` ` `
2024-08-18 02:46:19 +08:00
3. Crucial guidelines for suggested code snippets :
- Only apply the change ( s ) suggested by the most recent assistant message ( before your generation ) .
- Do not make any unrelated changes to the code .
- Produce a valid full rewrite of the entire original file without skipping any lines . Do not be lazy !
- Do not arbitrarily delete pre - existing comments / empty Lines .
- Do not omit large parts of the original file for no reason .
- Do not omit any needed changes from the requisite messages / code blocks .
- If there is a clicked code block , bias towards just applying that ( and applying other changes implied ) .
2024-08-18 02:52:27 +08:00
- Please keep your suggested code changes minimal , and do not include irrelevant lines in the code snippet .
2024-08-18 02:46:19 +08:00
4. Crucial guidelines for line numbers :
2024-08-19 00:05:13 +08:00
- The content regarding line numbers MUST strictly follow the format " Replace lines: {{start_line}}-{{end_line}} " . Do not be lazy !
2024-08-15 19:04:15 +08:00
- The range { { start_line } } - { { end_line } } is INCLUSIVE . Both start_line and end_line are included in the replacement .
2024-08-18 04:47:47 +08:00
- Count EVERY line , including empty lines and comments lines , comments . Do not be lazy !
2024-08-15 19:04:15 +08:00
- For single - line changes , use the same number for start and end lines .
- For multi - line changes , ensure the range covers ALL affected lines , from the very first to the very last .
- Double - check that your line numbers align perfectly with the original code structure .
5. Final check :
- Review all suggestions , ensuring each line number is correct , especially the start_line and end_line .
- Confirm that no unrelated code is accidentally modified or deleted .
- Verify that the start_line and end_line correctly include all intended lines for replacement .
- Perform a final alignment check to ensure your line numbers haven ' t shifted, especially the start_line.
- Double - check that your line numbers align perfectly with the original code structure .
2024-08-18 02:52:27 +08:00
- Do not show the full content after these modifications .
2024-08-15 19:04:15 +08:00
Remember : Accurate line numbers are CRITICAL . The range start_line to end_line must include ALL lines to be replaced , from the very first to the very last . Double - check every range before finalizing your response , paying special attention to the start_line to ensure it hasn ' t shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift.
] ]
2024-08-21 14:50:40 +08:00
---@class AvanteHandlerOptions: table<[string], string>
---@field on_chunk AvanteChunkParser
---@field on_complete AvanteCompleteParser
---
2024-08-18 15:03:25 -04:00
---@class AvantePromptOptions: table<[string], string>
---@field question string
---@field code_lang string
---@field code_content string
---@field selected_code_content? string
---
2024-08-18 22:20:29 -04:00
---@class AvanteBaseMessage
---@field role "user" | "system"
---@field content string
---
---@class AvanteClaudeMessage: AvanteBaseMessage
---@field role "user"
---@field content {type: "text", text: string, cache_control?: {type: "ephemeral"}}[]
---
---@alias AvanteOpenAIMessage AvanteBaseMessage
---
2024-08-21 17:52:25 +01:00
---@class AvanteGeminiMessage
---@field role "user"
---@field parts { text: string }[]
---
---@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage | AvanteGeminiMessage
2024-08-18 22:20:29 -04:00
---
---@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): AvanteChatMessage[]
2024-08-18 15:03:25 -04:00
---
---@class AvanteCurlOutput: {url: string, body: table<string, any> | string, headers: table<string, string>}
---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
---
---@class ResponseParser
---@field on_chunk fun(chunk: string): any
---@field on_complete fun(err: string|nil): any
2024-08-19 06:11:02 -04:00
---@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil
2024-08-18 22:20:29 -04:00
---
2024-08-19 08:35:36 -04:00
---@class AvanteDefaultBaseProvider
2024-08-18 22:20:29 -04:00
---@field endpoint string
2024-08-19 08:35:36 -04:00
---@field local? boolean
---
---@class AvanteSupportedProvider: AvanteDefaultBaseProvider
2024-08-18 22:20:29 -04:00
---@field model string
2024-08-19 08:35:36 -04:00
---@field temperature number
---@field max_tokens number
---
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
---@field deployment string
---@field api_version string
---@field temperature number
---@field max_tokens number
---
2024-08-20 14:24:33 -04:00
---@class AvanteCopilotProvider: AvanteSupportedProvider
---@field proxy string | nil
---@field allow_insecure boolean
---@field timeout number
---
2024-08-21 17:52:25 +01:00
---@class AvanteGeminiProvider: AvanteDefaultBaseProvider
---@field model string
---@field type string
---@field options table
---
2024-08-19 08:35:36 -04:00
---@class AvanteProvider: AvanteDefaultBaseProvider
2024-08-20 07:43:53 -04:00
---@field model? string
2024-08-18 22:20:29 -04:00
---@field api_key_name string
---@field parse_response_data AvanteResponseParser
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
2024-08-21 10:14:30 -04:00
---@field parse_stream_data? fun(line: string, handler_opts: AvanteHandlerOptions): nil
2024-08-19 06:11:02 -04:00
---
---@alias AvanteChunkParser fun(chunk: string): any
---@alias AvanteCompleteParser fun(err: string|nil): nil
2024-08-18 15:03:25 -04:00
------------------------------Anthropic------------------------------
2024-08-18 22:20:29 -04:00
---@param opts AvantePromptOptions
---@return AvanteClaudeMessage[]
2024-08-20 07:43:53 -04:00
M.make_claude_message = function ( opts )
2024-08-15 19:04:15 +08:00
local code_prompt_obj = {
type = " text " ,
2024-08-18 15:03:25 -04:00
text = string.format ( " <code>```%s \n %s```</code> " , opts.code_lang , opts.code_content ) ,
2024-08-15 19:04:15 +08:00
}
2024-08-17 22:29:05 +08:00
if Tiktoken.count ( code_prompt_obj.text ) > 1024 then
code_prompt_obj.cache_control = { type = " ephemeral " }
end
2024-08-18 15:03:25 -04:00
if opts.selected_code_content then
code_prompt_obj.text = string.format ( " <code_context>```%s \n %s```</code_context> " , opts.code_lang , opts.code_content )
2024-08-17 22:29:05 +08:00
end
local message_content = {
code_prompt_obj ,
}
2024-08-18 15:03:25 -04:00
if opts.selected_code_content then
2024-08-17 22:29:05 +08:00
local selected_code_obj = {
type = " text " ,
2024-08-18 15:03:25 -04:00
text = string.format ( " <code>```%s \n %s```</code> " , opts.code_lang , opts.selected_code_content ) ,
2024-08-17 22:29:05 +08:00
}
if Tiktoken.count ( selected_code_obj.text ) > 1024 then
selected_code_obj.cache_control = { type = " ephemeral " }
end
table.insert ( message_content , selected_code_obj )
end
table.insert ( message_content , {
type = " text " ,
2024-08-18 15:03:25 -04:00
text = string.format ( " <question>%s</question> " , opts.question ) ,
2024-08-17 22:29:05 +08:00
} )
local user_prompt = base_user_prompt
2024-08-15 19:04:15 +08:00
local user_prompt_obj = {
type = " text " ,
text = user_prompt ,
}
2024-08-17 15:14:30 +08:00
if Tiktoken.count ( user_prompt_obj.text ) > 1024 then
2024-08-15 19:04:15 +08:00
user_prompt_obj.cache_control = { type = " ephemeral " }
end
2024-08-17 22:29:05 +08:00
table.insert ( message_content , user_prompt_obj )
2024-08-18 15:03:25 -04:00
return {
{
role = " user " ,
content = message_content ,
2024-08-15 19:04:15 +08:00
} ,
}
2024-08-18 15:03:25 -04:00
end
2024-08-15 19:04:15 +08:00
2024-08-18 22:20:29 -04:00
---@type AvanteResponseParser
2024-08-20 07:43:53 -04:00
M.parse_claude_response = function ( data_stream , event_state , opts )
2024-08-19 06:11:02 -04:00
if event_state == " content_block_delta " then
2024-08-21 21:28:17 +08:00
local ok , json = pcall ( vim.json . decode , data_stream )
if not ok then
return
end
2024-08-18 15:03:25 -04:00
opts.on_chunk ( json.delta . text )
2024-08-19 06:11:02 -04:00
elseif event_state == " message_stop " then
2024-08-18 15:03:25 -04:00
opts.on_complete ( nil )
return
2024-08-19 06:11:02 -04:00
elseif event_state == " error " then
2024-08-18 15:03:25 -04:00
opts.on_complete ( vim.json . decode ( data_stream ) )
end
end
2024-08-15 19:04:15 +08:00
2024-08-18 15:03:25 -04:00
---@type AvanteCurlArgsBuilder
2024-08-20 07:43:53 -04:00
M.make_claude_curl_args = function ( code_opts )
2024-08-18 15:03:25 -04:00
return {
url = Utils.trim ( Config.claude . endpoint , { suffix = " / " } ) .. " /v1/messages " ,
headers = {
[ " Content-Type " ] = " application/json " ,
[ " x-api-key " ] = E.value ( " claude " ) ,
[ " anthropic-version " ] = " 2023-06-01 " ,
[ " anthropic-beta " ] = " prompt-caching-2024-07-31 " ,
} ,
body = {
model = Config.claude . model ,
system = system_prompt ,
stream = true ,
2024-08-20 07:43:53 -04:00
messages = M.make_claude_message ( code_opts ) ,
2024-08-18 15:03:25 -04:00
temperature = Config.claude . temperature ,
max_tokens = Config.claude . max_tokens ,
} ,
}
2024-08-15 19:04:15 +08:00
end
2024-08-18 15:03:25 -04:00
------------------------------OpenAI------------------------------
2024-08-18 22:20:29 -04:00
---@param opts AvantePromptOptions
---@return AvanteOpenAIMessage[]
2024-08-20 07:43:53 -04:00
M.make_openai_message = function ( opts )
2024-08-15 19:04:15 +08:00
local user_prompt = base_user_prompt
.. " \n \n CODE: \n "
.. " ``` "
2024-08-18 15:03:25 -04:00
.. opts.code_lang
2024-08-15 19:04:15 +08:00
.. " \n "
2024-08-18 15:03:25 -04:00
.. opts.code_content
2024-08-15 19:04:15 +08:00
.. " \n ``` "
2024-08-17 22:29:05 +08:00
.. " \n \n QUESTION: \n "
2024-08-18 15:03:25 -04:00
.. opts.question
2024-08-17 22:29:05 +08:00
2024-08-18 15:03:25 -04:00
if opts.selected_code_content ~= nil then
2024-08-17 22:29:05 +08:00
user_prompt = base_user_prompt
.. " \n \n CODE CONTEXT: \n "
.. " ``` "
2024-08-18 15:03:25 -04:00
.. opts.code_lang
2024-08-17 22:29:05 +08:00
.. " \n "
2024-08-18 15:03:25 -04:00
.. opts.code_content
2024-08-17 22:29:05 +08:00
.. " \n ``` "
.. " \n \n CODE: \n "
.. " ``` "
2024-08-18 15:03:25 -04:00
.. opts.code_lang
2024-08-17 22:29:05 +08:00
.. " \n "
2024-08-18 15:03:25 -04:00
.. opts.selected_code_content
2024-08-17 22:29:05 +08:00
.. " \n ``` "
.. " \n \n QUESTION: \n "
2024-08-18 15:03:25 -04:00
.. opts.question
end
return {
{ role = " system " , content = system_prompt } ,
{ role = " user " , content = user_prompt } ,
}
end
2024-08-18 22:20:29 -04:00
---@type AvanteResponseParser
2024-08-20 07:43:53 -04:00
M.parse_openai_response = function ( data_stream , _ , opts )
2024-08-18 15:03:25 -04:00
if data_stream : match ( ' "%[DONE%]": ' ) then
opts.on_complete ( nil )
return
2024-08-17 22:29:05 +08:00
end
2024-08-18 15:03:25 -04:00
if data_stream : match ( ' "delta": ' ) then
local json = vim.json . decode ( data_stream )
if json.choices and json.choices [ 1 ] then
local choice = json.choices [ 1 ]
if choice.finish_reason == " stop " then
opts.on_complete ( nil )
elseif choice.delta . content then
opts.on_chunk ( choice.delta . content )
end
end
end
end
2024-08-15 19:04:15 +08:00
2024-08-18 15:03:25 -04:00
---@type AvanteCurlArgsBuilder
2024-08-20 07:43:53 -04:00
M.make_openai_curl_args = function ( code_opts )
2024-08-18 15:03:25 -04:00
return {
url = Utils.trim ( Config.openai . endpoint , { suffix = " / " } ) .. " /v1/chat/completions " ,
headers = {
[ " Content-Type " ] = " application/json " ,
[ " Authorization " ] = " Bearer " .. E.value ( " openai " ) ,
} ,
body = {
model = Config.openai . model ,
2024-08-20 07:43:53 -04:00
messages = M.make_openai_message ( code_opts ) ,
2024-08-18 15:03:25 -04:00
temperature = Config.openai . temperature ,
max_tokens = Config.openai . max_tokens ,
stream = true ,
} ,
}
end
2024-08-20 14:24:33 -04:00
------------------------------Copilot------------------------------
---@type AvanteAiMessageBuilder
M.make_copilot_message = M.make_openai_message
---@type AvanteResponseParser
M.parse_copilot_response = M.parse_openai_response
---@type AvanteCurlArgsBuilder
M.make_copilot_curl_args = function ( code_opts )
local github_token = Utils.copilot . cached_token ( )
if not github_token then
error (
" No GitHub token found, please use `:Copilot auth` to setup with `copilot.lua` or `:Copilot setup` with `copilot.vim` "
)
end
local on_done = function ( )
return {
url = Utils.trim ( Config.copilot . endpoint , { suffix = " / " } ) .. " /chat/completions " ,
proxy = Config.copilot . proxy ,
insecure = Config.copilot . allow_insecure ,
headers = Utils.copilot . generate_headers ( M.copilot . token.token , M.copilot . sessionid , M.copilot . machineid ) ,
body = {
mode = Config.copilot . model ,
n = 1 ,
top_p = 1 ,
stream = true ,
temperature = Config.copilot . temperature ,
max_tokens = Config.copilot . max_tokens ,
messages = M.make_copilot_message ( code_opts ) ,
} ,
}
end
local result = nil
if not M.copilot . token or ( M.copilot . token.expires_at and M.copilot . token.expires_at <= math.floor ( os.time ( ) ) ) then
local sessionid = Utils.copilot . uuid ( ) .. tostring ( math.floor ( os.time ( ) * 1000 ) )
local url = " https://api.github.com/copilot_internal/v2/token "
local headers = {
[ " Authorization " ] = " token " .. github_token ,
[ " Accept " ] = " application/json " ,
}
for key , value in pairs ( Utils.copilot . version_headers ) do
headers [ key ] = value
end
local response = curl.get ( url , {
timeout = Config.copilot . timeout ,
headers = headers ,
proxy = M.copilot . proxy ,
insecure = M.copilot . allow_insecure ,
on_error = function ( err )
error ( " Failed to get response: " .. vim.inspect ( err ) )
end ,
} )
M.copilot . sessionid = sessionid
M.copilot . token = vim.json . decode ( response.body )
result = on_done ( )
else
result = on_done ( )
end
return result
end
2024-08-18 15:03:25 -04:00
------------------------------Azure------------------------------
---@type AvanteAiMessageBuilder
2024-08-20 07:43:53 -04:00
M.make_azure_message = M.make_openai_message
2024-08-18 15:03:25 -04:00
2024-08-18 22:20:29 -04:00
---@type AvanteResponseParser
2024-08-20 07:43:53 -04:00
M.parse_azure_response = M.parse_openai_response
2024-08-18 15:03:25 -04:00
---@type AvanteCurlArgsBuilder
2024-08-20 07:43:53 -04:00
M.make_azure_curl_args = function ( code_opts )
2024-08-18 15:03:25 -04:00
return {
2024-08-17 15:14:30 +08:00
url = Config.azure . endpoint
2024-08-15 19:04:15 +08:00
.. " /openai/deployments/ "
2024-08-17 15:14:30 +08:00
.. Config.azure . deployment
2024-08-15 19:04:15 +08:00
.. " /chat/completions?api-version= "
2024-08-18 15:03:25 -04:00
.. Config.azure . api_version ,
2024-08-15 19:04:15 +08:00
headers = {
[ " Content-Type " ] = " application/json " ,
2024-08-18 15:03:25 -04:00
[ " api-key " ] = E.value ( " azure " ) ,
} ,
2024-08-15 19:04:15 +08:00
body = {
2024-08-20 07:43:53 -04:00
messages = M.make_openai_message ( code_opts ) ,
2024-08-17 15:14:30 +08:00
temperature = Config.azure . temperature ,
max_tokens = Config.azure . max_tokens ,
2024-08-15 19:04:15 +08:00
stream = true ,
2024-08-18 15:03:25 -04:00
} ,
}
end
------------------------------Deepseek------------------------------
---@type AvanteAiMessageBuilder
2024-08-20 07:43:53 -04:00
M.make_deepseek_message = M.make_openai_message
2024-08-18 15:03:25 -04:00
2024-08-18 22:20:29 -04:00
---@type AvanteResponseParser
2024-08-20 07:43:53 -04:00
M.parse_deepseek_response = M.parse_openai_response
2024-08-18 15:03:25 -04:00
---@type AvanteCurlArgsBuilder
2024-08-20 07:43:53 -04:00
M.make_deepseek_curl_args = function ( code_opts )
2024-08-18 15:03:25 -04:00
return {
url = Utils.trim ( Config.deepseek . endpoint , { suffix = " / " } ) .. " /chat/completions " ,
2024-08-18 21:33:45 +08:00
headers = {
[ " Content-Type " ] = " application/json " ,
2024-08-18 15:03:25 -04:00
[ " Authorization " ] = " Bearer " .. E.value ( " deepseek " ) ,
} ,
2024-08-18 21:33:45 +08:00
body = {
model = Config.deepseek . model ,
2024-08-20 07:43:53 -04:00
messages = M.make_openai_message ( code_opts ) ,
2024-08-18 21:33:45 +08:00
temperature = Config.deepseek . temperature ,
max_tokens = Config.deepseek . max_tokens ,
stream = true ,
2024-08-18 15:03:25 -04:00
} ,
}
end
------------------------------Grok------------------------------
---@type AvanteAiMessageBuilder
2024-08-20 07:43:53 -04:00
M.make_groq_message = M.make_openai_message
2024-08-18 15:03:25 -04:00
2024-08-18 22:20:29 -04:00
---@type AvanteResponseParser
2024-08-20 07:43:53 -04:00
M.parse_groq_response = M.parse_openai_response
2024-08-18 15:03:25 -04:00
---@type AvanteCurlArgsBuilder
2024-08-20 07:43:53 -04:00
M.make_groq_curl_args = function ( code_opts )
2024-08-18 15:03:25 -04:00
return {
url = Utils.trim ( Config.groq . endpoint , { suffix = " / " } ) .. " /openai/v1/chat/completions " ,
2024-08-18 12:11:39 -04:00
headers = {
[ " Content-Type " ] = " application/json " ,
2024-08-18 15:03:25 -04:00
[ " Authorization " ] = " Bearer " .. E.value ( " groq " ) ,
} ,
2024-08-18 12:11:39 -04:00
body = {
model = Config.groq . model ,
2024-08-20 07:43:53 -04:00
messages = M.make_openai_message ( code_opts ) ,
2024-08-18 12:11:39 -04:00
temperature = Config.groq . temperature ,
max_tokens = Config.groq . max_tokens ,
stream = true ,
2024-08-18 15:03:25 -04:00
} ,
}
end
2024-08-21 17:52:25 +01:00
------------------------------Gemini------------------------------
---@param opts AvantePromptOptions
---@return AvanteGeminiMessage[]
M.make_gemini_message = function ( opts )
local code_prompt_obj = {
text = string.format ( " <code>```%s \n %s```</code> " , opts.code_lang , opts.code_content ) ,
}
if opts.selected_code_content then
code_prompt_obj.text = string.format ( " <code_context>```%s \n %s```</code_context> " , opts.code_lang , opts.code_content )
end
-- parts ready
local message_content = {
code_prompt_obj ,
}
if opts.selected_code_content then
local selected_code_obj = {
text = string.format ( " <code>```%s \n %s```</code> " , opts.code_lang , opts.selected_code_content ) ,
}
table.insert ( message_content , selected_code_obj )
end
-- insert a part into parts
table.insert ( message_content , {
text = string.format ( " <question>%s</question> " , opts.question ) ,
} )
-- local user_prompt_obj = {
-- text = base_user_prompt,
-- }
-- insert another part into parts
-- table.insert(message_content, user_prompt_obj)
return {
{
role = " user " ,
parts = message_content ,
} ,
}
end
---@type AvanteResponseParser
M.parse_gemini_response = function ( data_stream , event_state , opts )
local json = vim.json . decode ( data_stream )
opts.on_chunk ( json.candidates [ 1 ] . content.parts [ 1 ] . text )
end
---@type AvanteCurlArgsBuilder
M.make_gemini_curl_args = function ( code_opts )
local endpoint = " "
if Config.gemini . endpoint == " " then
endpoint = " https://generativelanguage.googleapis.com/v1beta/models/ "
.. Config.gemini . model
.. " :streamGenerateContent?alt=sse&key= "
.. E.value ( " gemini " )
end
-- Prepare the body with contents and options (only if options are not empty)
local body = {
systemInstruction = {
role = " user " ,
parts = {
{
text = system_prompt .. base_user_prompt ,
} ,
} ,
} ,
contents = M.make_gemini_message ( code_opts ) ,
}
if next ( Config.gemini . options ) ~= nil then -- Check if options table is not empty
for k , v in pairs ( Config.gemini . options ) do
body [ k ] = v
end
end
return {
url = endpoint ,
headers = {
[ " Content-Type " ] = " application/json " ,
} ,
body = body ,
}
end
2024-08-18 15:03:25 -04:00
------------------------------Logic------------------------------
2024-08-18 22:20:29 -04:00
local group = vim.api . nvim_create_augroup ( " AvanteLLM " , { clear = true } )
2024-08-18 15:03:25 -04:00
local active_job = nil
---@param question string
---@param code_lang string
---@param code_content string
---@param selected_content_content string | nil
2024-08-19 06:11:02 -04:00
---@param on_chunk AvanteChunkParser
---@param on_complete AvanteCompleteParser
2024-08-18 22:20:29 -04:00
M.stream = function ( question , code_lang , code_content , selected_content_content , on_chunk , on_complete )
2024-08-18 15:03:25 -04:00
local provider = Config.provider
2024-08-18 22:20:29 -04:00
local code_opts = {
2024-08-18 15:03:25 -04:00
question = question ,
code_lang = code_lang ,
code_content = code_content ,
selected_code_content = selected_content_content ,
2024-08-18 22:20:29 -04:00
}
2024-08-19 06:11:02 -04:00
local current_event_state = nil
local handler_opts = { on_chunk = on_chunk , on_complete = on_complete }
2024-08-18 22:20:29 -04:00
---@type AvanteCurlOutput
local spec = nil
---@type AvanteProvider
local ProviderConfig = nil
if E.is_default ( provider ) then
2024-08-20 07:43:53 -04:00
spec = M [ " make_ " .. provider .. " _curl_args " ] ( code_opts )
2024-08-18 22:20:29 -04:00
else
ProviderConfig = Config.vendors [ provider ]
spec = ProviderConfig.parse_curl_args ( ProviderConfig , code_opts )
end
2024-08-18 15:03:25 -04:00
---@param line string
local function parse_and_call ( line )
local event = line : match ( " ^event: (.+)$ " )
if event then
2024-08-19 06:11:02 -04:00
current_event_state = event
2024-08-18 15:03:25 -04:00
return
end
local data_match = line : match ( " ^data: (.+)$ " )
if data_match then
2024-08-18 22:20:29 -04:00
if ProviderConfig ~= nil then
2024-08-19 06:11:02 -04:00
ProviderConfig.parse_response_data ( data_match , current_event_state , handler_opts )
2024-08-18 22:20:29 -04:00
else
2024-08-20 07:43:53 -04:00
M [ " parse_ " .. provider .. " _response " ] ( data_match , current_event_state , handler_opts )
2024-08-18 22:20:29 -04:00
end
2024-08-18 15:03:25 -04:00
end
end
if active_job then
active_job : shutdown ( )
active_job = nil
2024-08-15 19:04:15 +08:00
end
2024-08-18 15:03:25 -04:00
active_job = curl.post ( spec.url , {
headers = spec.headers ,
body = vim.json . encode ( spec.body ) ,
stream = function ( err , data , _ )
2024-08-15 19:04:15 +08:00
if err then
on_complete ( err )
return
end
if not data then
return
end
2024-08-18 15:03:25 -04:00
vim.schedule ( function ( )
2024-08-21 13:11:00 +08:00
if ProviderConfig ~= nil and ProviderConfig.parse_stream_data ~= nil then
2024-08-20 07:54:58 -04:00
if ProviderConfig.parse_response_data ~= nil then
Utils.warn (
" parse_stream_data and parse_response_data are mutually exclusive, and thus parse_response_data will be ignored. Make sure that you handle the incoming data correctly. " ,
{ once = true }
)
end
2024-08-21 14:50:40 +08:00
ProviderConfig.parse_stream_data ( data , handler_opts )
2024-08-20 07:54:58 -04:00
else
parse_and_call ( data )
end
2024-08-18 15:03:25 -04:00
end )
end ,
on_error = function ( err )
on_complete ( err )
end ,
callback = function ( _ )
active_job = nil
end ,
} )
api.nvim_create_autocmd ( " User " , {
group = group ,
pattern = M.CANCEL_PATTERN ,
callback = function ( )
if active_job then
active_job : shutdown ( )
2024-08-19 05:40:57 -04:00
Utils.debug ( " LLM request cancelled " , { title = " Avante " } )
2024-08-18 15:03:25 -04:00
active_job = nil
2024-08-15 19:04:15 +08:00
end
end ,
} )
2024-08-18 15:03:25 -04:00
return active_job
2024-08-15 19:04:15 +08:00
end
2024-08-21 21:28:17 +08:00
---@public
2024-08-17 16:04:40 -04:00
function M . setup ( )
2024-08-20 14:24:33 -04:00
if Config.provider == " copilot " and not M.copilot then
M.copilot = {
proxy = Config.copilot . proxy ,
allow_insecure = Config.copilot . allow_insecure ,
github_token = Utils.copilot . cached_token ( ) ,
sessionid = nil ,
token = nil ,
machineid = Utils.copilot . machine_id ( ) ,
}
end
2024-08-17 16:04:40 -04:00
local has = E [ Config.provider ]
if not has then
E.setup ( E.key ( ) )
end
2024-08-18 15:03:25 -04:00
M.commands ( )
end
---@param provider Provider
function M . refresh ( provider )
local has = E [ provider ]
if not has then
E.setup ( E.key ( provider ) , true )
else
2024-08-19 05:40:57 -04:00
Utils.info ( " Switch to provider: " .. provider , { once = true , title = " Avante " } )
2024-08-18 15:03:25 -04:00
end
2024-08-18 22:20:29 -04:00
require ( " avante.config " ) . override ( { provider = provider } )
2024-08-18 15:03:25 -04:00
end
2024-08-20 07:43:53 -04:00
---@private
2024-08-18 15:03:25 -04:00
M.commands = function ( )
api.nvim_create_user_command ( " AvanteSwitchProvider " , function ( args )
local cmd = vim.trim ( args.args or " " )
M.refresh ( cmd )
end , {
nargs = 1 ,
desc = " avante: switch provider " ,
complete = function ( _ , line )
if line : match ( " ^%s*AvanteSwitchProvider %w " ) then
return { }
end
local prefix = line : match ( " ^%s*AvanteSwitchProvider (%w*) " ) or " "
2024-08-18 22:20:29 -04:00
-- join two tables
local Keys = vim.list_extend ( vim.tbl_keys ( E.env ) , vim.tbl_keys ( Config.vendors ) )
2024-08-18 15:03:25 -04:00
return vim.tbl_filter ( function ( key )
return key : find ( prefix ) == 1
2024-08-18 22:20:29 -04:00
end , Keys )
2024-08-18 15:03:25 -04:00
end ,
} )
2024-08-17 16:04:40 -04:00
end
2024-08-19 16:10:06 -04:00
M.SYSTEM_PROMPT = system_prompt
M.BASE_PROMPT = base_user_prompt
2024-08-20 07:43:53 -04:00
return M