chore(llm): expose types for support functions (#113)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
00f1e296b0
commit
94adc992a6
@ -7,10 +7,6 @@ local Config = require("avante.config")
|
||||
local Tiktoken = require("avante.tiktoken")
|
||||
local Dressing = require("avante.ui.dressing")
|
||||
|
||||
---@private
|
||||
---@class AvanteLLMInternal
|
||||
local H = {}
|
||||
|
||||
---@class avante.LLM
|
||||
local M = {}
|
||||
|
||||
@ -248,6 +244,7 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
|
||||
---@field max_tokens number
|
||||
---
|
||||
---@class AvanteProvider: AvanteDefaultBaseProvider
|
||||
---@field model? string
|
||||
---@field api_key_name string
|
||||
---@field parse_response_data AvanteResponseParser
|
||||
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||
@ -259,7 +256,7 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
|
||||
|
||||
---@param opts AvantePromptOptions
|
||||
---@return AvanteClaudeMessage[]
|
||||
H.make_claude_message = function(opts)
|
||||
M.make_claude_message = function(opts)
|
||||
local code_prompt_obj = {
|
||||
type = "text",
|
||||
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
|
||||
@ -317,7 +314,7 @@ H.make_claude_message = function(opts)
|
||||
end
|
||||
|
||||
---@type AvanteResponseParser
|
||||
H.parse_claude_response = function(data_stream, event_state, opts)
|
||||
M.parse_claude_response = function(data_stream, event_state, opts)
|
||||
if event_state == "content_block_delta" then
|
||||
local json = vim.json.decode(data_stream)
|
||||
opts.on_chunk(json.delta.text)
|
||||
@ -330,7 +327,7 @@ H.parse_claude_response = function(data_stream, event_state, opts)
|
||||
end
|
||||
|
||||
---@type AvanteCurlArgsBuilder
|
||||
H.make_claude_curl_args = function(code_opts)
|
||||
M.make_claude_curl_args = function(code_opts)
|
||||
return {
|
||||
url = Utils.trim(Config.claude.endpoint, { suffix = "/" }) .. "/v1/messages",
|
||||
headers = {
|
||||
@ -343,7 +340,7 @@ H.make_claude_curl_args = function(code_opts)
|
||||
model = Config.claude.model,
|
||||
system = system_prompt,
|
||||
stream = true,
|
||||
messages = H.make_claude_message(code_opts),
|
||||
messages = M.make_claude_message(code_opts),
|
||||
temperature = Config.claude.temperature,
|
||||
max_tokens = Config.claude.max_tokens,
|
||||
},
|
||||
@ -354,7 +351,7 @@ end
|
||||
|
||||
---@param opts AvantePromptOptions
|
||||
---@return AvanteOpenAIMessage[]
|
||||
H.make_openai_message = function(opts)
|
||||
M.make_openai_message = function(opts)
|
||||
local user_prompt = base_user_prompt
|
||||
.. "\n\nCODE:\n"
|
||||
.. "```"
|
||||
@ -390,7 +387,7 @@ H.make_openai_message = function(opts)
|
||||
end
|
||||
|
||||
---@type AvanteResponseParser
|
||||
H.parse_openai_response = function(data_stream, _, opts)
|
||||
M.parse_openai_response = function(data_stream, _, opts)
|
||||
if data_stream:match('"%[DONE%]":') then
|
||||
opts.on_complete(nil)
|
||||
return
|
||||
@ -409,7 +406,7 @@ H.parse_openai_response = function(data_stream, _, opts)
|
||||
end
|
||||
|
||||
---@type AvanteCurlArgsBuilder
|
||||
H.make_openai_curl_args = function(code_opts)
|
||||
M.make_openai_curl_args = function(code_opts)
|
||||
return {
|
||||
url = Utils.trim(Config.openai.endpoint, { suffix = "/" }) .. "/v1/chat/completions",
|
||||
headers = {
|
||||
@ -418,7 +415,7 @@ H.make_openai_curl_args = function(code_opts)
|
||||
},
|
||||
body = {
|
||||
model = Config.openai.model,
|
||||
messages = H.make_openai_message(code_opts),
|
||||
messages = M.make_openai_message(code_opts),
|
||||
temperature = Config.openai.temperature,
|
||||
max_tokens = Config.openai.max_tokens,
|
||||
stream = true,
|
||||
@ -429,13 +426,13 @@ end
|
||||
------------------------------Azure------------------------------
|
||||
|
||||
---@type AvanteAiMessageBuilder
|
||||
H.make_azure_message = H.make_openai_message
|
||||
M.make_azure_message = M.make_openai_message
|
||||
|
||||
---@type AvanteResponseParser
|
||||
H.parse_azure_response = H.parse_openai_response
|
||||
M.parse_azure_response = M.parse_openai_response
|
||||
|
||||
---@type AvanteCurlArgsBuilder
|
||||
H.make_azure_curl_args = function(code_opts)
|
||||
M.make_azure_curl_args = function(code_opts)
|
||||
return {
|
||||
url = Config.azure.endpoint
|
||||
.. "/openai/deployments/"
|
||||
@ -447,7 +444,7 @@ H.make_azure_curl_args = function(code_opts)
|
||||
["api-key"] = E.value("azure"),
|
||||
},
|
||||
body = {
|
||||
messages = H.make_openai_message(code_opts),
|
||||
messages = M.make_openai_message(code_opts),
|
||||
temperature = Config.azure.temperature,
|
||||
max_tokens = Config.azure.max_tokens,
|
||||
stream = true,
|
||||
@ -458,13 +455,13 @@ end
|
||||
------------------------------Deepseek------------------------------
|
||||
|
||||
---@type AvanteAiMessageBuilder
|
||||
H.make_deepseek_message = H.make_openai_message
|
||||
M.make_deepseek_message = M.make_openai_message
|
||||
|
||||
---@type AvanteResponseParser
|
||||
H.parse_deepseek_response = H.parse_openai_response
|
||||
M.parse_deepseek_response = M.parse_openai_response
|
||||
|
||||
---@type AvanteCurlArgsBuilder
|
||||
H.make_deepseek_curl_args = function(code_opts)
|
||||
M.make_deepseek_curl_args = function(code_opts)
|
||||
return {
|
||||
url = Utils.trim(Config.deepseek.endpoint, { suffix = "/" }) .. "/chat/completions",
|
||||
headers = {
|
||||
@ -473,7 +470,7 @@ H.make_deepseek_curl_args = function(code_opts)
|
||||
},
|
||||
body = {
|
||||
model = Config.deepseek.model,
|
||||
messages = H.make_openai_message(code_opts),
|
||||
messages = M.make_openai_message(code_opts),
|
||||
temperature = Config.deepseek.temperature,
|
||||
max_tokens = Config.deepseek.max_tokens,
|
||||
stream = true,
|
||||
@ -484,13 +481,13 @@ end
|
||||
------------------------------Grok------------------------------
|
||||
|
||||
---@type AvanteAiMessageBuilder
|
||||
H.make_groq_message = H.make_openai_message
|
||||
M.make_groq_message = M.make_openai_message
|
||||
|
||||
---@type AvanteResponseParser
|
||||
H.parse_groq_response = H.parse_openai_response
|
||||
M.parse_groq_response = M.parse_openai_response
|
||||
|
||||
---@type AvanteCurlArgsBuilder
|
||||
H.make_groq_curl_args = function(code_opts)
|
||||
M.make_groq_curl_args = function(code_opts)
|
||||
return {
|
||||
url = Utils.trim(Config.groq.endpoint, { suffix = "/" }) .. "/openai/v1/chat/completions",
|
||||
headers = {
|
||||
@ -499,7 +496,7 @@ H.make_groq_curl_args = function(code_opts)
|
||||
},
|
||||
body = {
|
||||
model = Config.groq.model,
|
||||
messages = H.make_openai_message(code_opts),
|
||||
messages = M.make_openai_message(code_opts),
|
||||
temperature = Config.groq.temperature,
|
||||
max_tokens = Config.groq.max_tokens,
|
||||
stream = true,
|
||||
@ -537,7 +534,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
||||
local ProviderConfig = nil
|
||||
|
||||
if E.is_default(provider) then
|
||||
spec = H["make_" .. provider .. "_curl_args"](code_opts)
|
||||
spec = M["make_" .. provider .. "_curl_args"](code_opts)
|
||||
else
|
||||
ProviderConfig = Config.vendors[provider]
|
||||
spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
|
||||
@ -555,7 +552,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
||||
if ProviderConfig ~= nil then
|
||||
ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts)
|
||||
else
|
||||
H["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts)
|
||||
M["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts)
|
||||
end
|
||||
end
|
||||
end
|
||||
@ -603,6 +600,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
||||
return active_job
|
||||
end
|
||||
|
||||
---@private
|
||||
function M.setup()
|
||||
local has = E[Config.provider]
|
||||
if not has then
|
||||
@ -623,6 +621,7 @@ function M.refresh(provider)
|
||||
require("avante.config").override({ provider = provider })
|
||||
end
|
||||
|
||||
---@private
|
||||
M.commands = function()
|
||||
api.nvim_create_user_command("AvanteSwitchProvider", function(args)
|
||||
local cmd = vim.trim(args.args or "")
|
||||
@ -647,16 +646,4 @@ end
|
||||
M.SYSTEM_PROMPT = system_prompt
|
||||
M.BASE_PROMPT = base_user_prompt
|
||||
|
||||
return setmetatable(M, {
|
||||
__index = function(t, k)
|
||||
local h = H[k]
|
||||
if h then
|
||||
return H[k]
|
||||
end
|
||||
local v = t[k]
|
||||
if v then
|
||||
return t[k]
|
||||
end
|
||||
error("Failed to find key: " .. k)
|
||||
end,
|
||||
})
|
||||
return M
|
||||
|
Loading…
x
Reference in New Issue
Block a user