feat: ask selected code block (#39)

This commit is contained in:
yetone 2024-08-17 22:29:05 +08:00 committed by GitHub
parent dea737bf05
commit 3dca5f4764
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 399 additions and 91 deletions

View File

@ -89,7 +89,7 @@ Default setup configuration:
},
},
mappings = {
show_sidebar = "<leader>aa",
ask = "<leader>aa",
diff = {
ours = "co",
theirs = "ct",

View File

@ -19,8 +19,9 @@ Your primary task is to suggest code modifications with precise line number rang
1. Carefully analyze the original code, paying close attention to its structure and line numbers. Line numbers start from 1 and include ALL lines, even empty ones.
2. When suggesting modifications:
a. Explain why the change is necessary or beneficial.
b. Provide the exact code snippet to be replaced using this format:
a. Use the language in the question to reply. If there are non-English parts in the question, use the language of those parts.
b. Explain why the change is necessary or beneficial.
c. Provide the exact code snippet to be replaced using this format:
Replace lines: {{start_line}}-{{end_line}}
```{{language}}
@ -58,14 +59,12 @@ Replace lines: {{start_line}}-{{end_line}}
Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift.
]]
local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete)
local function call_claude_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete)
local api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key then
error("ANTHROPIC_API_KEY environment variable is not set")
end
local user_prompt = base_user_prompt
local tokens = Config.claude.max_tokens
local headers = {
["Content-Type"] = "application/json",
@ -79,33 +78,56 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun
text = string.format("<code>```%s\n%s```</code>", code_lang, code_content),
}
if Tiktoken.count(code_prompt_obj.text) > 1024 then
code_prompt_obj.cache_control = { type = "ephemeral" }
end
if selected_code_content then
code_prompt_obj.text = string.format("<code_context>```%s\n%s```</code_context>", code_lang, code_content)
end
local message_content = {
code_prompt_obj,
}
if selected_code_content then
local selected_code_obj = {
type = "text",
text = string.format("<code>```%s\n%s```</code>", code_lang, selected_code_content),
}
if Tiktoken.count(selected_code_obj.text) > 1024 then
selected_code_obj.cache_control = { type = "ephemeral" }
end
table.insert(message_content, selected_code_obj)
end
table.insert(message_content, {
type = "text",
text = string.format("<question>%s</question>", question),
})
local user_prompt = base_user_prompt
local user_prompt_obj = {
type = "text",
text = user_prompt,
}
if Tiktoken.count(code_prompt_obj.text) > 1024 then
code_prompt_obj.cache_control = { type = "ephemeral" }
end
if Tiktoken.count(user_prompt_obj.text) > 1024 then
user_prompt_obj.cache_control = { type = "ephemeral" }
end
table.insert(message_content, user_prompt_obj)
local body = {
model = Config.claude.model,
system = system_prompt,
messages = {
{
role = "user",
content = {
code_prompt_obj,
{
type = "text",
text = string.format("<question>%s</question>", question),
},
user_prompt_obj,
},
content = message_content,
},
},
stream = true,
@ -154,21 +176,39 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun
})
end
local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
local function call_openai_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete)
local api_key = os.getenv("OPENAI_API_KEY")
if not api_key and Config.provider == "openai" then
error("OPENAI_API_KEY environment variable is not set")
end
local user_prompt = base_user_prompt
.. "\n\nQUESTION:\n"
.. question
.. "\n\nCODE:\n"
.. "```"
.. code_lang
.. "\n"
.. code_content
.. "\n```"
.. "\n\nQUESTION:\n"
.. question
if selected_code_content then
user_prompt = base_user_prompt
.. "\n\nCODE CONTEXT:\n"
.. "```"
.. code_lang
.. "\n"
.. code_content
.. "\n```"
.. "\n\nCODE:\n"
.. "```"
.. code_lang
.. "\n"
.. selected_code_content
.. "\n```"
.. "\n\nQUESTION:\n"
.. question
end
local url, headers, body
if Config.provider == "azure" then
@ -258,13 +298,14 @@ end
---@param question string
---@param code_lang string
---@param code_content string
---@param selected_content_content string | nil
---@param on_chunk fun(chunk: string): any
---@param on_complete fun(err: string|nil): any
function M.call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
function M.call_ai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
if Config.provider == "openai" or Config.provider == "azure" then
call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
call_openai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
elseif Config.provider == "claude" then
call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete)
call_claude_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
end
end

View File

@ -32,7 +32,8 @@ M.defaults = {
},
},
mappings = {
show_sidebar = "<leader>aa",
ask = "<leader>aa",
edit = "<leader>ae",
diff = {
ours = "co",
theirs = "ct",

View File

@ -4,6 +4,7 @@ local Tiktoken = require("avante.tiktoken")
local Sidebar = require("avante.sidebar")
local Config = require("avante.config")
local Diff = require("avante.diff")
local Selection = require("avante.selection")
---@class Avante
local M = {
@ -11,6 +12,7 @@ local M = {
sidebars = {},
---@type avante.Sidebar
current = nil,
selection = nil,
_once = false,
}
@ -35,7 +37,7 @@ H.commands = function()
end
H.keymaps = function()
vim.keymap.set({ "n" }, Config.mappings.show_sidebar, M.toggle, { noremap = true })
vim.keymap.set({ "n", "v" }, Config.mappings.ask, M.toggle, { noremap = true })
end
H.autocmds = function()
@ -76,7 +78,9 @@ H.autocmds = function()
if s then
s:destroy()
end
M.sidebars[tab] = nil
if tab ~= nil then
M.sidebars[tab] = nil
end
end,
})
@ -137,6 +141,10 @@ function M.setup(opts)
highlights = Config.highlights.diff,
})
local selection = Selection:new()
selection:setup()
M.selection = selection
-- setup helpers
H.autocmds()
H.commands()

24
lua/avante/range.lua Normal file
View File

@ -0,0 +1,24 @@
--@class avante.Range
--@field start table Selection start point
--@field start.line number Line number of the selection start
--@field start.col number Column number of the selection start
--@field finish table Selection end point
--@field finish.line number Line number of the selection end
--@field finish.col number Column number of the selection end
local Range = {}
Range.__index = Range
-- Create a selection range
-- @param start table Selection start point
-- @param start.line number Line number of the selection start
-- @param start.col number Column number of the selection start
-- @param finish table Selection end point
-- @param finish.line number Line number of the selection end
-- @param finish.col number Column number of the selection end
function Range.new(start, finish)
local self = setmetatable({}, Range)
self.start = start
self.finish = finish
return self
end
return Range

95
lua/avante/selection.lua Normal file
View File

@ -0,0 +1,95 @@
local Config = require("avante.config")
local api = vim.api
local fn = vim.fn
local NAMESPACE = api.nvim_create_namespace("avante_selection")
local PRIORITY = vim.highlight.priorities.user
local Selection = {}
function Selection:new()
return setmetatable({
hints_popup_extmark_id = nil,
edit_popup_renderer = nil,
augroup = api.nvim_create_augroup("avante_selection", { clear = true }),
}, { __index = self })
end
function Selection:get_virt_text_line()
local current_pos = fn.getpos(".")
-- Get the current and start position line numbers
local current_line = current_pos[2] - 1 -- 0-indexed
-- Ensure line numbers are not negative and don't exceed buffer range
local total_lines = api.nvim_buf_line_count(0)
if current_line < 0 then
current_line = 0
end
if current_line >= total_lines then
current_line = total_lines - 1
end
-- Take the first line of the selection to ensure virt_text is always in the top right corner
return current_line
end
function Selection:show_hints_popup()
self:close_hints_popup()
local hint_text = string.format(" [Ask %s] ", Config.mappings.ask)
local virt_text_line = self:get_virt_text_line()
self.hints_popup_extmark_id = vim.api.nvim_buf_set_extmark(0, NAMESPACE, virt_text_line, -1, {
virt_text = { { hint_text, "Keyword" } },
virt_text_pos = "eol",
priority = PRIORITY,
})
end
function Selection:close_hints_popup()
if self.hints_popup_extmark_id then
vim.api.nvim_buf_del_extmark(0, NAMESPACE, self.hints_popup_extmark_id)
self.hints_popup_extmark_id = nil
end
end
function Selection:setup()
vim.api.nvim_create_autocmd({ "ModeChanged" }, {
group = self.augroup,
pattern = { "n:v", "n:V", "n:" }, -- Entering Visual mode from Normal mode
callback = function()
self:show_hints_popup()
end,
})
api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI" }, {
group = self.augroup,
callback = function()
if vim.fn.mode() == "v" or vim.fn.mode() == "V" or vim.fn.mode() == "" then
self:show_hints_popup()
else
self:close_hints_popup()
end
end,
})
api.nvim_create_autocmd({ "ModeChanged" }, {
group = self.augroup,
pattern = { "v:n", "v:i", "v:c" }, -- Switching from visual mode back to normal, insert, or other modes
callback = function()
self:close_hints_popup()
end,
})
end
function Selection:delete_autocmds()
if self.augroup then
vim.api.nvim_del_augroup_by_id(self.augroup)
end
self.augroup = nil
end
return Selection

View File

@ -0,0 +1,17 @@
--@class avante.SelectionResult
--@field content string Selected content
--@field range avante.Range Selection range
local SelectionResult = {}
SelectionResult.__index = SelectionResult
-- Create a selection content and range
--@param content string Selected content
--@param range avante.Range Selection range
function SelectionResult.new(content, range)
local self = setmetatable({}, SelectionResult)
self.content = content
self.range = range
return self
end
return SelectionResult

View File

@ -20,6 +20,7 @@ local Sidebar = {}
---@class avante.SidebarState
---@field win integer
---@field buf integer
---@field selection avante.SelectionResult | nil
---@class avante.Sidebar
---@field id integer
@ -33,11 +34,11 @@ local Sidebar = {}
function Sidebar:new(id)
return setmetatable({
id = id,
code = { buf = 0, win = 0 },
code = { buf = 0, win = 0, selection = nil },
winid = { result = 0, input = 0 },
view = View:new(),
renderer = nil,
}, { __index = Sidebar })
}, { __index = self })
end
--- This function should only be used on TabClosed, nothing else.
@ -63,10 +64,17 @@ function Sidebar:reset()
end
function Sidebar:open()
local in_visual_mode = Utils.in_visual_mode() and self:in_code_win()
if not self.view:is_open() then
self:intialize()
self:render()
else
if in_visual_mode then
self:close()
self:intialize()
self:render()
return self
end
self:focus()
end
return self
@ -86,8 +94,13 @@ function Sidebar:focus()
return false
end
function Sidebar:in_code_win()
return self.code.win == api.nvim_get_current_win()
end
function Sidebar:toggle()
if self.view:is_open() then
local in_visual_mode = Utils.in_visual_mode() and self:in_code_win()
if self.view:is_open() and not in_visual_mode then
self:close()
return false
else
@ -108,6 +121,7 @@ end
function Sidebar:intialize()
self.code.win = api.nvim_get_current_win()
self.code.buf = api.nvim_get_current_buf()
self.code.selection = Utils.get_visual_selection_and_range()
local split_command = "botright vs"
local layout = Config.get_renderer_layout_options()
@ -123,15 +137,15 @@ function Sidebar:intialize()
})
self.renderer:on_mount(function()
local components = self.renderer:get_focusable_components()
-- current layout is a
-- [ chat ]
-- <gap>
-- [ input ]
self.winid.result = components[1].winid
self.winid.input = components[2].winid
self.winid.result = self.renderer:get_component_by_id("result").winid
self.winid.input = self.renderer:get_component_by_id("input").winid
self.augroup = api.nvim_create_augroup("avante_" .. self.id .. self.view.win, { clear = true })
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
local selected_code_buf = self.renderer:get_component_by_id("selected_code").bufnr
api.nvim_buf_set_option(selected_code_buf, "filetype", filetype)
api.nvim_set_option_value("wrap", false, { win = self.renderer:get_component_by_id("selected_code").winid })
api.nvim_create_autocmd("BufEnter", {
group = self.augroup,
buffer = self.view.buf,
@ -215,7 +229,11 @@ function Sidebar:update_content(content, focus, callback)
-- XXX: omit error for now, but should fix me why it can't jump here.
return err
end)
api.nvim_win_set_cursor(self.winid.result, { api.nvim_buf_line_count(self.view.buf), 0 })
xpcall(function()
api.nvim_win_set_cursor(self.winid.result, { api.nvim_buf_line_count(self.view.buf), 0 })
end, function(err)
return err
end)
end
end, 0)
return self
@ -260,10 +278,12 @@ local function is_cursor_in_codeblock(codeblocks)
return nil
end
local function prepend_line_number(content)
local function prepend_line_number(content, start_line)
start_line = start_line or 1
local lines = vim.split(content, "\n")
local result = {}
for i, line in ipairs(lines) do
i = i + start_line - 1
table.insert(result, "L" .. i .. ": " .. line)
end
return table.concat(result, "\n")
@ -560,25 +580,53 @@ function Sidebar:render()
local content = self:get_code_content()
local content_with_line_numbers = prepend_line_number(content)
local selected_code_content_with_line_numbers = nil
if self.code.selection ~= nil then
selected_code_content_with_line_numbers =
prepend_line_number(self.code.selection.content, self.code.selection.range.start.line)
end
local full_response = ""
signal.is_loading = true
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
AiBot.call_ai_api_stream(user_input, filetype, content_with_line_numbers, function(chunk)
full_response = full_response .. chunk
self:update_content(
"## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response,
true
)
vim.schedule(function()
vim.cmd("redraw")
end)
end, function(err)
signal.is_loading = false
AiBot.call_ai_api_stream(
user_input,
filetype,
content_with_line_numbers,
selected_code_content_with_line_numbers,
function(chunk)
full_response = full_response .. chunk
self:update_content(
"## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response,
true
)
vim.schedule(function()
vim.cmd("redraw")
end)
end,
function(err)
signal.is_loading = false
if err ~= nil then
if err ~= nil then
self:update_content(
"## "
.. timestamp
.. "\n\n> "
.. user_input:gsub("\n", "\n> ")
.. "\n\n"
.. full_response
.. "\n\n🚨 Error: "
.. vim.inspect(err),
true
)
return
end
-- Execute when the stream request is actually completed
self:update_content(
"## "
.. timestamp
@ -586,37 +634,23 @@ function Sidebar:render()
.. user_input:gsub("\n", "\n> ")
.. "\n\n"
.. full_response
.. "\n\n🚨 Error: "
.. vim.inspect(err),
true
.. "\n\n**Generation complete!** Please review the code suggestions above.\n\n\n\n",
true,
function()
api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN })
end
)
return
api.nvim_set_current_win(self.winid.result)
-- Display notification
-- show_notification("Content generation complete!")
-- Save chat history
table.insert(chat_history or {}, { timestamp = timestamp, requirement = user_input, response = full_response })
save_chat_history(self, chat_history)
end
-- Execute when the stream request is actually completed
self:update_content(
"## "
.. timestamp
.. "\n\n> "
.. user_input:gsub("\n", "\n> ")
.. "\n\n"
.. full_response
.. "\n\n**Generation complete!** Please review the code suggestions above.\n\n\n\n",
true,
function()
api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN })
end
)
api.nvim_set_current_win(self.winid.result)
-- Display notification
-- show_notification("Content generation complete!")
-- Save chat history
table.insert(chat_history or {}, { timestamp = timestamp, requirement = user_input, response = full_response })
save_chat_history(self, chat_history)
end)
)
end
local body = function()
@ -625,15 +659,38 @@ function Sidebar:render()
local code_file_fullpath = api.nvim_buf_get_name(self.code.buf)
local code_filename = fn.fnamemodify(code_file_fullpath, ":t")
local input_label = string.format(" 🙋 with %s %s (<Tab> switch focus): ", icon, code_filename)
if self.code.selection ~= nil then
input_label = string.format(
" 🙋 with selected code in %s %s(%d:%d) (<Tab> switch focus): ",
icon,
code_filename,
self.code.selection.range.start.line,
self.code.selection.range.finish.line
)
end
local selected_code_lines_count = 0
local selected_code_max_lines_count = 10
local selected_code_size = 0
if self.code.selection ~= nil then
local selected_code_lines = vim.split(self.code.selection.content, "\n")
selected_code_lines_count = #selected_code_lines
selected_code_size = math.min(selected_code_lines_count, selected_code_max_lines_count) + 4
end
return N.rows(
{ flex = 0 },
N.box(
{
direction = "column",
size = vim.o.lines - 4,
size = vim.o.lines - 4 - selected_code_size,
},
N.buffer({
id = "response",
id = "result",
flex = 1,
buf = self.view.buf,
autoscroll = true,
@ -650,16 +707,35 @@ function Sidebar:render()
})
),
N.gap(1),
N.paragraph({
hidden = self.code.selection == nil,
id = "selected_code",
lines = self.code.selection and self.code.selection.content or "",
border_label = {
text = "💻 Selected Code"
.. (
selected_code_lines_count > selected_code_max_lines_count
and " (Show only the first " .. tostring(selected_code_max_lines_count) .. " lines)"
or ""
),
align = "center",
},
align = "left",
is_focusable = false,
max_lines = selected_code_max_lines_count,
padding = {
top = 1,
bottom = 1,
left = 1,
right = 1,
},
}),
N.columns(
{ flex = 0 },
N.text_input({
id = "text-input",
id = "input",
border_label = {
text = string.format(
" 🙋 with %s %s (<Tab> key to switch between result and input): ",
icon,
code_filename
),
text = input_label,
},
placeholder = "Enter your question",
autofocus = true,

View File

@ -1,11 +1,57 @@
local Range = require("avante.range")
local SelectionResult = require("avante.selection_result")
local M = {}
function M.trim_suffix(str, suffix)
return string.gsub(str, suffix .. "$", "")
end
function M.trim_line_number_prefix(line)
return line:gsub("^L%d+: ", "")
end
function M.in_visual_mode()
local current_mode = vim.fn.mode()
return current_mode == "v" or current_mode == "V" or current_mode == ""
end
-- Get the selected content and range in Visual mode
-- @return avante.SelectionResult | nil Selected content and range
function M.get_visual_selection_and_range()
if not M.in_visual_mode() then
return nil
end
-- Get the start and end positions of Visual mode
local start_pos = vim.fn.getpos("v")
local end_pos = vim.fn.getpos(".")
-- Get the start and end line and column numbers
local start_line = start_pos[2]
local start_col = start_pos[3]
local end_line = end_pos[2]
local end_col = end_pos[3]
-- If the start point is after the end point, swap them
if start_line > end_line or (start_line == end_line and start_col > end_col) then
start_line, end_line = end_line, start_line
start_col, end_col = end_col, start_col
end
local content = ""
local range = Range.new({ line = start_line, col = start_col }, { line = end_line, col = end_col })
-- Check if it's a single-line selection
if start_line == end_line then
-- Get partial content of a single line
local line = vim.fn.getline(start_line)
-- content = string.sub(line, start_col, end_col)
content = line
else
-- Multi-line selection: Get all lines in the selection
local lines = vim.fn.getline(start_line, end_line)
-- Extract partial content of the first line
-- lines[1] = string.sub(lines[1], start_col)
-- Extract partial content of the last line
-- lines[#lines] = string.sub(lines[#lines], 1, end_col)
-- Concatenate all lines in the selection into a string
content = table.concat(lines, "\n")
end
if not content then
return nil
end
-- Return the selected content and range
return SelectionResult.new(content, range)
end
return M