feat: write to multiple files (#720)

This commit is contained in:
yetone 2024-10-14 20:15:11 +08:00 committed by GitHub
parent 347d9be730
commit b19573cb2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 122 additions and 65 deletions

View File

@ -298,17 +298,20 @@ end
---@field explanation string ---@field explanation string
---@field start_line_in_response_buf integer ---@field start_line_in_response_buf integer
---@field end_line_in_response_buf integer ---@field end_line_in_response_buf integer
---@field filepath string
---@param response_content string ---@param response_content string
---@return AvanteCodeSnippet[] ---@return table<string, AvanteCodeSnippet[]>
local function extract_code_snippets(response_content) local function extract_code_snippets_map(response_content)
local snippets = {} local snippets = {}
local current_snippet = {} local current_snippet = {}
local in_code_block = false local in_code_block = false
local lang, start_line, end_line, start_line_in_response_buf local lang, start_line, end_line, start_line_in_response_buf
local explanation = "" local explanation = ""
for idx, line in ipairs(vim.split(response_content, "\n")) do local lines = vim.split(response_content, "\n")
for idx, line in ipairs(lines) do
local _, start_line_str, end_line_str = local _, start_line_str, end_line_str =
line:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)") line:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)")
if start_line_str ~= nil and end_line_str ~= nil then if start_line_str ~= nil and end_line_str ~= nil then
@ -337,6 +340,7 @@ local function extract_code_snippets(response_content)
explanation = explanation, explanation = explanation,
start_line_in_response_buf = start_line_in_response_buf, start_line_in_response_buf = start_line_in_response_buf,
end_line_in_response_buf = idx, end_line_in_response_buf = idx,
filepath = lines[start_line_in_response_buf - 2],
} }
table.insert(snippets, snippet) table.insert(snippets, snippet)
end end
@ -357,21 +361,33 @@ local function extract_code_snippets(response_content)
end end
end end
return snippets local snippets_map = {}
for _, snippet in ipairs(snippets) do
snippets_map[snippet.filepath] = snippets_map[snippet.filepath] or {}
table.insert(snippets_map[snippet.filepath], snippet)
end end
---@param snippets AvanteCodeSnippet[] return snippets_map
---@return AvanteCodeSnippet[] end
local function ensure_snippets_no_overlap(original_content, snippets)
---@param snippets_map table<string, AvanteCodeSnippet[]>
---@return table<string, AvanteCodeSnippet[]>
local function ensure_snippets_no_overlap(snippets_map)
local new_snippets_map = {}
for filepath, snippets in pairs(snippets_map) do
table.sort(snippets, function(a, b) return a.range[1] < b.range[1] end) table.sort(snippets, function(a, b) return a.range[1] < b.range[1] end)
local original_content = ""
if Utils.file.exists(filepath) then original_content = Utils.file.read_content(filepath) or "" end
local original_lines = vim.split(original_content, "\n") local original_lines = vim.split(original_content, "\n")
local result = {} local new_snippets = {}
local last_end_line = 0 local last_end_line = 0
for _, snippet in ipairs(snippets) do for _, snippet in ipairs(snippets) do
if snippet.range[1] > last_end_line then if snippet.range[1] > last_end_line then
table.insert(result, snippet) table.insert(new_snippets, snippet)
last_end_line = snippet.range[2] last_end_line = snippet.range[2]
else else
local snippet_lines = vim.split(snippet.content, "\n") local snippet_lines = vim.split(snippet.content, "\n")
@ -390,15 +406,17 @@ local function ensure_snippets_no_overlap(original_content, snippets)
if new_start_line ~= nil then if new_start_line ~= nil then
snippet.content = table.concat(vim.list_slice(snippet_lines, new_start_line - snippet.range[1] + 1), "\n") snippet.content = table.concat(vim.list_slice(snippet_lines, new_start_line - snippet.range[1] + 1), "\n")
snippet.range[1] = new_start_line snippet.range[1] = new_start_line
table.insert(result, snippet) table.insert(new_snippets, snippet)
last_end_line = snippet.range[2] last_end_line = snippet.range[2]
else else
Utils.error("Failed to ensure snippets no overlap", { once = true, title = "Avante" }) Utils.error("Failed to ensure snippets no overlap", { once = true, title = "Avante" })
end end
end end
end end
new_snippets_map[filepath] = new_snippets
end
return result return new_snippets_map
end end
local function insert_conflict_contents(bufnr, snippets) local function insert_conflict_contents(bufnr, snippets)
@ -494,40 +512,47 @@ end
---@param current_cursor boolean ---@param current_cursor boolean
function Sidebar:apply(current_cursor) function Sidebar:apply(current_cursor)
local content = table.concat(Utils.get_buf_lines(0, -1, self.code.bufnr), "\n")
local response, response_start_line = self:get_content_between_separators() local response, response_start_line = self:get_content_between_separators()
local all_snippets = extract_code_snippets(response) local all_snippets_map = extract_code_snippets_map(response)
all_snippets = ensure_snippets_no_overlap(content, all_snippets) all_snippets_map = ensure_snippets_no_overlap(all_snippets_map)
local selected_snippets = {} local selected_snippets_map = {}
if current_cursor then if current_cursor then
if self.result and self.result.winid then if self.result and self.result.winid then
local cursor_line = Utils.get_cursor_pos(self.result.winid) local cursor_line = Utils.get_cursor_pos(self.result.winid)
for _, snippet in ipairs(all_snippets) do for filepath, snippets in pairs(all_snippets_map) do
for _, snippet in ipairs(snippets) do
if if
cursor_line >= snippet.start_line_in_response_buf + response_start_line - 1 cursor_line >= snippet.start_line_in_response_buf + response_start_line - 1
and cursor_line <= snippet.end_line_in_response_buf + response_start_line - 1 and cursor_line <= snippet.end_line_in_response_buf + response_start_line - 1
then then
selected_snippets = { snippet } selected_snippets_map[filepath] = { snippet }
break break
end end
end end
end end
end
else else
selected_snippets = all_snippets selected_snippets_map = all_snippets_map
end end
vim.defer_fn(function() vim.defer_fn(function()
insert_conflict_contents(self.code.bufnr, selected_snippets) for filepath, snippets in pairs(selected_snippets_map) do
local bufnr = Utils.get_opened_buffer(filepath)
api.nvim_set_current_win(self.code.winid) if not bufnr then bufnr = Utils.create_new_buffer_with_file(filepath) end
insert_conflict_contents(bufnr, snippets)
local winid = Utils.get_winid(bufnr)
if not winid then goto continue end
api.nvim_set_current_win(winid)
api.nvim_feedkeys(api.nvim_replace_termcodes("<Esc>", true, false, true), "n", true) api.nvim_feedkeys(api.nvim_replace_termcodes("<Esc>", true, false, true), "n", true)
Diff.add_visited_buffer(self.code.bufnr) Diff.add_visited_buffer(bufnr)
Diff.process(self.code.bufnr) Diff.process(bufnr)
api.nvim_win_set_cursor(self.code.winid, { 1, 0 }) api.nvim_win_set_cursor(winid, { 1, 0 })
vim.defer_fn(function() vim.defer_fn(function()
Diff.find_next("ours") Diff.find_next("ours")
vim.cmd("normal! zz") vim.cmd("normal! zz")
end, 100) end, 100)
::continue::
end
end, 10) end, 10)
end end

View File

@ -33,4 +33,9 @@ function M.read_content(filepath)
return nil return nil
end end
function M.exists(filepath)
local stat = vim.loop.fs_stat(filepath)
return stat ~= nil
end
return M return M

View File

@ -373,6 +373,12 @@ function M.get_win_options(winid, opt_name, key)
end end
end end
function M.get_winid(bufnr)
for _, winid in ipairs(api.nvim_list_wins()) do
if api.nvim_win_get_buf(winid) == bufnr then return winid end
end
end
function M.unlock_buf(bufnr) function M.unlock_buf(bufnr)
vim.bo[bufnr].modified = false vim.bo[bufnr].modified = false
vim.bo[bufnr].modifiable = true vim.bo[bufnr].modifiable = true
@ -655,4 +661,25 @@ function M.get_mentions()
} }
end end
function M.get_opened_buffer(filepath)
for _, buf in ipairs(api.nvim_list_bufs()) do
if fn.buflisted(buf) == 1 and fn.bufname(buf) == filepath then return buf end
end
return nil
end
function M.create_new_buffer_with_file(filepath)
local buf = api.nvim_create_buf(false, true)
api.nvim_buf_set_name(buf, filepath)
api.nvim_set_option_value("buftype", "", { buf = buf })
api.nvim_set_current_buf(buf)
vim.cmd("edit " .. filepath)
return buf
end
return M return M