From 3dca5f47644bc1413ff153288ed8b8955d0c40a5 Mon Sep 17 00:00:00 2001 From: yetone Date: Sat, 17 Aug 2024 22:29:05 +0800 Subject: [PATCH] feat: ask selected code block (#39) --- README.md | 2 +- lua/avante/ai_bot.lua | 87 ++++++++++---- lua/avante/config.lua | 3 +- lua/avante/init.lua | 12 +- lua/avante/range.lua | 24 ++++ lua/avante/selection.lua | 95 +++++++++++++++ lua/avante/selection_result.lua | 17 +++ lua/avante/sidebar.lua | 198 ++++++++++++++++++++++---------- lua/avante/utils.lua | 52 ++++++++- 9 files changed, 399 insertions(+), 91 deletions(-) create mode 100644 lua/avante/range.lua create mode 100644 lua/avante/selection.lua create mode 100644 lua/avante/selection_result.lua diff --git a/README.md b/README.md index 2e30b48..78f5c39 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,7 @@ Default setup configuration: }, }, mappings = { - show_sidebar = "aa", + ask = "aa", diff = { ours = "co", theirs = "ct", diff --git a/lua/avante/ai_bot.lua b/lua/avante/ai_bot.lua index fd3241a..d52a265 100644 --- a/lua/avante/ai_bot.lua +++ b/lua/avante/ai_bot.lua @@ -19,8 +19,9 @@ Your primary task is to suggest code modifications with precise line number rang 1. Carefully analyze the original code, paying close attention to its structure and line numbers. Line numbers start from 1 and include ALL lines, even empty ones. 2. When suggesting modifications: - a. Explain why the change is necessary or beneficial. - b. Provide the exact code snippet to be replaced using this format: + a. Use the language in the question to reply. If there are non-English parts in the question, use the language of those parts. + b. Explain why the change is necessary or beneficial. + c. Provide the exact code snippet to be replaced using this format: Replace lines: {{start_line}}-{{end_line}} ```{{language}} @@ -58,14 +59,12 @@ Replace lines: {{start_line}}-{{end_line}} Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift. ]] -local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) +local function call_claude_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete) local api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key then error("ANTHROPIC_API_KEY environment variable is not set") end - local user_prompt = base_user_prompt - local tokens = Config.claude.max_tokens local headers = { ["Content-Type"] = "application/json", @@ -79,33 +78,56 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun text = string.format("```%s\n%s```", code_lang, code_content), } + if Tiktoken.count(code_prompt_obj.text) > 1024 then + code_prompt_obj.cache_control = { type = "ephemeral" } + end + + if selected_code_content then + code_prompt_obj.text = string.format("```%s\n%s```", code_lang, code_content) + end + + local message_content = { + code_prompt_obj, + } + + if selected_code_content then + local selected_code_obj = { + type = "text", + text = string.format("```%s\n%s```", code_lang, selected_code_content), + } + + if Tiktoken.count(selected_code_obj.text) > 1024 then + selected_code_obj.cache_control = { type = "ephemeral" } + end + + table.insert(message_content, selected_code_obj) + end + + table.insert(message_content, { + type = "text", + text = string.format("%s", question), + }) + + local user_prompt = base_user_prompt + local user_prompt_obj = { type = "text", text = user_prompt, } - if Tiktoken.count(code_prompt_obj.text) > 1024 then - code_prompt_obj.cache_control = { type = "ephemeral" } - end - if Tiktoken.count(user_prompt_obj.text) > 1024 then user_prompt_obj.cache_control = { type = "ephemeral" } end + table.insert(message_content, user_prompt_obj) + local body = { model = Config.claude.model, system = system_prompt, messages = { { role = "user", - content = { - code_prompt_obj, - { - type = "text", - text = string.format("%s", question), - }, - user_prompt_obj, - }, + content = message_content, }, }, stream = true, @@ -154,21 +176,39 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun }) end -local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) +local function call_openai_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete) local api_key = os.getenv("OPENAI_API_KEY") if not api_key and Config.provider == "openai" then error("OPENAI_API_KEY environment variable is not set") end local user_prompt = base_user_prompt - .. "\n\nQUESTION:\n" - .. question .. "\n\nCODE:\n" .. "```" .. code_lang .. "\n" .. code_content .. "\n```" + .. "\n\nQUESTION:\n" + .. question + + if selected_code_content then + user_prompt = base_user_prompt + .. "\n\nCODE CONTEXT:\n" + .. "```" + .. code_lang + .. "\n" + .. code_content + .. "\n```" + .. "\n\nCODE:\n" + .. "```" + .. code_lang + .. "\n" + .. selected_code_content + .. "\n```" + .. "\n\nQUESTION:\n" + .. question + end local url, headers, body if Config.provider == "azure" then @@ -258,13 +298,14 @@ end ---@param question string ---@param code_lang string ---@param code_content string +---@param selected_content_content string | nil ---@param on_chunk fun(chunk: string): any ---@param on_complete fun(err: string|nil): any -function M.call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete) +function M.call_ai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete) if Config.provider == "openai" or Config.provider == "azure" then - call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) + call_openai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete) elseif Config.provider == "claude" then - call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) + call_claude_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete) end end diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 40e44c8..d6ee32c 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -32,7 +32,8 @@ M.defaults = { }, }, mappings = { - show_sidebar = "aa", + ask = "aa", + edit = "ae", diff = { ours = "co", theirs = "ct", diff --git a/lua/avante/init.lua b/lua/avante/init.lua index fe5cdf3..00f2bf4 100644 --- a/lua/avante/init.lua +++ b/lua/avante/init.lua @@ -4,6 +4,7 @@ local Tiktoken = require("avante.tiktoken") local Sidebar = require("avante.sidebar") local Config = require("avante.config") local Diff = require("avante.diff") +local Selection = require("avante.selection") ---@class Avante local M = { @@ -11,6 +12,7 @@ local M = { sidebars = {}, ---@type avante.Sidebar current = nil, + selection = nil, _once = false, } @@ -35,7 +37,7 @@ H.commands = function() end H.keymaps = function() - vim.keymap.set({ "n" }, Config.mappings.show_sidebar, M.toggle, { noremap = true }) + vim.keymap.set({ "n", "v" }, Config.mappings.ask, M.toggle, { noremap = true }) end H.autocmds = function() @@ -76,7 +78,9 @@ H.autocmds = function() if s then s:destroy() end - M.sidebars[tab] = nil + if tab ~= nil then + M.sidebars[tab] = nil + end end, }) @@ -137,6 +141,10 @@ function M.setup(opts) highlights = Config.highlights.diff, }) + local selection = Selection:new() + selection:setup() + M.selection = selection + -- setup helpers H.autocmds() H.commands() diff --git a/lua/avante/range.lua b/lua/avante/range.lua new file mode 100644 index 0000000..652b919 --- /dev/null +++ b/lua/avante/range.lua @@ -0,0 +1,24 @@ +--@class avante.Range +--@field start table Selection start point +--@field start.line number Line number of the selection start +--@field start.col number Column number of the selection start +--@field finish table Selection end point +--@field finish.line number Line number of the selection end +--@field finish.col number Column number of the selection end +local Range = {} +Range.__index = Range +-- Create a selection range +-- @param start table Selection start point +-- @param start.line number Line number of the selection start +-- @param start.col number Column number of the selection start +-- @param finish table Selection end point +-- @param finish.line number Line number of the selection end +-- @param finish.col number Column number of the selection end +function Range.new(start, finish) + local self = setmetatable({}, Range) + self.start = start + self.finish = finish + return self +end + +return Range diff --git a/lua/avante/selection.lua b/lua/avante/selection.lua new file mode 100644 index 0000000..1906aa6 --- /dev/null +++ b/lua/avante/selection.lua @@ -0,0 +1,95 @@ +local Config = require("avante.config") + +local api = vim.api +local fn = vim.fn + +local NAMESPACE = api.nvim_create_namespace("avante_selection") +local PRIORITY = vim.highlight.priorities.user + +local Selection = {} + +function Selection:new() + return setmetatable({ + hints_popup_extmark_id = nil, + edit_popup_renderer = nil, + augroup = api.nvim_create_augroup("avante_selection", { clear = true }), + }, { __index = self }) +end + +function Selection:get_virt_text_line() + local current_pos = fn.getpos(".") + + -- Get the current and start position line numbers + local current_line = current_pos[2] - 1 -- 0-indexed + + -- Ensure line numbers are not negative and don't exceed buffer range + local total_lines = api.nvim_buf_line_count(0) + if current_line < 0 then + current_line = 0 + end + if current_line >= total_lines then + current_line = total_lines - 1 + end + + -- Take the first line of the selection to ensure virt_text is always in the top right corner + return current_line +end + +function Selection:show_hints_popup() + self:close_hints_popup() + + local hint_text = string.format(" [Ask %s] ", Config.mappings.ask) + + local virt_text_line = self:get_virt_text_line() + + self.hints_popup_extmark_id = vim.api.nvim_buf_set_extmark(0, NAMESPACE, virt_text_line, -1, { + virt_text = { { hint_text, "Keyword" } }, + virt_text_pos = "eol", + priority = PRIORITY, + }) +end + +function Selection:close_hints_popup() + if self.hints_popup_extmark_id then + vim.api.nvim_buf_del_extmark(0, NAMESPACE, self.hints_popup_extmark_id) + self.hints_popup_extmark_id = nil + end +end + +function Selection:setup() + vim.api.nvim_create_autocmd({ "ModeChanged" }, { + group = self.augroup, + pattern = { "n:v", "n:V", "n:" }, -- Entering Visual mode from Normal mode + callback = function() + self:show_hints_popup() + end, + }) + + api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI" }, { + group = self.augroup, + callback = function() + if vim.fn.mode() == "v" or vim.fn.mode() == "V" or vim.fn.mode() == "" then + self:show_hints_popup() + else + self:close_hints_popup() + end + end, + }) + + api.nvim_create_autocmd({ "ModeChanged" }, { + group = self.augroup, + pattern = { "v:n", "v:i", "v:c" }, -- Switching from visual mode back to normal, insert, or other modes + callback = function() + self:close_hints_popup() + end, + }) +end + +function Selection:delete_autocmds() + if self.augroup then + vim.api.nvim_del_augroup_by_id(self.augroup) + end + self.augroup = nil +end + +return Selection diff --git a/lua/avante/selection_result.lua b/lua/avante/selection_result.lua new file mode 100644 index 0000000..e44d12d --- /dev/null +++ b/lua/avante/selection_result.lua @@ -0,0 +1,17 @@ +--@class avante.SelectionResult +--@field content string Selected content +--@field range avante.Range Selection range +local SelectionResult = {} +SelectionResult.__index = SelectionResult + +-- Create a selection content and range +--@param content string Selected content +--@param range avante.Range Selection range +function SelectionResult.new(content, range) + local self = setmetatable({}, SelectionResult) + self.content = content + self.range = range + return self +end + +return SelectionResult diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 715dbf3..ae46f77 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -20,6 +20,7 @@ local Sidebar = {} ---@class avante.SidebarState ---@field win integer ---@field buf integer +---@field selection avante.SelectionResult | nil ---@class avante.Sidebar ---@field id integer @@ -33,11 +34,11 @@ local Sidebar = {} function Sidebar:new(id) return setmetatable({ id = id, - code = { buf = 0, win = 0 }, + code = { buf = 0, win = 0, selection = nil }, winid = { result = 0, input = 0 }, view = View:new(), renderer = nil, - }, { __index = Sidebar }) + }, { __index = self }) end --- This function should only be used on TabClosed, nothing else. @@ -63,10 +64,17 @@ function Sidebar:reset() end function Sidebar:open() + local in_visual_mode = Utils.in_visual_mode() and self:in_code_win() if not self.view:is_open() then self:intialize() self:render() else + if in_visual_mode then + self:close() + self:intialize() + self:render() + return self + end self:focus() end return self @@ -86,8 +94,13 @@ function Sidebar:focus() return false end +function Sidebar:in_code_win() + return self.code.win == api.nvim_get_current_win() +end + function Sidebar:toggle() - if self.view:is_open() then + local in_visual_mode = Utils.in_visual_mode() and self:in_code_win() + if self.view:is_open() and not in_visual_mode then self:close() return false else @@ -108,6 +121,7 @@ end function Sidebar:intialize() self.code.win = api.nvim_get_current_win() self.code.buf = api.nvim_get_current_buf() + self.code.selection = Utils.get_visual_selection_and_range() local split_command = "botright vs" local layout = Config.get_renderer_layout_options() @@ -123,15 +137,15 @@ function Sidebar:intialize() }) self.renderer:on_mount(function() - local components = self.renderer:get_focusable_components() - -- current layout is a - -- [ chat ] - -- - -- [ input ] - self.winid.result = components[1].winid - self.winid.input = components[2].winid + self.winid.result = self.renderer:get_component_by_id("result").winid + self.winid.input = self.renderer:get_component_by_id("input").winid self.augroup = api.nvim_create_augroup("avante_" .. self.id .. self.view.win, { clear = true }) + local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf }) + local selected_code_buf = self.renderer:get_component_by_id("selected_code").bufnr + api.nvim_buf_set_option(selected_code_buf, "filetype", filetype) + api.nvim_set_option_value("wrap", false, { win = self.renderer:get_component_by_id("selected_code").winid }) + api.nvim_create_autocmd("BufEnter", { group = self.augroup, buffer = self.view.buf, @@ -215,7 +229,11 @@ function Sidebar:update_content(content, focus, callback) -- XXX: omit error for now, but should fix me why it can't jump here. return err end) - api.nvim_win_set_cursor(self.winid.result, { api.nvim_buf_line_count(self.view.buf), 0 }) + xpcall(function() + api.nvim_win_set_cursor(self.winid.result, { api.nvim_buf_line_count(self.view.buf), 0 }) + end, function(err) + return err + end) end end, 0) return self @@ -260,10 +278,12 @@ local function is_cursor_in_codeblock(codeblocks) return nil end -local function prepend_line_number(content) +local function prepend_line_number(content, start_line) + start_line = start_line or 1 local lines = vim.split(content, "\n") local result = {} for i, line in ipairs(lines) do + i = i + start_line - 1 table.insert(result, "L" .. i .. ": " .. line) end return table.concat(result, "\n") @@ -560,25 +580,53 @@ function Sidebar:render() local content = self:get_code_content() local content_with_line_numbers = prepend_line_number(content) + + local selected_code_content_with_line_numbers = nil + if self.code.selection ~= nil then + selected_code_content_with_line_numbers = + prepend_line_number(self.code.selection.content, self.code.selection.range.start.line) + end + local full_response = "" signal.is_loading = true local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf }) - AiBot.call_ai_api_stream(user_input, filetype, content_with_line_numbers, function(chunk) - full_response = full_response .. chunk - self:update_content( - "## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response, - true - ) - vim.schedule(function() - vim.cmd("redraw") - end) - end, function(err) - signal.is_loading = false + AiBot.call_ai_api_stream( + user_input, + filetype, + content_with_line_numbers, + selected_code_content_with_line_numbers, + function(chunk) + full_response = full_response .. chunk + self:update_content( + "## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response, + true + ) + vim.schedule(function() + vim.cmd("redraw") + end) + end, + function(err) + signal.is_loading = false - if err ~= nil then + if err ~= nil then + self:update_content( + "## " + .. timestamp + .. "\n\n> " + .. user_input:gsub("\n", "\n> ") + .. "\n\n" + .. full_response + .. "\n\n🚨 Error: " + .. vim.inspect(err), + true + ) + return + end + + -- Execute when the stream request is actually completed self:update_content( "## " .. timestamp @@ -586,37 +634,23 @@ function Sidebar:render() .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response - .. "\n\n🚨 Error: " - .. vim.inspect(err), - true + .. "\n\n**Generation complete!** Please review the code suggestions above.\n\n\n\n", + true, + function() + api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN }) + end ) - return + + api.nvim_set_current_win(self.winid.result) + + -- Display notification + -- show_notification("Content generation complete!") + + -- Save chat history + table.insert(chat_history or {}, { timestamp = timestamp, requirement = user_input, response = full_response }) + save_chat_history(self, chat_history) end - - -- Execute when the stream request is actually completed - self:update_content( - "## " - .. timestamp - .. "\n\n> " - .. user_input:gsub("\n", "\n> ") - .. "\n\n" - .. full_response - .. "\n\n**Generation complete!** Please review the code suggestions above.\n\n\n\n", - true, - function() - api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN }) - end - ) - - api.nvim_set_current_win(self.winid.result) - - -- Display notification - -- show_notification("Content generation complete!") - - -- Save chat history - table.insert(chat_history or {}, { timestamp = timestamp, requirement = user_input, response = full_response }) - save_chat_history(self, chat_history) - end) + ) end local body = function() @@ -625,15 +659,38 @@ function Sidebar:render() local code_file_fullpath = api.nvim_buf_get_name(self.code.buf) local code_filename = fn.fnamemodify(code_file_fullpath, ":t") + local input_label = string.format(" šŸ™‹ with %s %s ( switch focus): ", icon, code_filename) + + if self.code.selection ~= nil then + input_label = string.format( + " šŸ™‹ with selected code in %s %s(%d:%d) ( switch focus): ", + icon, + code_filename, + self.code.selection.range.start.line, + self.code.selection.range.finish.line + ) + end + + local selected_code_lines_count = 0 + local selected_code_max_lines_count = 10 + + local selected_code_size = 0 + + if self.code.selection ~= nil then + local selected_code_lines = vim.split(self.code.selection.content, "\n") + selected_code_lines_count = #selected_code_lines + selected_code_size = math.min(selected_code_lines_count, selected_code_max_lines_count) + 4 + end + return N.rows( { flex = 0 }, N.box( { direction = "column", - size = vim.o.lines - 4, + size = vim.o.lines - 4 - selected_code_size, }, N.buffer({ - id = "response", + id = "result", flex = 1, buf = self.view.buf, autoscroll = true, @@ -650,16 +707,35 @@ function Sidebar:render() }) ), N.gap(1), + N.paragraph({ + hidden = self.code.selection == nil, + id = "selected_code", + lines = self.code.selection and self.code.selection.content or "", + border_label = { + text = "šŸ’» Selected Code" + .. ( + selected_code_lines_count > selected_code_max_lines_count + and " (Show only the first " .. tostring(selected_code_max_lines_count) .. " lines)" + or "" + ), + align = "center", + }, + align = "left", + is_focusable = false, + max_lines = selected_code_max_lines_count, + padding = { + top = 1, + bottom = 1, + left = 1, + right = 1, + }, + }), N.columns( { flex = 0 }, N.text_input({ - id = "text-input", + id = "input", border_label = { - text = string.format( - " šŸ™‹ with %s %s ( key to switch between result and input): ", - icon, - code_filename - ), + text = input_label, }, placeholder = "Enter your question", autofocus = true, diff --git a/lua/avante/utils.lua b/lua/avante/utils.lua index 8bb3828..d2a1f03 100644 --- a/lua/avante/utils.lua +++ b/lua/avante/utils.lua @@ -1,11 +1,57 @@ +local Range = require("avante.range") +local SelectionResult = require("avante.selection_result") local M = {} - function M.trim_suffix(str, suffix) return string.gsub(str, suffix .. "$", "") end - function M.trim_line_number_prefix(line) return line:gsub("^L%d+: ", "") end - +function M.in_visual_mode() + local current_mode = vim.fn.mode() + return current_mode == "v" or current_mode == "V" or current_mode == "" +end +-- Get the selected content and range in Visual mode +-- @return avante.SelectionResult | nil Selected content and range +function M.get_visual_selection_and_range() + if not M.in_visual_mode() then + return nil + end + -- Get the start and end positions of Visual mode + local start_pos = vim.fn.getpos("v") + local end_pos = vim.fn.getpos(".") + -- Get the start and end line and column numbers + local start_line = start_pos[2] + local start_col = start_pos[3] + local end_line = end_pos[2] + local end_col = end_pos[3] + -- If the start point is after the end point, swap them + if start_line > end_line or (start_line == end_line and start_col > end_col) then + start_line, end_line = end_line, start_line + start_col, end_col = end_col, start_col + end + local content = "" + local range = Range.new({ line = start_line, col = start_col }, { line = end_line, col = end_col }) + -- Check if it's a single-line selection + if start_line == end_line then + -- Get partial content of a single line + local line = vim.fn.getline(start_line) + -- content = string.sub(line, start_col, end_col) + content = line + else + -- Multi-line selection: Get all lines in the selection + local lines = vim.fn.getline(start_line, end_line) + -- Extract partial content of the first line + -- lines[1] = string.sub(lines[1], start_col) + -- Extract partial content of the last line + -- lines[#lines] = string.sub(lines[#lines], 1, end_col) + -- Concatenate all lines in the selection into a string + content = table.concat(lines, "\n") + end + if not content then + return nil + end + -- Return the selected content and range + return SelectionResult.new(content, range) +end return M