chore(llm): expose types for support functions (#113)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2024-08-20 07:43:53 -04:00 committed by GitHub
parent 00f1e296b0
commit 94adc992a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,10 +7,6 @@ local Config = require("avante.config")
local Tiktoken = require("avante.tiktoken")
local Dressing = require("avante.ui.dressing")
---@private
---@class AvanteLLMInternal
local H = {}
---@class avante.LLM
local M = {}
@ -248,6 +244,7 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
---@field max_tokens number
---
---@class AvanteProvider: AvanteDefaultBaseProvider
---@field model? string
---@field api_key_name string
---@field parse_response_data AvanteResponseParser
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
@ -259,7 +256,7 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
---@param opts AvantePromptOptions
---@return AvanteClaudeMessage[]
H.make_claude_message = function(opts)
M.make_claude_message = function(opts)
local code_prompt_obj = {
type = "text",
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
@ -317,7 +314,7 @@ H.make_claude_message = function(opts)
end
---@type AvanteResponseParser
H.parse_claude_response = function(data_stream, event_state, opts)
M.parse_claude_response = function(data_stream, event_state, opts)
if event_state == "content_block_delta" then
local json = vim.json.decode(data_stream)
opts.on_chunk(json.delta.text)
@ -330,7 +327,7 @@ H.parse_claude_response = function(data_stream, event_state, opts)
end
---@type AvanteCurlArgsBuilder
H.make_claude_curl_args = function(code_opts)
M.make_claude_curl_args = function(code_opts)
return {
url = Utils.trim(Config.claude.endpoint, { suffix = "/" }) .. "/v1/messages",
headers = {
@ -343,7 +340,7 @@ H.make_claude_curl_args = function(code_opts)
model = Config.claude.model,
system = system_prompt,
stream = true,
messages = H.make_claude_message(code_opts),
messages = M.make_claude_message(code_opts),
temperature = Config.claude.temperature,
max_tokens = Config.claude.max_tokens,
},
@ -354,7 +351,7 @@ end
---@param opts AvantePromptOptions
---@return AvanteOpenAIMessage[]
H.make_openai_message = function(opts)
M.make_openai_message = function(opts)
local user_prompt = base_user_prompt
.. "\n\nCODE:\n"
.. "```"
@ -390,7 +387,7 @@ H.make_openai_message = function(opts)
end
---@type AvanteResponseParser
H.parse_openai_response = function(data_stream, _, opts)
M.parse_openai_response = function(data_stream, _, opts)
if data_stream:match('"%[DONE%]":') then
opts.on_complete(nil)
return
@ -409,7 +406,7 @@ H.parse_openai_response = function(data_stream, _, opts)
end
---@type AvanteCurlArgsBuilder
H.make_openai_curl_args = function(code_opts)
M.make_openai_curl_args = function(code_opts)
return {
url = Utils.trim(Config.openai.endpoint, { suffix = "/" }) .. "/v1/chat/completions",
headers = {
@ -418,7 +415,7 @@ H.make_openai_curl_args = function(code_opts)
},
body = {
model = Config.openai.model,
messages = H.make_openai_message(code_opts),
messages = M.make_openai_message(code_opts),
temperature = Config.openai.temperature,
max_tokens = Config.openai.max_tokens,
stream = true,
@ -429,13 +426,13 @@ end
------------------------------Azure------------------------------
---@type AvanteAiMessageBuilder
H.make_azure_message = H.make_openai_message
M.make_azure_message = M.make_openai_message
---@type AvanteResponseParser
H.parse_azure_response = H.parse_openai_response
M.parse_azure_response = M.parse_openai_response
---@type AvanteCurlArgsBuilder
H.make_azure_curl_args = function(code_opts)
M.make_azure_curl_args = function(code_opts)
return {
url = Config.azure.endpoint
.. "/openai/deployments/"
@ -447,7 +444,7 @@ H.make_azure_curl_args = function(code_opts)
["api-key"] = E.value("azure"),
},
body = {
messages = H.make_openai_message(code_opts),
messages = M.make_openai_message(code_opts),
temperature = Config.azure.temperature,
max_tokens = Config.azure.max_tokens,
stream = true,
@ -458,13 +455,13 @@ end
------------------------------Deepseek------------------------------
---@type AvanteAiMessageBuilder
H.make_deepseek_message = H.make_openai_message
M.make_deepseek_message = M.make_openai_message
---@type AvanteResponseParser
H.parse_deepseek_response = H.parse_openai_response
M.parse_deepseek_response = M.parse_openai_response
---@type AvanteCurlArgsBuilder
H.make_deepseek_curl_args = function(code_opts)
M.make_deepseek_curl_args = function(code_opts)
return {
url = Utils.trim(Config.deepseek.endpoint, { suffix = "/" }) .. "/chat/completions",
headers = {
@ -473,7 +470,7 @@ H.make_deepseek_curl_args = function(code_opts)
},
body = {
model = Config.deepseek.model,
messages = H.make_openai_message(code_opts),
messages = M.make_openai_message(code_opts),
temperature = Config.deepseek.temperature,
max_tokens = Config.deepseek.max_tokens,
stream = true,
@ -484,13 +481,13 @@ end
------------------------------Grok------------------------------
---@type AvanteAiMessageBuilder
H.make_groq_message = H.make_openai_message
M.make_groq_message = M.make_openai_message
---@type AvanteResponseParser
H.parse_groq_response = H.parse_openai_response
M.parse_groq_response = M.parse_openai_response
---@type AvanteCurlArgsBuilder
H.make_groq_curl_args = function(code_opts)
M.make_groq_curl_args = function(code_opts)
return {
url = Utils.trim(Config.groq.endpoint, { suffix = "/" }) .. "/openai/v1/chat/completions",
headers = {
@ -499,7 +496,7 @@ H.make_groq_curl_args = function(code_opts)
},
body = {
model = Config.groq.model,
messages = H.make_openai_message(code_opts),
messages = M.make_openai_message(code_opts),
temperature = Config.groq.temperature,
max_tokens = Config.groq.max_tokens,
stream = true,
@ -537,7 +534,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
local ProviderConfig = nil
if E.is_default(provider) then
spec = H["make_" .. provider .. "_curl_args"](code_opts)
spec = M["make_" .. provider .. "_curl_args"](code_opts)
else
ProviderConfig = Config.vendors[provider]
spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
@ -555,7 +552,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
if ProviderConfig ~= nil then
ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts)
else
H["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts)
M["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts)
end
end
end
@ -603,6 +600,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
return active_job
end
---@private
function M.setup()
local has = E[Config.provider]
if not has then
@ -623,6 +621,7 @@ function M.refresh(provider)
require("avante.config").override({ provider = provider })
end
---@private
M.commands = function()
api.nvim_create_user_command("AvanteSwitchProvider", function(args)
local cmd = vim.trim(args.args or "")
@ -647,16 +646,4 @@ end
M.SYSTEM_PROMPT = system_prompt
M.BASE_PROMPT = base_user_prompt
return setmetatable(M, {
__index = function(t, k)
local h = H[k]
if h then
return H[k]
end
local v = t[k]
if v then
return t[k]
end
error("Failed to find key: " .. k)
end,
})
return M