feat: support apply current code snippet (#391)

This commit is contained in:
yetone 2024-08-30 15:01:23 +08:00 committed by GitHub
parent ea73816665
commit 8c71e1f624
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -183,15 +183,26 @@ local function realign_line_numbers(code_lines, snippet)
return snippet return snippet
end end
---@class AvanteCodeSnippet
---@field range integer[]
---@field content string
---@field lang string
---@field explanation string
---@field start_line_in_response_buf integer
---@field end_line_in_response_buf integer
---@param code_content string
---@param response_content string
---@return AvanteCodeSnippet[]
local function extract_code_snippets(code_content, response_content) local function extract_code_snippets(code_content, response_content)
local code_lines = vim.split(code_content, "\n") local code_lines = vim.split(code_content, "\n")
local snippets = {} local snippets = {}
local current_snippet = {} local current_snippet = {}
local in_code_block = false local in_code_block = false
local lang, start_line, end_line local lang, start_line, end_line, start_line_in_response_buf
local explanation = "" local explanation = ""
for _, line in ipairs(vim.split(response_content, "\n")) do for idx, line in ipairs(vim.split(response_content, "\n")) do
local start_line_str, end_line_str = line:match("^Replace lines: (%d+)-(%d+)") local start_line_str, end_line_str = line:match("^Replace lines: (%d+)-(%d+)")
if start_line_str ~= nil and end_line_str ~= nil then if start_line_str ~= nil and end_line_str ~= nil then
start_line = tonumber(start_line_str) start_line = tonumber(start_line_str)
@ -205,6 +216,8 @@ local function extract_code_snippets(code_content, response_content)
content = table.concat(current_snippet, "\n"), content = table.concat(current_snippet, "\n"),
lang = lang, lang = lang,
explanation = explanation, explanation = explanation,
start_line_in_response_buf = start_line_in_response_buf,
end_line_in_response_buf = idx,
} }
snippet = realign_line_numbers(code_lines, snippet) snippet = realign_line_numbers(code_lines, snippet)
table.insert(snippets, snippet) table.insert(snippets, snippet)
@ -219,6 +232,7 @@ local function extract_code_snippets(code_content, response_content)
lang = "text" lang = "text"
end end
in_code_block = true in_code_block = true
start_line_in_response_buf = idx
end end
elseif in_code_block then elseif in_code_block then
table.insert(current_snippet, line) table.insert(current_snippet, line)
@ -318,10 +332,26 @@ local function parse_codeblocks(buf)
return codeblocks return codeblocks
end end
function Sidebar:apply() ---@param current_cursor boolean
function Sidebar:apply(current_cursor)
local content = table.concat(Utils.get_buf_lines(0, -1, self.code.bufnr), "\n") local content = table.concat(Utils.get_buf_lines(0, -1, self.code.bufnr), "\n")
local response = self:get_content_between_separators() local response, response_start_line = self:get_content_between_separators()
local snippets = extract_code_snippets(content, response) local snippets = extract_code_snippets(content, response)
if current_cursor then
if self.result and self.result.winid then
local cursor_line = Utils.get_cursor_pos(self.result.winid)
for _, snippet in ipairs(snippets) do
if
cursor_line >= snippet.start_line_in_response_buf + response_start_line
and cursor_line <= snippet.end_line_in_response_buf + response_start_line
then
snippets = { snippet }
break
end
end
end
end
local conflict_content = get_conflict_content(content, snippets) local conflict_content = get_conflict_content(content, snippets)
vim.defer_fn(function() vim.defer_fn(function()
@ -501,7 +531,7 @@ function Sidebar:on_mount()
current_apply_extmark_id = current_apply_extmark_id =
api.nvim_buf_set_extmark(self.result.bufnr, CODEBLOCK_KEYBINDING_NAMESPACE, block.start_line, -1, { api.nvim_buf_set_extmark(self.result.bufnr, CODEBLOCK_KEYBINDING_NAMESPACE, block.start_line, -1, {
virt_text = { { " [<A>: apply patch] ", "Keyword" } }, virt_text = { { " [<a>: apply this, <A>: apply all] ", "Keyword" } },
virt_text_pos = "right_align", virt_text_pos = "right_align",
hl_group = "Keyword", hl_group = "Keyword",
priority = PRIORITY, priority = PRIORITY,
@ -509,8 +539,11 @@ function Sidebar:on_mount()
end end
local function bind_apply_key() local function bind_apply_key()
vim.keymap.set("n", "a", function()
self:apply(true)
end, { buffer = self.result.bufnr, noremap = true, silent = true })
vim.keymap.set("n", "A", function() vim.keymap.set("n", "A", function()
self:apply() self:apply(false)
end, { buffer = self.result.bufnr, noremap = true, silent = true }) end, { buffer = self.result.bufnr, noremap = true, silent = true })
end end
@ -878,7 +911,7 @@ function Sidebar:update_content_with_history(history)
self:update_content(content) self:update_content(content)
end end
---@return string ---@return string, integer
function Sidebar:get_content_between_separators() function Sidebar:get_content_between_separators()
local separator = "---" local separator = "---"
local cursor_line, _ = Utils.get_cursor_pos() local cursor_line, _ = Utils.get_cursor_pos()
@ -910,7 +943,7 @@ function Sidebar:get_content_between_separators()
end end
local content = table.concat(vim.list_slice(lines, start_line, end_line), "\n") local content = table.concat(vim.list_slice(lines, start_line, end_line), "\n")
return content return content, start_line
end end
---@alias AvanteSlashCommands "clear" | "help" | "lines" ---@alias AvanteSlashCommands "clear" | "help" | "lines"
@ -1150,7 +1183,7 @@ function Sidebar:create_input()
) )
if Config.behaviour.auto_apply_diff_after_generation then if Config.behaviour.auto_apply_diff_after_generation then
self:apply() self:apply(false)
end end
end end