From ba06b9bd9da9ed3f46299a78f2b07a963a906566 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 19 Aug 2024 05:40:57 -0400 Subject: [PATCH] refactor: cleanup utils and expose lazy (#83) Signed-off-by: Aaron Pham --- lua/avante/config.lua | 18 +- lua/avante/diff.lua | 34 ++-- lua/avante/diff/utils.lua | 88 ---------- lua/avante/llm.lua | 6 +- lua/avante/utils.lua | 66 ------- lua/avante/{diff => utils}/colors.lua | 1 + lua/avante/utils/init.lua | 238 ++++++++++++++++++++++++++ 7 files changed, 270 insertions(+), 181 deletions(-) delete mode 100644 lua/avante/diff/utils.lua delete mode 100644 lua/avante/utils.lua rename lua/avante/{diff => utils}/colors.lua (97%) create mode 100644 lua/avante/utils/init.lua diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 894dd20..ff1106d 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -6,8 +6,11 @@ local M = {} ---@class avante.Config M.defaults = { + debug = false, + ---Currently, default supported providers include "claude", "openai", "azure", "deepseek", "groq" + ---For custom provider, see README.md ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | [string] - provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq" + provider = "claude", openai = { endpoint = "https://api.openai.com", model = "gpt-4o", @@ -39,18 +42,21 @@ M.defaults = { temperature = 0, max_tokens = 4096, }, - --- To add support for custom provider, follow the format below - --- See https://github.com/yetone/avante.nvim/README.md#custom-providers for more details + ---To add support for custom provider, follow the format below + ---See https://github.com/yetone/avante.nvim/README.md#custom-providers for more details ---@type table vendors = {}, + ---Specify the behaviour of avante.nvim + ---1. auto_apply_diff_after_generation: Whether to automatically apply diff after LLM response. + --- This would simulate similar behaviour to cursor. Default to false. behaviour = { - auto_apply_diff_after_generation = false, -- Whether to automatically apply diff after LLM response. + auto_apply_diff_after_generation = false, }, highlights = { ---@type AvanteConflictHighlights diff = { - current = "DiffText", -- need have background color - incoming = "DiffAdd", -- need have background color + current = "DiffText", + incoming = "DiffAdd", }, }, mappings = { diff --git a/lua/avante/diff.lua b/lua/avante/diff.lua index 4a7ef1d..6ecfe4b 100644 --- a/lua/avante/diff.lua +++ b/lua/avante/diff.lua @@ -2,10 +2,8 @@ local M = {} -local color = require("avante.diff.colors") -local utils = require("avante.diff.utils") - local Config = require("avante.config") +local Utils = require("avante.utils") local fn = vim.fn local api = vim.api @@ -283,7 +281,7 @@ local function find_position(bufnr, comparator, opts) if not match then return end - local line = utils.get_cursor_pos() + local line = Utils.get_cursor_pos() line = line - 1 -- Convert to 0-based for position comparison if opts and opts.reverse then @@ -372,7 +370,7 @@ end ---@param range_start? integer ---@param range_end? integer local function parse_buffer(bufnr, range_start, range_end) - local lines = utils.get_buf_lines(range_start or 0, range_end or -1, bufnr) + local lines = Utils.get_buf_lines(range_start or 0, range_end or -1, bufnr) local prev_conflicts = visited_buffers[bufnr].positions ~= nil and #visited_buffers[bufnr].positions > 0 local has_conflict, positions = detect_conflicts(lines) @@ -501,15 +499,15 @@ end ---Derive the colour of the section label highlights based on each sections highlights ---@param highlights AvanteConflictHighlights local function set_highlights(highlights) - local current_color = utils.get_hl(highlights.current) - local incoming_color = utils.get_hl(highlights.incoming) - local ancestor_color = utils.get_hl(highlights.ancestor) + local current_color = Utils.get_hl(highlights.current) + local incoming_color = Utils.get_hl(highlights.incoming) + local ancestor_color = Utils.get_hl(highlights.ancestor) local current_bg = current_color.background or DEFAULT_CURRENT_BG_COLOR local incoming_bg = incoming_color.background or DEFAULT_INCOMING_BG_COLOR local ancestor_bg = ancestor_color.background or DEFAULT_ANCESTOR_BG_COLOR - local current_label_bg = color.shade_color(current_bg, 60) - local incoming_label_bg = color.shade_color(incoming_bg, 60) - local ancestor_label_bg = color.shade_color(ancestor_bg, 60) + local current_label_bg = Utils.colors.shade_color(current_bg, 60) + local incoming_label_bg = Utils.colors.shade_color(incoming_bg, 60) + local ancestor_label_bg = Utils.colors.shade_color(ancestor_bg, 60) api.nvim_set_hl(0, CURRENT_HL, { background = current_bg, bold = true, default = true }) api.nvim_set_hl(0, INCOMING_HL, { background = incoming_bg, bold = true, default = true }) api.nvim_set_hl(0, ANCESTOR_HL, { background = ancestor_bg, bold = true, default = true }) @@ -555,7 +553,7 @@ function M.setup() api.nvim_set_decoration_provider(NAMESPACE, { on_buf = function(_, bufnr, _) - return utils.is_valid_buf(bufnr) + return Utils.is_valid_buf(bufnr) end, on_win = function(_, _, bufnr, _, _) if visited_buffers[bufnr] then @@ -656,10 +654,10 @@ function M.choose(side) local lines = {} if vim.tbl_contains({ SIDES.OURS, SIDES.THEIRS, SIDES.BASE }, side) then local data = position[name_map[side]] - lines = utils.get_buf_lines(data.content_start, data.content_end + 1) + lines = Utils.get_buf_lines(data.content_start, data.content_end + 1) elseif side == SIDES.BOTH then - local first = utils.get_buf_lines(position.current.content_start, position.current.content_end + 1) - local second = utils.get_buf_lines(position.incoming.content_start, position.incoming.content_end + 1) + local first = Utils.get_buf_lines(position.current.content_start, position.current.content_end + 1) + local second = Utils.get_buf_lines(position.incoming.content_start, position.incoming.content_end + 1) lines = vim.list_extend(first, second) elseif side == SIDES.NONE then lines = {} @@ -697,10 +695,10 @@ function M.choose(side) local lines = {} if vim.tbl_contains({ SIDES.OURS, SIDES.THEIRS, SIDES.BASE }, side) then local data = position[name_map[side]] - lines = utils.get_buf_lines(data.content_start, data.content_end + 1) + lines = Utils.get_buf_lines(data.content_start, data.content_end + 1) elseif side == SIDES.BOTH then - local first = utils.get_buf_lines(position.current.content_start, position.current.content_end + 1) - local second = utils.get_buf_lines(position.incoming.content_start, position.incoming.content_end + 1) + local first = Utils.get_buf_lines(position.current.content_start, position.current.content_end + 1) + local second = Utils.get_buf_lines(position.incoming.content_start, position.incoming.content_end + 1) lines = vim.list_extend(first, second) elseif side == SIDES.NONE then lines = {} diff --git a/lua/avante/diff/utils.lua b/lua/avante/diff/utils.lua deleted file mode 100644 index d6cc6ad..0000000 --- a/lua/avante/diff/utils.lua +++ /dev/null @@ -1,88 +0,0 @@ ------------------------------------------------------------------------------// --- Utils ------------------------------------------------------------------------------// -local M = {} - -local api = vim.api -local fn = vim.fn - ---- Wrapper for [vim.notify] ----@param msg string|string[] ----@param level "error" | "trace" | "debug" | "info" | "warn" ----@param once boolean? -function M.notify(msg, level, once) - if type(msg) == "table" then - msg = table.concat(msg, "\n") - end - local lvl = vim.log.levels[level:upper()] or vim.log.levels.INFO - local opts = { title = "Git conflict" } - if once then - return vim.notify_once(msg, lvl, opts) - end - vim.notify(msg, lvl, opts) -end - ---- Start an async job ----@param cmd string ----@param callback fun(data: string[]): nil -function M.job(cmd, callback) - fn.jobstart(cmd, { - stdout_buffered = true, - on_stdout = function(_, data, _) - callback(data) - end, - }) -end - ----Only call the passed function once every timeout in ms ----@param timeout integer ----@param func function ----@return function -function M.throttle(timeout, func) - local timer = vim.loop.new_timer() - local running = false - return function(...) - if not running then - func(...) - running = true - timer:start(timeout, 0, function() - running = false - end) - end - end -end - ----Wrapper around `api.nvim_buf_get_lines` which defaults to the current buffer ----@param start integer ----@param _end integer ----@param buf integer? ----@return string[] -function M.get_buf_lines(start, _end, buf) - return api.nvim_buf_get_lines(buf or 0, start, _end, false) -end - ----Get cursor row and column as (1, 0) based ----@param win_id integer? ----@return integer, integer -function M.get_cursor_pos(win_id) - return unpack(api.nvim_win_get_cursor(win_id or 0)) -end - ----Check if the buffer is likely to have actionable conflict markers ----@param bufnr integer? ----@return boolean -function M.is_valid_buf(bufnr) - bufnr = bufnr or 0 - return #vim.bo[bufnr].buftype == 0 and vim.bo[bufnr].modifiable -end - ----@param name string? ----@return table -function M.get_hl(name) - if not name then - return {} - end - return api.nvim_get_hl_by_name(name, true) -end - -return M diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index e94e804..8e69258 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -88,7 +88,7 @@ E.setup = function(var, refresh) vim.fn.setenv(var, value) else if not E[Config.provider] then - vim.notify_once("Failed to set " .. var .. ". Avante won't work as expected", vim.log.levels.WARN) + Util.warn("Failed to set " .. var .. ". Avante won't work as expected", { once = true, title = "Avante" }) end end end @@ -559,7 +559,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content, callback = function() if active_job then active_job:shutdown() - vim.notify("LLM request cancelled", vim.log.levels.DEBUG) + Utils.debug("LLM request cancelled", { title = "Avante" }) active_job = nil end end, @@ -583,7 +583,7 @@ function M.refresh(provider) if not has then E.setup(E.key(provider), true) else - vim.notify_once("Switch to provider: " .. provider, vim.log.levels.INFO) + Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" }) end require("avante.config").override({ provider = provider }) end diff --git a/lua/avante/utils.lua b/lua/avante/utils.lua deleted file mode 100644 index 9191a4a..0000000 --- a/lua/avante/utils.lua +++ /dev/null @@ -1,66 +0,0 @@ -local Range = require("avante.range") -local SelectionResult = require("avante.selection_result") -local M = {} ----@param str string ----@param opts? {suffix?: string, prefix?: string} -function M.trim(str, opts) - if not opts then - return str - end - if opts.suffix then - return str:sub(-1) == opts.suffix and str:sub(1, -2) or str - elseif opts.prefix then - return str:sub(1, 1) == opts.prefix and str:sub(2) or str - end -end -function M.trim_line_number_prefix(line) - return line:gsub("^L%d+: ", "") -end -function M.in_visual_mode() - local current_mode = vim.fn.mode() - return current_mode == "v" or current_mode == "V" or current_mode == "" -end --- Get the selected content and range in Visual mode --- @return avante.SelectionResult | nil Selected content and range -function M.get_visual_selection_and_range() - if not M.in_visual_mode() then - return nil - end - -- Get the start and end positions of Visual mode - local start_pos = vim.fn.getpos("v") - local end_pos = vim.fn.getpos(".") - -- Get the start and end line and column numbers - local start_line = start_pos[2] - local start_col = start_pos[3] - local end_line = end_pos[2] - local end_col = end_pos[3] - -- If the start point is after the end point, swap them - if start_line > end_line or (start_line == end_line and start_col > end_col) then - start_line, end_line = end_line, start_line - start_col, end_col = end_col, start_col - end - local content = "" - local range = Range.new({ line = start_line, col = start_col }, { line = end_line, col = end_col }) - -- Check if it's a single-line selection - if start_line == end_line then - -- Get partial content of a single line - local line = vim.fn.getline(start_line) - -- content = string.sub(line, start_col, end_col) - content = line - else - -- Multi-line selection: Get all lines in the selection - local lines = vim.fn.getline(start_line, end_line) - -- Extract partial content of the first line - -- lines[1] = string.sub(lines[1], start_col) - -- Extract partial content of the last line - -- lines[#lines] = string.sub(lines[#lines], 1, end_col) - -- Concatenate all lines in the selection into a string - content = table.concat(lines, "\n") - end - if not content then - return nil - end - -- Return the selected content and range - return SelectionResult.new(content, range) -end -return M diff --git a/lua/avante/diff/colors.lua b/lua/avante/utils/colors.lua similarity index 97% rename from lua/avante/diff/colors.lua rename to lua/avante/utils/colors.lua index 5005519..e5baf40 100644 --- a/lua/avante/diff/colors.lua +++ b/lua/avante/utils/colors.lua @@ -1,3 +1,4 @@ +---@class avante.util.colors local M = {} local bit = require("bit") diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua new file mode 100644 index 0000000..4a6f45e --- /dev/null +++ b/lua/avante/utils/init.lua @@ -0,0 +1,238 @@ +local api = vim.api +local fn = vim.fn + +---@class avante.Utils: LazyUtilCore +---@field colors avante.util.colors +local M = {} + +setmetatable(M, { + __index = function(t, k) + local ok, lazyutil = pcall(require, "lazy.core.util") + if ok and lazyutil[k] then + return lazyutil[k] + end + + ---@diagnostic disable-next-line: no-unknown + t[k] = require("avante.utils." .. k) + return t[k] + end, +}) + +---@param str string +---@param opts? {suffix?: string, prefix?: string} +function M.trim(str, opts) + if not opts then + return str + end + if opts.suffix then + return str:sub(-1) == opts.suffix and str:sub(1, -2) or str + elseif opts.prefix then + return str:sub(1, 1) == opts.prefix and str:sub(2) or str + end +end + +function M.in_visual_mode() + local current_mode = vim.fn.mode() + return current_mode == "v" or current_mode == "V" or current_mode == "" +end + +-- Get the selected content and range in Visual mode +-- @return avante.SelectionResult | nil Selected content and range +function M.get_visual_selection_and_range() + local Range = require("avante.range") + local SelectionResult = require("avante.selection_result") + + if not M.in_visual_mode() then + return nil + end + -- Get the start and end positions of Visual mode + local start_pos = vim.fn.getpos("v") + local end_pos = vim.fn.getpos(".") + -- Get the start and end line and column numbers + local start_line = start_pos[2] + local start_col = start_pos[3] + local end_line = end_pos[2] + local end_col = end_pos[3] + -- If the start point is after the end point, swap them + if start_line > end_line or (start_line == end_line and start_col > end_col) then + start_line, end_line = end_line, start_line + start_col, end_col = end_col, start_col + end + local content = "" + local range = Range.new({ line = start_line, col = start_col }, { line = end_line, col = end_col }) + -- Check if it's a single-line selection + if start_line == end_line then + -- Get partial content of a single line + local line = vim.fn.getline(start_line) + -- content = string.sub(line, start_col, end_col) + content = line + else + -- Multi-line selection: Get all lines in the selection + local lines = vim.fn.getline(start_line, end_line) + -- Extract partial content of the first line + -- lines[1] = string.sub(lines[1], start_col) + -- Extract partial content of the last line + -- lines[#lines] = string.sub(lines[#lines], 1, end_col) + -- Concatenate all lines in the selection into a string + content = table.concat(lines, "\n") + end + if not content then + return nil + end + -- Return the selected content and range + return SelectionResult.new(content, range) +end + +--- Start an async job +---@param cmd string +---@param callback fun(data: string[]): nil +function M.job(cmd, callback) + fn.jobstart(cmd, { + stdout_buffered = true, + on_stdout = function(_, data, _) + callback(data) + end, + }) +end + +---Only call the passed function once every timeout in ms +---@param timeout integer +---@param func function +---@return function +function M.throttle(timeout, func) + local timer = vim.loop.new_timer() + local running = false + return function(...) + if not running then + func(...) + running = true + timer:start(timeout, 0, function() + running = false + end) + end + end +end + +---Wrapper around `api.nvim_buf_get_lines` which defaults to the current buffer +---@param start integer +---@param _end integer +---@param buf integer? +---@return string[] +function M.get_buf_lines(start, _end, buf) + return api.nvim_buf_get_lines(buf or 0, start, _end, false) +end + +---Get cursor row and column as (1, 0) based +---@param win_id integer? +---@return integer, integer +function M.get_cursor_pos(win_id) + return unpack(api.nvim_win_get_cursor(win_id or 0)) +end + +---Check if the buffer is likely to have actionable conflict markers +---@param bufnr integer? +---@return boolean +function M.is_valid_buf(bufnr) + bufnr = bufnr or 0 + return #vim.bo[bufnr].buftype == 0 and vim.bo[bufnr].modifiable +end + +---@param name string? +---@return table +function M.get_hl(name) + if not name then + return {} + end + return api.nvim_get_hl(0, { name = name }) +end + +--- vendor from lazy.nvim for early access and override + +---@param msg string|string[] +---@param opts? LazyNotifyOpts +function M.notify(msg, opts) + if vim.in_fast_event() then + return vim.schedule(function() + M.notify(msg, opts) + end) + end + + opts = opts or {} + if type(msg) == "table" then + ---@diagnostic disable-next-line: no-unknown + msg = table.concat( + vim.tbl_filter(function(line) + return line or false + end, msg), + "\n" + ) + end + if opts.stacktrace then + msg = msg .. M.pretty_trace({ level = opts.stacklevel or 2 }) + end + local lang = opts.lang or "markdown" + local n = opts.once and vim.notify_once or vim.notify + n(msg, opts.level or vim.log.levels.INFO, { + on_open = function(win) + local ok = pcall(function() + vim.treesitter.language.add("markdown") + end) + if not ok then + pcall(require, "nvim-treesitter") + end + vim.wo[win].conceallevel = 3 + vim.wo[win].concealcursor = "" + vim.wo[win].spell = false + local buf = vim.api.nvim_win_get_buf(win) + if not pcall(vim.treesitter.start, buf, lang) then + vim.bo[buf].filetype = lang + vim.bo[buf].syntax = lang + end + end, + title = opts.title or "lazy.nvim", + }) +end + +---@param msg string|string[] +---@param opts? LazyNotifyOpts +function M.error(msg, opts) + opts = opts or {} + opts.level = vim.log.levels.ERROR + M.notify(msg, opts) +end + +---@param msg string|string[] +---@param opts? LazyNotifyOpts +function M.info(msg, opts) + opts = opts or {} + opts.level = vim.log.levels.INFO + M.notify(msg, opts) +end + +---@param msg string|string[] +---@param opts? LazyNotifyOpts +function M.warn(msg, opts) + opts = opts or {} + opts.level = vim.log.levels.WARN + M.notify(msg, opts) +end + +---@param msg string|table +---@param opts? LazyNotifyOpts +function M.debug(msg, opts) + if not require("avante.config").options.debug then + return + end + opts = opts or {} + if opts.title then + opts.title = "lazy.nvim: " .. opts.title + end + if type(msg) == "string" then + M.notify(msg, opts) + else + opts.lang = "lua" + M.notify(vim.inspect(msg), opts) + end +end + +return M