chore(llm): expose types for support functions (#113)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2024-08-20 07:43:53 -04:00 committed by GitHub
parent 00f1e296b0
commit 94adc992a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,10 +7,6 @@ local Config = require("avante.config")
local Tiktoken = require("avante.tiktoken") local Tiktoken = require("avante.tiktoken")
local Dressing = require("avante.ui.dressing") local Dressing = require("avante.ui.dressing")
---@private
---@class AvanteLLMInternal
local H = {}
---@class avante.LLM ---@class avante.LLM
local M = {} local M = {}
@ -248,6 +244,7 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
---@field max_tokens number ---@field max_tokens number
--- ---
---@class AvanteProvider: AvanteDefaultBaseProvider ---@class AvanteProvider: AvanteDefaultBaseProvider
---@field model? string
---@field api_key_name string ---@field api_key_name string
---@field parse_response_data AvanteResponseParser ---@field parse_response_data AvanteResponseParser
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput ---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
@ -259,7 +256,7 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
---@param opts AvantePromptOptions ---@param opts AvantePromptOptions
---@return AvanteClaudeMessage[] ---@return AvanteClaudeMessage[]
H.make_claude_message = function(opts) M.make_claude_message = function(opts)
local code_prompt_obj = { local code_prompt_obj = {
type = "text", type = "text",
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content), text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
@ -317,7 +314,7 @@ H.make_claude_message = function(opts)
end end
---@type AvanteResponseParser ---@type AvanteResponseParser
H.parse_claude_response = function(data_stream, event_state, opts) M.parse_claude_response = function(data_stream, event_state, opts)
if event_state == "content_block_delta" then if event_state == "content_block_delta" then
local json = vim.json.decode(data_stream) local json = vim.json.decode(data_stream)
opts.on_chunk(json.delta.text) opts.on_chunk(json.delta.text)
@ -330,7 +327,7 @@ H.parse_claude_response = function(data_stream, event_state, opts)
end end
---@type AvanteCurlArgsBuilder ---@type AvanteCurlArgsBuilder
H.make_claude_curl_args = function(code_opts) M.make_claude_curl_args = function(code_opts)
return { return {
url = Utils.trim(Config.claude.endpoint, { suffix = "/" }) .. "/v1/messages", url = Utils.trim(Config.claude.endpoint, { suffix = "/" }) .. "/v1/messages",
headers = { headers = {
@ -343,7 +340,7 @@ H.make_claude_curl_args = function(code_opts)
model = Config.claude.model, model = Config.claude.model,
system = system_prompt, system = system_prompt,
stream = true, stream = true,
messages = H.make_claude_message(code_opts), messages = M.make_claude_message(code_opts),
temperature = Config.claude.temperature, temperature = Config.claude.temperature,
max_tokens = Config.claude.max_tokens, max_tokens = Config.claude.max_tokens,
}, },
@ -354,7 +351,7 @@ end
---@param opts AvantePromptOptions ---@param opts AvantePromptOptions
---@return AvanteOpenAIMessage[] ---@return AvanteOpenAIMessage[]
H.make_openai_message = function(opts) M.make_openai_message = function(opts)
local user_prompt = base_user_prompt local user_prompt = base_user_prompt
.. "\n\nCODE:\n" .. "\n\nCODE:\n"
.. "```" .. "```"
@ -390,7 +387,7 @@ H.make_openai_message = function(opts)
end end
---@type AvanteResponseParser ---@type AvanteResponseParser
H.parse_openai_response = function(data_stream, _, opts) M.parse_openai_response = function(data_stream, _, opts)
if data_stream:match('"%[DONE%]":') then if data_stream:match('"%[DONE%]":') then
opts.on_complete(nil) opts.on_complete(nil)
return return
@ -409,7 +406,7 @@ H.parse_openai_response = function(data_stream, _, opts)
end end
---@type AvanteCurlArgsBuilder ---@type AvanteCurlArgsBuilder
H.make_openai_curl_args = function(code_opts) M.make_openai_curl_args = function(code_opts)
return { return {
url = Utils.trim(Config.openai.endpoint, { suffix = "/" }) .. "/v1/chat/completions", url = Utils.trim(Config.openai.endpoint, { suffix = "/" }) .. "/v1/chat/completions",
headers = { headers = {
@ -418,7 +415,7 @@ H.make_openai_curl_args = function(code_opts)
}, },
body = { body = {
model = Config.openai.model, model = Config.openai.model,
messages = H.make_openai_message(code_opts), messages = M.make_openai_message(code_opts),
temperature = Config.openai.temperature, temperature = Config.openai.temperature,
max_tokens = Config.openai.max_tokens, max_tokens = Config.openai.max_tokens,
stream = true, stream = true,
@ -429,13 +426,13 @@ end
------------------------------Azure------------------------------ ------------------------------Azure------------------------------
---@type AvanteAiMessageBuilder ---@type AvanteAiMessageBuilder
H.make_azure_message = H.make_openai_message M.make_azure_message = M.make_openai_message
---@type AvanteResponseParser ---@type AvanteResponseParser
H.parse_azure_response = H.parse_openai_response M.parse_azure_response = M.parse_openai_response
---@type AvanteCurlArgsBuilder ---@type AvanteCurlArgsBuilder
H.make_azure_curl_args = function(code_opts) M.make_azure_curl_args = function(code_opts)
return { return {
url = Config.azure.endpoint url = Config.azure.endpoint
.. "/openai/deployments/" .. "/openai/deployments/"
@ -447,7 +444,7 @@ H.make_azure_curl_args = function(code_opts)
["api-key"] = E.value("azure"), ["api-key"] = E.value("azure"),
}, },
body = { body = {
messages = H.make_openai_message(code_opts), messages = M.make_openai_message(code_opts),
temperature = Config.azure.temperature, temperature = Config.azure.temperature,
max_tokens = Config.azure.max_tokens, max_tokens = Config.azure.max_tokens,
stream = true, stream = true,
@ -458,13 +455,13 @@ end
------------------------------Deepseek------------------------------ ------------------------------Deepseek------------------------------
---@type AvanteAiMessageBuilder ---@type AvanteAiMessageBuilder
H.make_deepseek_message = H.make_openai_message M.make_deepseek_message = M.make_openai_message
---@type AvanteResponseParser ---@type AvanteResponseParser
H.parse_deepseek_response = H.parse_openai_response M.parse_deepseek_response = M.parse_openai_response
---@type AvanteCurlArgsBuilder ---@type AvanteCurlArgsBuilder
H.make_deepseek_curl_args = function(code_opts) M.make_deepseek_curl_args = function(code_opts)
return { return {
url = Utils.trim(Config.deepseek.endpoint, { suffix = "/" }) .. "/chat/completions", url = Utils.trim(Config.deepseek.endpoint, { suffix = "/" }) .. "/chat/completions",
headers = { headers = {
@ -473,7 +470,7 @@ H.make_deepseek_curl_args = function(code_opts)
}, },
body = { body = {
model = Config.deepseek.model, model = Config.deepseek.model,
messages = H.make_openai_message(code_opts), messages = M.make_openai_message(code_opts),
temperature = Config.deepseek.temperature, temperature = Config.deepseek.temperature,
max_tokens = Config.deepseek.max_tokens, max_tokens = Config.deepseek.max_tokens,
stream = true, stream = true,
@ -484,13 +481,13 @@ end
------------------------------Grok------------------------------ ------------------------------Grok------------------------------
---@type AvanteAiMessageBuilder ---@type AvanteAiMessageBuilder
H.make_groq_message = H.make_openai_message M.make_groq_message = M.make_openai_message
---@type AvanteResponseParser ---@type AvanteResponseParser
H.parse_groq_response = H.parse_openai_response M.parse_groq_response = M.parse_openai_response
---@type AvanteCurlArgsBuilder ---@type AvanteCurlArgsBuilder
H.make_groq_curl_args = function(code_opts) M.make_groq_curl_args = function(code_opts)
return { return {
url = Utils.trim(Config.groq.endpoint, { suffix = "/" }) .. "/openai/v1/chat/completions", url = Utils.trim(Config.groq.endpoint, { suffix = "/" }) .. "/openai/v1/chat/completions",
headers = { headers = {
@ -499,7 +496,7 @@ H.make_groq_curl_args = function(code_opts)
}, },
body = { body = {
model = Config.groq.model, model = Config.groq.model,
messages = H.make_openai_message(code_opts), messages = M.make_openai_message(code_opts),
temperature = Config.groq.temperature, temperature = Config.groq.temperature,
max_tokens = Config.groq.max_tokens, max_tokens = Config.groq.max_tokens,
stream = true, stream = true,
@ -537,7 +534,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
local ProviderConfig = nil local ProviderConfig = nil
if E.is_default(provider) then if E.is_default(provider) then
spec = H["make_" .. provider .. "_curl_args"](code_opts) spec = M["make_" .. provider .. "_curl_args"](code_opts)
else else
ProviderConfig = Config.vendors[provider] ProviderConfig = Config.vendors[provider]
spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts) spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
@ -555,7 +552,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
if ProviderConfig ~= nil then if ProviderConfig ~= nil then
ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts) ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts)
else else
H["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts) M["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts)
end end
end end
end end
@ -603,6 +600,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
return active_job return active_job
end end
---@private
function M.setup() function M.setup()
local has = E[Config.provider] local has = E[Config.provider]
if not has then if not has then
@ -623,6 +621,7 @@ function M.refresh(provider)
require("avante.config").override({ provider = provider }) require("avante.config").override({ provider = provider })
end end
---@private
M.commands = function() M.commands = function()
api.nvim_create_user_command("AvanteSwitchProvider", function(args) api.nvim_create_user_command("AvanteSwitchProvider", function(args)
local cmd = vim.trim(args.args or "") local cmd = vim.trim(args.args or "")
@ -647,16 +646,4 @@ end
M.SYSTEM_PROMPT = system_prompt M.SYSTEM_PROMPT = system_prompt
M.BASE_PROMPT = base_user_prompt M.BASE_PROMPT = base_user_prompt
return setmetatable(M, { return M
__index = function(t, k)
local h = H[k]
if h then
return H[k]
end
local v = t[k]
if v then
return t[k]
end
error("Failed to find key: " .. k)
end,
})