From b19573cb2a8e30662e2245c3bfbb6e22d09d0905 Mon Sep 17 00:00:00 2001 From: yetone Date: Mon, 14 Oct 2024 20:15:11 +0800 Subject: [PATCH] feat: write to multiple files (#720) --- lua/avante/sidebar.lua | 155 ++++++++++++++++++++++---------------- lua/avante/utils/file.lua | 5 ++ lua/avante/utils/init.lua | 27 +++++++ 3 files changed, 122 insertions(+), 65 deletions(-) diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 01536a1..b33d3aa 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -298,17 +298,20 @@ end ---@field explanation string ---@field start_line_in_response_buf integer ---@field end_line_in_response_buf integer +---@field filepath string ---@param response_content string ----@return AvanteCodeSnippet[] -local function extract_code_snippets(response_content) +---@return table +local function extract_code_snippets_map(response_content) local snippets = {} local current_snippet = {} local in_code_block = false local lang, start_line, end_line, start_line_in_response_buf local explanation = "" - for idx, line in ipairs(vim.split(response_content, "\n")) do + local lines = vim.split(response_content, "\n") + + for idx, line in ipairs(lines) do local _, start_line_str, end_line_str = line:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)") if start_line_str ~= nil and end_line_str ~= nil then @@ -337,6 +340,7 @@ local function extract_code_snippets(response_content) explanation = explanation, start_line_in_response_buf = start_line_in_response_buf, end_line_in_response_buf = idx, + filepath = lines[start_line_in_response_buf - 2], } table.insert(snippets, snippet) end @@ -357,48 +361,62 @@ local function extract_code_snippets(response_content) end end - return snippets -end - ----@param snippets AvanteCodeSnippet[] ----@return AvanteCodeSnippet[] -local function ensure_snippets_no_overlap(original_content, snippets) - table.sort(snippets, function(a, b) return a.range[1] < b.range[1] end) - - local original_lines = vim.split(original_content, "\n") - - local result = {} - local last_end_line = 0 + local snippets_map = {} for _, snippet in ipairs(snippets) do - if snippet.range[1] > last_end_line then - table.insert(result, snippet) - last_end_line = snippet.range[2] - else - local snippet_lines = vim.split(snippet.content, "\n") - -- Trim the overlapping part - local new_start_line = nil - for i = snippet.range[1], math.min(snippet.range[2], last_end_line) do - if - Utils.remove_indentation(original_lines[i]) - == Utils.remove_indentation(snippet_lines[i - snippet.range[1] + 1]) - then - new_start_line = i + 1 - else - break - end - end - if new_start_line ~= nil then - snippet.content = table.concat(vim.list_slice(snippet_lines, new_start_line - snippet.range[1] + 1), "\n") - snippet.range[1] = new_start_line - table.insert(result, snippet) - last_end_line = snippet.range[2] - else - Utils.error("Failed to ensure snippets no overlap", { once = true, title = "Avante" }) - end - end + snippets_map[snippet.filepath] = snippets_map[snippet.filepath] or {} + table.insert(snippets_map[snippet.filepath], snippet) end - return result + return snippets_map +end + +---@param snippets_map table +---@return table +local function ensure_snippets_no_overlap(snippets_map) + local new_snippets_map = {} + + for filepath, snippets in pairs(snippets_map) do + table.sort(snippets, function(a, b) return a.range[1] < b.range[1] end) + + local original_content = "" + if Utils.file.exists(filepath) then original_content = Utils.file.read_content(filepath) or "" end + + local original_lines = vim.split(original_content, "\n") + + local new_snippets = {} + local last_end_line = 0 + for _, snippet in ipairs(snippets) do + if snippet.range[1] > last_end_line then + table.insert(new_snippets, snippet) + last_end_line = snippet.range[2] + else + local snippet_lines = vim.split(snippet.content, "\n") + -- Trim the overlapping part + local new_start_line = nil + for i = snippet.range[1], math.min(snippet.range[2], last_end_line) do + if + Utils.remove_indentation(original_lines[i]) + == Utils.remove_indentation(snippet_lines[i - snippet.range[1] + 1]) + then + new_start_line = i + 1 + else + break + end + end + if new_start_line ~= nil then + snippet.content = table.concat(vim.list_slice(snippet_lines, new_start_line - snippet.range[1] + 1), "\n") + snippet.range[1] = new_start_line + table.insert(new_snippets, snippet) + last_end_line = snippet.range[2] + else + Utils.error("Failed to ensure snippets no overlap", { once = true, title = "Avante" }) + end + end + end + new_snippets_map[filepath] = new_snippets + end + + return new_snippets_map end local function insert_conflict_contents(bufnr, snippets) @@ -494,40 +512,47 @@ end ---@param current_cursor boolean function Sidebar:apply(current_cursor) - local content = table.concat(Utils.get_buf_lines(0, -1, self.code.bufnr), "\n") local response, response_start_line = self:get_content_between_separators() - local all_snippets = extract_code_snippets(response) - all_snippets = ensure_snippets_no_overlap(content, all_snippets) - local selected_snippets = {} + local all_snippets_map = extract_code_snippets_map(response) + all_snippets_map = ensure_snippets_no_overlap(all_snippets_map) + local selected_snippets_map = {} if current_cursor then if self.result and self.result.winid then local cursor_line = Utils.get_cursor_pos(self.result.winid) - for _, snippet in ipairs(all_snippets) do - if - cursor_line >= snippet.start_line_in_response_buf + response_start_line - 1 - and cursor_line <= snippet.end_line_in_response_buf + response_start_line - 1 - then - selected_snippets = { snippet } - break + for filepath, snippets in pairs(all_snippets_map) do + for _, snippet in ipairs(snippets) do + if + cursor_line >= snippet.start_line_in_response_buf + response_start_line - 1 + and cursor_line <= snippet.end_line_in_response_buf + response_start_line - 1 + then + selected_snippets_map[filepath] = { snippet } + break + end end end end else - selected_snippets = all_snippets + selected_snippets_map = all_snippets_map end vim.defer_fn(function() - insert_conflict_contents(self.code.bufnr, selected_snippets) - - api.nvim_set_current_win(self.code.winid) - api.nvim_feedkeys(api.nvim_replace_termcodes("", true, false, true), "n", true) - Diff.add_visited_buffer(self.code.bufnr) - Diff.process(self.code.bufnr) - api.nvim_win_set_cursor(self.code.winid, { 1, 0 }) - vim.defer_fn(function() - Diff.find_next("ours") - vim.cmd("normal! zz") - end, 100) + for filepath, snippets in pairs(selected_snippets_map) do + local bufnr = Utils.get_opened_buffer(filepath) + if not bufnr then bufnr = Utils.create_new_buffer_with_file(filepath) end + insert_conflict_contents(bufnr, snippets) + local winid = Utils.get_winid(bufnr) + if not winid then goto continue end + api.nvim_set_current_win(winid) + api.nvim_feedkeys(api.nvim_replace_termcodes("", true, false, true), "n", true) + Diff.add_visited_buffer(bufnr) + Diff.process(bufnr) + api.nvim_win_set_cursor(winid, { 1, 0 }) + vim.defer_fn(function() + Diff.find_next("ours") + vim.cmd("normal! zz") + end, 100) + ::continue:: + end end, 10) end diff --git a/lua/avante/utils/file.lua b/lua/avante/utils/file.lua index a043df9..5e3a196 100644 --- a/lua/avante/utils/file.lua +++ b/lua/avante/utils/file.lua @@ -33,4 +33,9 @@ function M.read_content(filepath) return nil end +function M.exists(filepath) + local stat = vim.loop.fs_stat(filepath) + return stat ~= nil +end + return M diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 417a6ed..47f8e6a 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -373,6 +373,12 @@ function M.get_win_options(winid, opt_name, key) end end +function M.get_winid(bufnr) + for _, winid in ipairs(api.nvim_list_wins()) do + if api.nvim_win_get_buf(winid) == bufnr then return winid end + end +end + function M.unlock_buf(bufnr) vim.bo[bufnr].modified = false vim.bo[bufnr].modifiable = true @@ -655,4 +661,25 @@ function M.get_mentions() } end +function M.get_opened_buffer(filepath) + for _, buf in ipairs(api.nvim_list_bufs()) do + if fn.buflisted(buf) == 1 and fn.bufname(buf) == filepath then return buf end + end + return nil +end + +function M.create_new_buffer_with_file(filepath) + local buf = api.nvim_create_buf(false, true) + + api.nvim_buf_set_name(buf, filepath) + + api.nvim_set_option_value("buftype", "", { buf = buf }) + + api.nvim_set_current_buf(buf) + + vim.cmd("edit " .. filepath) + + return buf +end + return M