From 94adc992a62014c649c439dda6f6629880979afb Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 20 Aug 2024 07:43:53 -0400 Subject: [PATCH] chore(llm): expose types for support functions (#113) Signed-off-by: Aaron Pham --- lua/avante/llm.lua | 65 +++++++++++++++++++--------------------------- 1 file changed, 26 insertions(+), 39 deletions(-) diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index c5056a2..7b2c54b 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -7,10 +7,6 @@ local Config = require("avante.config") local Tiktoken = require("avante.tiktoken") local Dressing = require("avante.ui.dressing") ----@private ----@class AvanteLLMInternal -local H = {} - ---@class avante.LLM local M = {} @@ -248,6 +244,7 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m ---@field max_tokens number --- ---@class AvanteProvider: AvanteDefaultBaseProvider +---@field model? string ---@field api_key_name string ---@field parse_response_data AvanteResponseParser ---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput @@ -259,7 +256,7 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m ---@param opts AvantePromptOptions ---@return AvanteClaudeMessage[] -H.make_claude_message = function(opts) +M.make_claude_message = function(opts) local code_prompt_obj = { type = "text", text = string.format("```%s\n%s```", opts.code_lang, opts.code_content), @@ -317,7 +314,7 @@ H.make_claude_message = function(opts) end ---@type AvanteResponseParser -H.parse_claude_response = function(data_stream, event_state, opts) +M.parse_claude_response = function(data_stream, event_state, opts) if event_state == "content_block_delta" then local json = vim.json.decode(data_stream) opts.on_chunk(json.delta.text) @@ -330,7 +327,7 @@ H.parse_claude_response = function(data_stream, event_state, opts) end ---@type AvanteCurlArgsBuilder -H.make_claude_curl_args = function(code_opts) +M.make_claude_curl_args = function(code_opts) return { url = Utils.trim(Config.claude.endpoint, { suffix = "/" }) .. "/v1/messages", headers = { @@ -343,7 +340,7 @@ H.make_claude_curl_args = function(code_opts) model = Config.claude.model, system = system_prompt, stream = true, - messages = H.make_claude_message(code_opts), + messages = M.make_claude_message(code_opts), temperature = Config.claude.temperature, max_tokens = Config.claude.max_tokens, }, @@ -354,7 +351,7 @@ end ---@param opts AvantePromptOptions ---@return AvanteOpenAIMessage[] -H.make_openai_message = function(opts) +M.make_openai_message = function(opts) local user_prompt = base_user_prompt .. "\n\nCODE:\n" .. "```" @@ -390,7 +387,7 @@ H.make_openai_message = function(opts) end ---@type AvanteResponseParser -H.parse_openai_response = function(data_stream, _, opts) +M.parse_openai_response = function(data_stream, _, opts) if data_stream:match('"%[DONE%]":') then opts.on_complete(nil) return @@ -409,7 +406,7 @@ H.parse_openai_response = function(data_stream, _, opts) end ---@type AvanteCurlArgsBuilder -H.make_openai_curl_args = function(code_opts) +M.make_openai_curl_args = function(code_opts) return { url = Utils.trim(Config.openai.endpoint, { suffix = "/" }) .. "/v1/chat/completions", headers = { @@ -418,7 +415,7 @@ H.make_openai_curl_args = function(code_opts) }, body = { model = Config.openai.model, - messages = H.make_openai_message(code_opts), + messages = M.make_openai_message(code_opts), temperature = Config.openai.temperature, max_tokens = Config.openai.max_tokens, stream = true, @@ -429,13 +426,13 @@ end ------------------------------Azure------------------------------ ---@type AvanteAiMessageBuilder -H.make_azure_message = H.make_openai_message +M.make_azure_message = M.make_openai_message ---@type AvanteResponseParser -H.parse_azure_response = H.parse_openai_response +M.parse_azure_response = M.parse_openai_response ---@type AvanteCurlArgsBuilder -H.make_azure_curl_args = function(code_opts) +M.make_azure_curl_args = function(code_opts) return { url = Config.azure.endpoint .. "/openai/deployments/" @@ -447,7 +444,7 @@ H.make_azure_curl_args = function(code_opts) ["api-key"] = E.value("azure"), }, body = { - messages = H.make_openai_message(code_opts), + messages = M.make_openai_message(code_opts), temperature = Config.azure.temperature, max_tokens = Config.azure.max_tokens, stream = true, @@ -458,13 +455,13 @@ end ------------------------------Deepseek------------------------------ ---@type AvanteAiMessageBuilder -H.make_deepseek_message = H.make_openai_message +M.make_deepseek_message = M.make_openai_message ---@type AvanteResponseParser -H.parse_deepseek_response = H.parse_openai_response +M.parse_deepseek_response = M.parse_openai_response ---@type AvanteCurlArgsBuilder -H.make_deepseek_curl_args = function(code_opts) +M.make_deepseek_curl_args = function(code_opts) return { url = Utils.trim(Config.deepseek.endpoint, { suffix = "/" }) .. "/chat/completions", headers = { @@ -473,7 +470,7 @@ H.make_deepseek_curl_args = function(code_opts) }, body = { model = Config.deepseek.model, - messages = H.make_openai_message(code_opts), + messages = M.make_openai_message(code_opts), temperature = Config.deepseek.temperature, max_tokens = Config.deepseek.max_tokens, stream = true, @@ -484,13 +481,13 @@ end ------------------------------Grok------------------------------ ---@type AvanteAiMessageBuilder -H.make_groq_message = H.make_openai_message +M.make_groq_message = M.make_openai_message ---@type AvanteResponseParser -H.parse_groq_response = H.parse_openai_response +M.parse_groq_response = M.parse_openai_response ---@type AvanteCurlArgsBuilder -H.make_groq_curl_args = function(code_opts) +M.make_groq_curl_args = function(code_opts) return { url = Utils.trim(Config.groq.endpoint, { suffix = "/" }) .. "/openai/v1/chat/completions", headers = { @@ -499,7 +496,7 @@ H.make_groq_curl_args = function(code_opts) }, body = { model = Config.groq.model, - messages = H.make_openai_message(code_opts), + messages = M.make_openai_message(code_opts), temperature = Config.groq.temperature, max_tokens = Config.groq.max_tokens, stream = true, @@ -537,7 +534,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content, local ProviderConfig = nil if E.is_default(provider) then - spec = H["make_" .. provider .. "_curl_args"](code_opts) + spec = M["make_" .. provider .. "_curl_args"](code_opts) else ProviderConfig = Config.vendors[provider] spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts) @@ -555,7 +552,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content, if ProviderConfig ~= nil then ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts) else - H["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts) + M["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts) end end end @@ -603,6 +600,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content, return active_job end +---@private function M.setup() local has = E[Config.provider] if not has then @@ -623,6 +621,7 @@ function M.refresh(provider) require("avante.config").override({ provider = provider }) end +---@private M.commands = function() api.nvim_create_user_command("AvanteSwitchProvider", function(args) local cmd = vim.trim(args.args or "") @@ -647,16 +646,4 @@ end M.SYSTEM_PROMPT = system_prompt M.BASE_PROMPT = base_user_prompt -return setmetatable(M, { - __index = function(t, k) - local h = H[k] - if h then - return H[k] - end - local v = t[k] - if v then - return t[k] - end - error("Failed to find key: " .. k) - end, -}) +return M