chore(llm): expose types for support functions (#113)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
00f1e296b0
commit
94adc992a6
@ -7,10 +7,6 @@ local Config = require("avante.config")
|
|||||||
local Tiktoken = require("avante.tiktoken")
|
local Tiktoken = require("avante.tiktoken")
|
||||||
local Dressing = require("avante.ui.dressing")
|
local Dressing = require("avante.ui.dressing")
|
||||||
|
|
||||||
---@private
|
|
||||||
---@class AvanteLLMInternal
|
|
||||||
local H = {}
|
|
||||||
|
|
||||||
---@class avante.LLM
|
---@class avante.LLM
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
@ -248,6 +244,7 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
|
|||||||
---@field max_tokens number
|
---@field max_tokens number
|
||||||
---
|
---
|
||||||
---@class AvanteProvider: AvanteDefaultBaseProvider
|
---@class AvanteProvider: AvanteDefaultBaseProvider
|
||||||
|
---@field model? string
|
||||||
---@field api_key_name string
|
---@field api_key_name string
|
||||||
---@field parse_response_data AvanteResponseParser
|
---@field parse_response_data AvanteResponseParser
|
||||||
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
|
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||||
@ -259,7 +256,7 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
|
|||||||
|
|
||||||
---@param opts AvantePromptOptions
|
---@param opts AvantePromptOptions
|
||||||
---@return AvanteClaudeMessage[]
|
---@return AvanteClaudeMessage[]
|
||||||
H.make_claude_message = function(opts)
|
M.make_claude_message = function(opts)
|
||||||
local code_prompt_obj = {
|
local code_prompt_obj = {
|
||||||
type = "text",
|
type = "text",
|
||||||
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
|
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
|
||||||
@ -317,7 +314,7 @@ H.make_claude_message = function(opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@type AvanteResponseParser
|
---@type AvanteResponseParser
|
||||||
H.parse_claude_response = function(data_stream, event_state, opts)
|
M.parse_claude_response = function(data_stream, event_state, opts)
|
||||||
if event_state == "content_block_delta" then
|
if event_state == "content_block_delta" then
|
||||||
local json = vim.json.decode(data_stream)
|
local json = vim.json.decode(data_stream)
|
||||||
opts.on_chunk(json.delta.text)
|
opts.on_chunk(json.delta.text)
|
||||||
@ -330,7 +327,7 @@ H.parse_claude_response = function(data_stream, event_state, opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@type AvanteCurlArgsBuilder
|
---@type AvanteCurlArgsBuilder
|
||||||
H.make_claude_curl_args = function(code_opts)
|
M.make_claude_curl_args = function(code_opts)
|
||||||
return {
|
return {
|
||||||
url = Utils.trim(Config.claude.endpoint, { suffix = "/" }) .. "/v1/messages",
|
url = Utils.trim(Config.claude.endpoint, { suffix = "/" }) .. "/v1/messages",
|
||||||
headers = {
|
headers = {
|
||||||
@ -343,7 +340,7 @@ H.make_claude_curl_args = function(code_opts)
|
|||||||
model = Config.claude.model,
|
model = Config.claude.model,
|
||||||
system = system_prompt,
|
system = system_prompt,
|
||||||
stream = true,
|
stream = true,
|
||||||
messages = H.make_claude_message(code_opts),
|
messages = M.make_claude_message(code_opts),
|
||||||
temperature = Config.claude.temperature,
|
temperature = Config.claude.temperature,
|
||||||
max_tokens = Config.claude.max_tokens,
|
max_tokens = Config.claude.max_tokens,
|
||||||
},
|
},
|
||||||
@ -354,7 +351,7 @@ end
|
|||||||
|
|
||||||
---@param opts AvantePromptOptions
|
---@param opts AvantePromptOptions
|
||||||
---@return AvanteOpenAIMessage[]
|
---@return AvanteOpenAIMessage[]
|
||||||
H.make_openai_message = function(opts)
|
M.make_openai_message = function(opts)
|
||||||
local user_prompt = base_user_prompt
|
local user_prompt = base_user_prompt
|
||||||
.. "\n\nCODE:\n"
|
.. "\n\nCODE:\n"
|
||||||
.. "```"
|
.. "```"
|
||||||
@ -390,7 +387,7 @@ H.make_openai_message = function(opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@type AvanteResponseParser
|
---@type AvanteResponseParser
|
||||||
H.parse_openai_response = function(data_stream, _, opts)
|
M.parse_openai_response = function(data_stream, _, opts)
|
||||||
if data_stream:match('"%[DONE%]":') then
|
if data_stream:match('"%[DONE%]":') then
|
||||||
opts.on_complete(nil)
|
opts.on_complete(nil)
|
||||||
return
|
return
|
||||||
@ -409,7 +406,7 @@ H.parse_openai_response = function(data_stream, _, opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@type AvanteCurlArgsBuilder
|
---@type AvanteCurlArgsBuilder
|
||||||
H.make_openai_curl_args = function(code_opts)
|
M.make_openai_curl_args = function(code_opts)
|
||||||
return {
|
return {
|
||||||
url = Utils.trim(Config.openai.endpoint, { suffix = "/" }) .. "/v1/chat/completions",
|
url = Utils.trim(Config.openai.endpoint, { suffix = "/" }) .. "/v1/chat/completions",
|
||||||
headers = {
|
headers = {
|
||||||
@ -418,7 +415,7 @@ H.make_openai_curl_args = function(code_opts)
|
|||||||
},
|
},
|
||||||
body = {
|
body = {
|
||||||
model = Config.openai.model,
|
model = Config.openai.model,
|
||||||
messages = H.make_openai_message(code_opts),
|
messages = M.make_openai_message(code_opts),
|
||||||
temperature = Config.openai.temperature,
|
temperature = Config.openai.temperature,
|
||||||
max_tokens = Config.openai.max_tokens,
|
max_tokens = Config.openai.max_tokens,
|
||||||
stream = true,
|
stream = true,
|
||||||
@ -429,13 +426,13 @@ end
|
|||||||
------------------------------Azure------------------------------
|
------------------------------Azure------------------------------
|
||||||
|
|
||||||
---@type AvanteAiMessageBuilder
|
---@type AvanteAiMessageBuilder
|
||||||
H.make_azure_message = H.make_openai_message
|
M.make_azure_message = M.make_openai_message
|
||||||
|
|
||||||
---@type AvanteResponseParser
|
---@type AvanteResponseParser
|
||||||
H.parse_azure_response = H.parse_openai_response
|
M.parse_azure_response = M.parse_openai_response
|
||||||
|
|
||||||
---@type AvanteCurlArgsBuilder
|
---@type AvanteCurlArgsBuilder
|
||||||
H.make_azure_curl_args = function(code_opts)
|
M.make_azure_curl_args = function(code_opts)
|
||||||
return {
|
return {
|
||||||
url = Config.azure.endpoint
|
url = Config.azure.endpoint
|
||||||
.. "/openai/deployments/"
|
.. "/openai/deployments/"
|
||||||
@ -447,7 +444,7 @@ H.make_azure_curl_args = function(code_opts)
|
|||||||
["api-key"] = E.value("azure"),
|
["api-key"] = E.value("azure"),
|
||||||
},
|
},
|
||||||
body = {
|
body = {
|
||||||
messages = H.make_openai_message(code_opts),
|
messages = M.make_openai_message(code_opts),
|
||||||
temperature = Config.azure.temperature,
|
temperature = Config.azure.temperature,
|
||||||
max_tokens = Config.azure.max_tokens,
|
max_tokens = Config.azure.max_tokens,
|
||||||
stream = true,
|
stream = true,
|
||||||
@ -458,13 +455,13 @@ end
|
|||||||
------------------------------Deepseek------------------------------
|
------------------------------Deepseek------------------------------
|
||||||
|
|
||||||
---@type AvanteAiMessageBuilder
|
---@type AvanteAiMessageBuilder
|
||||||
H.make_deepseek_message = H.make_openai_message
|
M.make_deepseek_message = M.make_openai_message
|
||||||
|
|
||||||
---@type AvanteResponseParser
|
---@type AvanteResponseParser
|
||||||
H.parse_deepseek_response = H.parse_openai_response
|
M.parse_deepseek_response = M.parse_openai_response
|
||||||
|
|
||||||
---@type AvanteCurlArgsBuilder
|
---@type AvanteCurlArgsBuilder
|
||||||
H.make_deepseek_curl_args = function(code_opts)
|
M.make_deepseek_curl_args = function(code_opts)
|
||||||
return {
|
return {
|
||||||
url = Utils.trim(Config.deepseek.endpoint, { suffix = "/" }) .. "/chat/completions",
|
url = Utils.trim(Config.deepseek.endpoint, { suffix = "/" }) .. "/chat/completions",
|
||||||
headers = {
|
headers = {
|
||||||
@ -473,7 +470,7 @@ H.make_deepseek_curl_args = function(code_opts)
|
|||||||
},
|
},
|
||||||
body = {
|
body = {
|
||||||
model = Config.deepseek.model,
|
model = Config.deepseek.model,
|
||||||
messages = H.make_openai_message(code_opts),
|
messages = M.make_openai_message(code_opts),
|
||||||
temperature = Config.deepseek.temperature,
|
temperature = Config.deepseek.temperature,
|
||||||
max_tokens = Config.deepseek.max_tokens,
|
max_tokens = Config.deepseek.max_tokens,
|
||||||
stream = true,
|
stream = true,
|
||||||
@ -484,13 +481,13 @@ end
|
|||||||
------------------------------Grok------------------------------
|
------------------------------Grok------------------------------
|
||||||
|
|
||||||
---@type AvanteAiMessageBuilder
|
---@type AvanteAiMessageBuilder
|
||||||
H.make_groq_message = H.make_openai_message
|
M.make_groq_message = M.make_openai_message
|
||||||
|
|
||||||
---@type AvanteResponseParser
|
---@type AvanteResponseParser
|
||||||
H.parse_groq_response = H.parse_openai_response
|
M.parse_groq_response = M.parse_openai_response
|
||||||
|
|
||||||
---@type AvanteCurlArgsBuilder
|
---@type AvanteCurlArgsBuilder
|
||||||
H.make_groq_curl_args = function(code_opts)
|
M.make_groq_curl_args = function(code_opts)
|
||||||
return {
|
return {
|
||||||
url = Utils.trim(Config.groq.endpoint, { suffix = "/" }) .. "/openai/v1/chat/completions",
|
url = Utils.trim(Config.groq.endpoint, { suffix = "/" }) .. "/openai/v1/chat/completions",
|
||||||
headers = {
|
headers = {
|
||||||
@ -499,7 +496,7 @@ H.make_groq_curl_args = function(code_opts)
|
|||||||
},
|
},
|
||||||
body = {
|
body = {
|
||||||
model = Config.groq.model,
|
model = Config.groq.model,
|
||||||
messages = H.make_openai_message(code_opts),
|
messages = M.make_openai_message(code_opts),
|
||||||
temperature = Config.groq.temperature,
|
temperature = Config.groq.temperature,
|
||||||
max_tokens = Config.groq.max_tokens,
|
max_tokens = Config.groq.max_tokens,
|
||||||
stream = true,
|
stream = true,
|
||||||
@ -537,7 +534,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
|||||||
local ProviderConfig = nil
|
local ProviderConfig = nil
|
||||||
|
|
||||||
if E.is_default(provider) then
|
if E.is_default(provider) then
|
||||||
spec = H["make_" .. provider .. "_curl_args"](code_opts)
|
spec = M["make_" .. provider .. "_curl_args"](code_opts)
|
||||||
else
|
else
|
||||||
ProviderConfig = Config.vendors[provider]
|
ProviderConfig = Config.vendors[provider]
|
||||||
spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
|
spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
|
||||||
@ -555,7 +552,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
|||||||
if ProviderConfig ~= nil then
|
if ProviderConfig ~= nil then
|
||||||
ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts)
|
ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts)
|
||||||
else
|
else
|
||||||
H["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts)
|
M["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -603,6 +600,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
|||||||
return active_job
|
return active_job
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---@private
|
||||||
function M.setup()
|
function M.setup()
|
||||||
local has = E[Config.provider]
|
local has = E[Config.provider]
|
||||||
if not has then
|
if not has then
|
||||||
@ -623,6 +621,7 @@ function M.refresh(provider)
|
|||||||
require("avante.config").override({ provider = provider })
|
require("avante.config").override({ provider = provider })
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---@private
|
||||||
M.commands = function()
|
M.commands = function()
|
||||||
api.nvim_create_user_command("AvanteSwitchProvider", function(args)
|
api.nvim_create_user_command("AvanteSwitchProvider", function(args)
|
||||||
local cmd = vim.trim(args.args or "")
|
local cmd = vim.trim(args.args or "")
|
||||||
@ -647,16 +646,4 @@ end
|
|||||||
M.SYSTEM_PROMPT = system_prompt
|
M.SYSTEM_PROMPT = system_prompt
|
||||||
M.BASE_PROMPT = base_user_prompt
|
M.BASE_PROMPT = base_user_prompt
|
||||||
|
|
||||||
return setmetatable(M, {
|
return M
|
||||||
__index = function(t, k)
|
|
||||||
local h = H[k]
|
|
||||||
if h then
|
|
||||||
return H[k]
|
|
||||||
end
|
|
||||||
local v = t[k]
|
|
||||||
if v then
|
|
||||||
return t[k]
|
|
||||||
end
|
|
||||||
error("Failed to find key: " .. k)
|
|
||||||
end,
|
|
||||||
})
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user