feat: repo map (#496)

* feat: repo map

* chore: remove breakline

* chore: remove spaces

* fix: golang public method

* feat: mentions for editing input
This commit is contained in:
yetone 2024-09-23 18:52:26 +08:00 committed by GitHub
parent 8dbfe85dd4
commit 8e1018fef7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1191 additions and 64 deletions

View File

@ -38,6 +38,7 @@ For building binary if you wish to build from source, then `cargo` is required.
build = "make", build = "make",
-- build = "powershell -ExecutionPolicy Bypass -File Build.ps1 -BuildFromSource false" -- for windows -- build = "powershell -ExecutionPolicy Bypass -File Build.ps1 -BuildFromSource false" -- for windows
dependencies = { dependencies = {
"nvim-treesitter/nvim-treesitter",
"stevearc/dressing.nvim", "stevearc/dressing.nvim",
"nvim-lua/plenary.nvim", "nvim-lua/plenary.nvim",
"MunifTanjim/nui.nvim", "MunifTanjim/nui.nvim",
@ -427,7 +428,7 @@ If you have the following structure:
- [x] Slash commands - [x] Slash commands
- [x] Edit the selected block - [x] Edit the selected block
- [x] Smart Tab (Cursor Flow) - [x] Smart Tab (Cursor Flow)
- [ ] Chat with project - [x] Chat with project (You can use `@codebase` to chat with the whole project)
- [ ] Chat with selected files - [ ] Chat with selected files
## Roadmap ## Roadmap

View File

@ -21,6 +21,7 @@ struct TemplateContext {
ask: bool, ask: bool,
question: String, question: String,
code_lang: String, code_lang: String,
filepath: String,
file_content: String, file_content: String,
selected_code: Option<String>, selected_code: Option<String>,
project_context: Option<String>, project_context: Option<String>,
@ -45,6 +46,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
ask => context.ask, ask => context.ask,
question => context.question, question => context.question,
code_lang => context.code_lang, code_lang => context.code_lang,
filepath => context.filepath,
file_content => context.file_content, file_content => context.file_content,
selected_code => context.selected_code, selected_code => context.selected_code,
project_context => context.project_context, project_context => context.project_context,

View File

@ -60,11 +60,14 @@ M.stream = function(opts)
Path.prompts.initialize(Path.prompts.get(opts.bufnr)) Path.prompts.initialize(Path.prompts.get(opts.bufnr))
local filepath = Utils.relative_path(api.nvim_buf_get_name(opts.bufnr))
local template_opts = { local template_opts = {
use_xml_format = Provider.use_xml_format, use_xml_format = Provider.use_xml_format,
ask = opts.ask, -- TODO: add mode without ask instruction ask = opts.ask, -- TODO: add mode without ask instruction
question = original_instructions, question = original_instructions,
code_lang = opts.code_lang, code_lang = opts.code_lang,
filepath = filepath,
file_content = opts.file_content, file_content = opts.file_content,
selected_code = opts.selected_code, selected_code = opts.selected_code,
project_context = opts.project_context, project_context = opts.project_context,

View File

@ -28,17 +28,17 @@ end
H.get_mode_file = function(mode) return string.format("custom.%s.avanterules", mode) end H.get_mode_file = function(mode) return string.format("custom.%s.avanterules", mode) end
-- History path -- History path
local M = {} local History = {}
-- Returns the Path to the chat history file for the given buffer. -- Returns the Path to the chat history file for the given buffer.
---@param bufnr integer ---@param bufnr integer
---@return Path ---@return Path
M.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(H.filename(bufnr)) end History.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(H.filename(bufnr)) end
-- Loads the chat history for the given buffer. -- Loads the chat history for the given buffer.
---@param bufnr integer ---@param bufnr integer
M.load = function(bufnr) History.load = function(bufnr)
local history_file = M.get(bufnr) local history_file = History.get(bufnr)
if history_file:exists() then if history_file:exists() then
local content = history_file:read() local content = history_file:read()
return content ~= nil and vim.json.decode(content) or {} return content ~= nil and vim.json.decode(content) or {}
@ -49,29 +49,29 @@ end
-- Saves the chat history for the given buffer. -- Saves the chat history for the given buffer.
---@param bufnr integer ---@param bufnr integer
---@param history table ---@param history table
M.save = function(bufnr, history) History.save = function(bufnr, history)
local history_file = M.get(bufnr) local history_file = History.get(bufnr)
history_file:write(vim.json.encode(history), "w") history_file:write(vim.json.encode(history), "w")
end end
P.history = M P.history = History
-- Prompt path -- Prompt path
local N = {} local Prompt = {}
---@class AvanteTemplates ---@class AvanteTemplates
---@field initialize fun(directory: string): nil ---@field initialize fun(directory: string): nil
---@field render fun(template: string, context: TemplateOptions): string ---@field render fun(template: string, context: TemplateOptions): string
local templates = nil local templates = nil
N.templates = { planning = nil, editing = nil, suggesting = nil } Prompt.templates = { planning = nil, editing = nil, suggesting = nil }
-- Creates a directory in the cache path for the given buffer and copies the custom prompts to it. -- Creates a directory in the cache path for the given buffer and copies the custom prompts to it.
-- We need to do this beacuse the prompt template engine requires a given directory to load all required files. -- We need to do this beacuse the prompt template engine requires a given directory to load all required files.
-- PERF: Hmm instead of copy to cache, we can also load in globals context, but it requires some work on bindings. (eh maybe?) -- PERF: Hmm instead of copy to cache, we can also load in globals context, but it requires some work on bindings. (eh maybe?)
---@param bufnr number ---@param bufnr number
---@return string the resulted cache_directory to be loaded with avante_templates ---@return string the resulted cache_directory to be loaded with avante_templates
N.get = function(bufnr) Prompt.get = function(bufnr)
if not P.available() then error("Make sure to build avante (missing avante_templates)", 2) end if not P.available() then error("Make sure to build avante (missing avante_templates)", 2) end
-- get root directory of given bufnr -- get root directory of given bufnr
@ -85,19 +85,19 @@ N.get = function(bufnr)
local scanner = Scan.scan_dir(directory:absolute(), { depth = 1, add_dirs = true }) local scanner = Scan.scan_dir(directory:absolute(), { depth = 1, add_dirs = true })
for _, entry in ipairs(scanner) do for _, entry in ipairs(scanner) do
local file = Path:new(entry) local file = Path:new(entry)
if entry:find("planning") and N.templates.planning == nil then if entry:find("planning") and Prompt.templates.planning == nil then
N.templates.planning = file:read() Prompt.templates.planning = file:read()
elseif entry:find("editing") and N.templates.editing == nil then elseif entry:find("editing") and Prompt.templates.editing == nil then
N.templates.editing = file:read() Prompt.templates.editing = file:read()
elseif entry:find("suggesting") and N.templates.suggesting == nil then elseif entry:find("suggesting") and Prompt.templates.suggesting == nil then
N.templates.suggesting = file:read() Prompt.templates.suggesting = file:read()
end end
end end
Path:new(debug.getinfo(1).source:match("@?(.*/)"):gsub("/lua/avante/path.lua$", "") .. "templates") Path:new(debug.getinfo(1).source:match("@?(.*/)"):gsub("/lua/avante/path.lua$", "") .. "templates")
:copy({ destination = cache_prompt_dir, recursive = true }) :copy({ destination = cache_prompt_dir, recursive = true })
vim.iter(N.templates):filter(function(_, v) return v ~= nil end):each(function(k, v) vim.iter(Prompt.templates):filter(function(_, v) return v ~= nil end):each(function(k, v)
local f = cache_prompt_dir:joinpath(H.get_mode_file(k)) local f = cache_prompt_dir:joinpath(H.get_mode_file(k))
f:write(v, "w") f:write(v, "w")
end) end)
@ -106,22 +106,53 @@ N.get = function(bufnr)
end end
---@param mode LlmMode ---@param mode LlmMode
N.get_file = function(mode) Prompt.get_file = function(mode)
if N.templates[mode] ~= nil then return H.get_mode_file(mode) end if Prompt.templates[mode] ~= nil then return H.get_mode_file(mode) end
return string.format("%s.avanterules", mode) return string.format("%s.avanterules", mode)
end end
---@param path string ---@param path string
---@param opts TemplateOptions ---@param opts TemplateOptions
N.render_file = function(path, opts) return templates.render(path, opts) end Prompt.render_file = function(path, opts) return templates.render(path, opts) end
---@param mode LlmMode ---@param mode LlmMode
---@param opts TemplateOptions ---@param opts TemplateOptions
N.render_mode = function(mode, opts) return templates.render(N.get_file(mode), opts) end Prompt.render_mode = function(mode, opts) return templates.render(Prompt.get_file(mode), opts) end
N.initialize = function(directory) templates.initialize(directory) end Prompt.initialize = function(directory) templates.initialize(directory) end
P.prompts = N P.prompts = Prompt
local RepoMap = {}
-- Get a chat history file name given a buffer
---@param project_root string
---@param ext string
---@return string
RepoMap.filename = function(project_root, ext)
-- Replace path separators with double underscores
local path_with_separators = fn.substitute(project_root, "/", "__", "g")
-- Replace other non-alphanumeric characters with single underscores
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. "." .. ext .. ".repo_map.json"
end
RepoMap.get = function(project_root, ext) return Path:new(P.data_path):joinpath(RepoMap.filename(project_root, ext)) end
RepoMap.save = function(project_root, ext, data)
local file = RepoMap.get(project_root, ext)
file:write(vim.json.encode(data), "w")
end
RepoMap.load = function(project_root, ext)
local file = RepoMap.get(project_root, ext)
if file:exists() then
local content = file:read()
return content ~= nil and vim.json.decode(content) or {}
end
return nil
end
P.repo_map = RepoMap
P.setup = function() P.setup = function()
local history_path = Path:new(Config.history.storage_path) local history_path = Path:new(Config.history.storage_path)
@ -132,6 +163,10 @@ P.setup = function()
if not cache_path:exists() then cache_path:mkdir({ parents = true }) end if not cache_path:exists() then cache_path:mkdir({ parents = true }) end
P.cache_path = cache_path P.cache_path = cache_path
local data_path = Path:new(vim.fn.stdpath("data") .. "/avante")
if not data_path:exists() then data_path:mkdir({ parents = true }) end
P.data_path = data_path
vim.defer_fn(function() vim.defer_fn(function()
local ok, module = pcall(require, "avante_templates") local ok, module = pcall(require, "avante_templates")
---@cast module AvanteTemplates ---@cast module AvanteTemplates

View File

@ -1,7 +1,6 @@
local Utils = require("avante.utils") local Utils = require("avante.utils")
local Config = require("avante.config") local Config = require("avante.config")
local Llm = require("avante.llm") local Llm = require("avante.llm")
local Highlights = require("avante.highlights")
local Provider = require("avante.providers") local Provider = require("avante.providers")
local api = vim.api local api = vim.api
@ -391,10 +390,16 @@ function Selection:create_editing_input()
end end
local filetype = api.nvim_get_option_value("filetype", { buf = code_bufnr }) local filetype = api.nvim_get_option_value("filetype", { buf = code_bufnr })
local file_ext = api.nvim_buf_get_name(code_bufnr):match("^.+%.(.+)$")
local mentions = Utils.extract_mentions(input)
input = mentions.new_content
local project_context = mentions.enable_project_context and Utils.repo_map.get_repo_map(file_ext) or nil
Llm.stream({ Llm.stream({
bufnr = code_bufnr, bufnr = code_bufnr,
ask = true, ask = true,
project_context = vim.json.encode(project_context),
file_content = code_content, file_content = code_content,
code_lang = filetype, code_lang = filetype,
selected_code = self.selection.content, selected_code = self.selection.content,
@ -453,6 +458,25 @@ function Selection:create_editing_input()
end, end,
}) })
api.nvim_create_autocmd("InsertEnter", {
group = self.augroup,
buffer = bufnr,
once = true,
desc = "Setup the completion of helpers in the input buffer",
callback = function()
local has_cmp, cmp = pcall(require, "cmp")
if has_cmp then
cmp.register_source("avante_mentions", require("cmp_avante.mentions").new(Utils.get_mentions(), bufnr))
cmp.setup.buffer({
enabled = true,
sources = {
{ name = "avante_mentions" },
},
})
end
end,
})
api.nvim_create_autocmd("User", { api.nvim_create_autocmd("User", {
pattern = "AvanteEditSubmitted", pattern = "AvanteEditSubmitted",
callback = function(ev) callback = function(ev)

View File

@ -441,10 +441,8 @@ local function insert_conflict_contents(bufnr, snippets)
local result = {} local result = {}
table.insert(result, "<<<<<<< HEAD") table.insert(result, "<<<<<<< HEAD")
if start_line ~= end_line then for i = start_line, end_line do
for i = start_line, end_line do table.insert(result, lines[i])
table.insert(result, lines[i])
end
end end
table.insert(result, "=======") table.insert(result, "=======")
@ -460,8 +458,6 @@ local function insert_conflict_contents(bufnr, snippets)
table.insert(result, line) table.insert(result, line)
end end
if start_line == end_line then table.insert(result, lines[start_line]) end
table.insert(result, ">>>>>>> Snippet") table.insert(result, ">>>>>>> Snippet")
api.nvim_buf_set_lines(bufnr, offset + start_line - 1, offset + end_line, false, result) api.nvim_buf_set_lines(bufnr, offset + start_line - 1, offset + end_line, false, result)
@ -1296,9 +1292,17 @@ function Sidebar:create_input(opts)
Path.history.save(self.code.bufnr, chat_history) Path.history.save(self.code.bufnr, chat_history)
end end
local mentions = Utils.extract_mentions(request)
request = mentions.new_content
local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$")
local project_context = mentions.enable_project_context and Utils.repo_map.get_repo_map(file_ext) or nil
Llm.stream({ Llm.stream({
bufnr = self.code.bufnr, bufnr = self.code.bufnr,
ask = opts.ask, ask = opts.ask,
project_context = vim.json.encode(project_context),
file_content = content_with_line_numbers, file_content = content_with_line_numbers,
code_lang = filetype, code_lang = filetype,
selected_code = selected_code_content_with_line_numbers, selected_code = selected_code_content_with_line_numbers,
@ -1358,9 +1362,9 @@ function Sidebar:create_input(opts)
local function place_sign_at_first_line(bufnr) local function place_sign_at_first_line(bufnr)
local group = "avante_input_prompt_group" local group = "avante_input_prompt_group"
vim.fn.sign_unplace(group, { buffer = bufnr }) fn.sign_unplace(group, { buffer = bufnr })
vim.fn.sign_place(0, group, "AvanteInputPromptSign", bufnr, { lnum = 1 }) fn.sign_place(0, group, "AvanteInputPromptSign", bufnr, { lnum = 1 })
end end
place_sign_at_first_line(self.input.bufnr) place_sign_at_first_line(self.input.bufnr)
@ -1387,11 +1391,16 @@ function Sidebar:create_input(opts)
if not self.registered_cmp then if not self.registered_cmp then
self.registered_cmp = true self.registered_cmp = true
cmp.register_source("avante_commands", require("cmp_avante.commands").new(self)) cmp.register_source("avante_commands", require("cmp_avante.commands").new(self))
cmp.register_source(
"avante_mentions",
require("cmp_avante.mentions").new(Utils.get_mentions(), self.input.bufnr)
)
end end
cmp.setup.buffer({ cmp.setup.buffer({
enabled = true, enabled = true,
sources = { sources = {
{ name = "avante_commands" }, { name = "avante_commands" },
{ name = "avante_mentions" },
}, },
}) })
end end

View File

@ -82,24 +82,26 @@ function Suggestion:suggest()
return return
end end
Utils.debug("full_response: " .. vim.inspect(full_response)) Utils.debug("full_response: " .. vim.inspect(full_response))
local cursor_row, cursor_col = Utils.get_cursor_pos() vim.schedule(function()
if cursor_row ~= doc.position.row or cursor_col ~= doc.position.col then return end local cursor_row, cursor_col = Utils.get_cursor_pos()
local ok, suggestions = pcall(vim.json.decode, full_response) if cursor_row ~= doc.position.row or cursor_col ~= doc.position.col then return end
if not ok then local ok, suggestions = pcall(vim.json.decode, full_response)
Utils.error("Error while decoding suggestions: " .. full_response, { once = true, title = "Avante" }) if not ok then
return Utils.error("Error while decoding suggestions: " .. full_response, { once = true, title = "Avante" })
end return
if not suggestions then end
Utils.info("No suggestions found", { once = true, title = "Avante" }) if not suggestions then
return Utils.info("No suggestions found", { once = true, title = "Avante" })
end return
suggestions = vim end
.iter(suggestions) suggestions = vim
:map(function(s) return { row = s.row, col = s.col, content = Utils.trim_all_line_numbers(s.content) } end) .iter(suggestions)
:totable() :map(function(s) return { row = s.row, col = s.col, content = Utils.trim_all_line_numbers(s.content) } end)
ctx.suggestions = suggestions :totable()
ctx.current_suggestion_idx = 1 ctx.suggestions = suggestions
self:show() ctx.current_suggestion_idx = 1
self:show()
end)
end, end,
}) })
end end

View File

@ -1,4 +1,6 @@
{%- if use_xml_format -%} {%- if use_xml_format -%}
<filepath>{{filepath}}</filepath>
{%- if selected_code -%} {%- if selected_code -%}
<context> <context>
```{{code_lang}} ```{{code_lang}}
@ -19,6 +21,8 @@
</code> </code>
{%- endif %} {%- endif %}
{% else %} {% else %}
FILEPATH: {{filepath}}
{%- if selected_code -%} {%- if selected_code -%}
CONTEXT: CONTEXT:
```{{code_lang}} ```{{code_lang}}

View File

@ -17,6 +17,7 @@ Your task is to suggest code modifications at the cursor position. Follow these
{% endraw %} {% endraw %}
3. When suggesting suggested code: 3. When suggesting suggested code:
- DO NOT include three backticks: {%raw%}```{%endraw%} in your suggestion. Treat the suggested code AS IS.
- Each element in the returned list is a COMPLETE and INDEPENDENT code snippet. - Each element in the returned list is a COMPLETE and INDEPENDENT code snippet.
- MUST be a valid json format. Don't be lazy! - MUST be a valid json format. Don't be lazy!
- Only return the new code to be inserted. - Only return the new code to be inserted.
@ -29,4 +30,3 @@ Your task is to suggest code modifications at the cursor position. Follow these
Remember to ONLY RETURN the suggested code snippet, without any additional formatting or explanation. Remember to ONLY RETURN the suggested code snippet, without any additional formatting or explanation.
{% endblock %} {% endblock %}

36
lua/avante/utils/file.lua Normal file
View File

@ -0,0 +1,36 @@
local LRUCache = require("avante.utils.lru_cache")
---@class avante.utils.file
local M = {}
local api = vim.api
local fn = vim.fn
local _file_content_lru_cache = LRUCache:new(60)
api.nvim_create_autocmd("BufWritePost", {
callback = function()
local filepath = api.nvim_buf_get_name(0)
local keys = _file_content_lru_cache:keys()
if vim.tbl_contains(keys, filepath) then
local content = table.concat(api.nvim_buf_get_lines(0, 0, -1, false), "\n")
_file_content_lru_cache:set(filepath, content)
end
end,
})
function M.read_content(filepath)
local cached_content = _file_content_lru_cache:get(filepath)
if cached_content then return cached_content end
local content = fn.readfile(filepath)
if content then
content = table.concat(content, "\n")
_file_content_lru_cache:set(filepath, content)
return content
end
return nil
end
return M

View File

@ -5,6 +5,7 @@ local lsp = vim.lsp
---@class avante.utils: LazyUtilCore ---@class avante.utils: LazyUtilCore
---@field tokens avante.utils.tokens ---@field tokens avante.utils.tokens
---@field root avante.utils.root ---@field root avante.utils.root
---@field repo_map avante.utils.repo_map
local M = {} local M = {}
setmetatable(M, { setmetatable(M, {
@ -444,7 +445,7 @@ function M.get_indentation(code) return code:match("^%s*") or "" end
--- remove indentation from code: spaces or tabs --- remove indentation from code: spaces or tabs
function M.remove_indentation(code) return code:gsub("^%s*", "") end function M.remove_indentation(code) return code:gsub("^%s*", "") end
local function relative_path(absolute) function M.relative_path(absolute)
local relative = fn.fnamemodify(absolute, ":.") local relative = fn.fnamemodify(absolute, ":.")
if string.sub(relative, 0, 1) == "/" then return fn.fnamemodify(absolute, ":t") end if string.sub(relative, 0, 1) == "/" then return fn.fnamemodify(absolute, ":t") end
return relative return relative
@ -462,7 +463,7 @@ function M.get_doc()
local doc = { local doc = {
uri = params.textDocument.uri, uri = params.textDocument.uri,
version = api.nvim_buf_get_var(0, "changedtick"), version = api.nvim_buf_get_var(0, "changedtick"),
relativePath = relative_path(absolute), relativePath = M.relative_path(absolute),
insertSpaces = vim.o.expandtab, insertSpaces = vim.o.expandtab,
tabSize = fn.shiftwidth(), tabSize = fn.shiftwidth(),
indentSize = fn.shiftwidth(), indentSize = fn.shiftwidth(),
@ -520,4 +521,126 @@ function M.winline(winid)
return line return line
end end
function M.get_project_root() return M.root.get() end
function M.is_same_file_ext(target_ext, filepath)
local ext = fn.fnamemodify(filepath, ":e")
if target_ext == "tsx" and ext == "ts" then return true end
if target_ext == "jsx" and ext == "js" then return true end
return ext == target_ext
end
-- Get recent filepaths in the same project and same file ext
function M.get_recent_filepaths(limit, filenames)
local project_root = M.get_project_root()
local current_ext = fn.expand("%:e")
local oldfiles = vim.v.oldfiles
local recent_files = {}
for _, file in ipairs(oldfiles) do
if vim.startswith(file, project_root) and M.is_same_file_ext(current_ext, file) then
if filenames and #filenames > 0 then
for _, filename in ipairs(filenames) do
if file:find(filename) then table.insert(recent_files, file) end
end
else
table.insert(recent_files, file)
end
if #recent_files >= (limit or 10) then break end
end
end
return recent_files
end
local function pattern_to_lua(pattern)
local lua_pattern = pattern:gsub("[%(%)%.%%%+%-%*%?%[%]%^%$]", "%%%1")
lua_pattern = lua_pattern:gsub("%*%*/", ".-/")
lua_pattern = lua_pattern:gsub("%*", "[^/]*")
lua_pattern = lua_pattern:gsub("%?", ".")
if lua_pattern:sub(-1) == "/" then lua_pattern = lua_pattern .. ".*" end
return lua_pattern
end
function M.parse_gitignore(gitignore_path)
local ignore_patterns = { ".git", ".worktree", "__pycache__", "node_modules" }
local negate_patterns = {}
local file = io.open(gitignore_path, "r")
if not file then return ignore_patterns, negate_patterns end
for line in file:lines() do
if line:match("%S") and not line:match("^#") then
local trimmed_line = line:match("^%s*(.-)%s*$")
if trimmed_line:sub(1, 1) == "!" then
table.insert(negate_patterns, pattern_to_lua(trimmed_line:sub(2)))
else
table.insert(ignore_patterns, pattern_to_lua(trimmed_line))
end
end
end
file:close()
return ignore_patterns, negate_patterns
end
local function is_ignored(file, ignore_patterns, negate_patterns)
for _, pattern in ipairs(negate_patterns) do
if file:match(pattern) then return false end
end
for _, pattern in ipairs(ignore_patterns) do
if file:match(pattern) then return true end
end
return false
end
function M.scan_directory(directory, ignore_patterns, negate_patterns)
local files = {}
local handle = vim.loop.fs_scandir(directory)
if not handle then return files end
while true do
local name, type = vim.loop.fs_scandir_next(handle)
if not name then break end
local full_path = directory .. "/" .. name
if type == "directory" then
vim.list_extend(files, M.scan_directory(full_path, ignore_patterns, negate_patterns))
elseif type == "file" then
if not is_ignored(full_path, ignore_patterns, negate_patterns) then table.insert(files, full_path) end
end
end
return files
end
function M.is_first_letter_uppercase(str) return string.match(str, "^[A-Z]") ~= nil end
---@param content string
---@return { new_content: string, enable_project_context: boolean }
function M.extract_mentions(content)
-- if content contains @codebase, enable project context and remove @codebase
local new_content = content
local enable_project_context = false
if content:match("@codebase") then
enable_project_context = true
new_content = content:gsub("@codebase", "")
end
return { new_content = new_content, enable_project_context = enable_project_context }
end
---@alias AvanteMentions "codebase"
---@alias AvanteMentionCallback fun(args: string, cb?: fun(args: string): nil): nil
---@alias AvanteMention {description: string, command: AvanteMentions, details: string, shorthelp?: string, callback?: AvanteMentionCallback}
---@return AvanteMention[]
function M.get_mentions()
return {
{
description = "codebase",
command = "codebase",
details = "repo map",
},
}
end
return M return M

View File

@ -0,0 +1,115 @@
local LRUCache = {}
LRUCache.__index = LRUCache
function LRUCache:new(capacity)
return setmetatable({
capacity = capacity,
cache = {},
head = nil,
tail = nil,
size = 0,
}, LRUCache)
end
-- Internal function: Move node to head (indicating most recently used)
function LRUCache:_move_to_head(node)
if self.head == node then return end
-- Disconnect the node
if node.prev then node.prev.next = node.next end
if node.next then node.next.prev = node.prev end
if self.tail == node then self.tail = node.prev end
-- Insert the node at the head
node.next = self.head
node.prev = nil
if self.head then self.head.prev = node end
self.head = node
if not self.tail then self.tail = node end
end
-- Get value from cache
function LRUCache:get(key)
local node = self.cache[key]
if not node then return nil end
self:_move_to_head(node)
return node.value
end
-- Set value in cache
function LRUCache:set(key, value)
local node = self.cache[key]
if node then
node.value = value
self:_move_to_head(node)
else
node = { key = key, value = value }
self.cache[key] = node
self.size = self.size + 1
self:_move_to_head(node)
if self.size > self.capacity then
local tail_key = self.tail.key
self.tail = self.tail.prev
if self.tail then self.tail.next = nil end
self.cache[tail_key] = nil
self.size = self.size - 1
end
end
end
-- Remove specified cache entry
function LRUCache:remove(key)
local node = self.cache[key]
if not node then return end
if node.prev then
node.prev.next = node.next
else
self.head = node.next
end
if node.next then
node.next.prev = node.prev
else
self.tail = node.prev
end
self.cache[key] = nil
self.size = self.size - 1
end
-- Get current size of cache
function LRUCache:get_size() return self.size end
-- Get capacity of cache
function LRUCache:get_capacity() return self.capacity end
-- Print current cache contents (for debugging)
function LRUCache:print_cache()
local node = self.head
while node do
print(node.key, node.value)
node = node.next
end
end
function LRUCache:keys()
local keys = {}
local node = self.head
while node do
table.insert(keys, node.key)
node = node.next
end
return keys
end
return LRUCache

View File

@ -0,0 +1,730 @@
local parsers = require("nvim-treesitter.parsers")
local Config = require("avante.config")
local get_node_text = vim.treesitter.get_node_text
---@class avante.utils.repo_map
local RepoMap = {}
local dependencies_queries = {
lua = [[
(function_call
name: (identifier) @function_name
arguments: (arguments
(string) @required_file))
]],
python = [[
(import_from_statement
module_name: (dotted_name) @import_module)
(import_statement
(dotted_name) @import_module)
]],
javascript = [[
(import_statement
source: (string) @import_module)
(call_expression
function: (identifier) @function_name
arguments: (arguments
(string) @required_file))
]],
typescript = [[
(import_statement
source: (string) @import_module)
(call_expression
function: (identifier) @function_name
arguments: (arguments
(string) @required_file))
]],
go = [[
(import_spec
path: (interpreted_string_literal) @import_module)
]],
rust = [[
(use_declaration
(scoped_identifier) @import_module)
(use_declaration
(identifier) @import_module)
]],
c = [[
(preproc_include
(string_literal) @import_module)
(preproc_include
(system_lib_string) @import_module)
]],
cpp = [[
(preproc_include
(string_literal) @import_module)
(preproc_include
(system_lib_string) @import_module)
]],
}
local definitions_queries = {
python = [[
;; Capture top-level functions, class, and method definitions
(module
(expression_statement
(assignment) @assignment
)
)
(module
(function_definition) @function
)
(module
(class_definition
body: (block
(expression_statement
(assignment) @class_assignment
)
)
)
)
(module
(class_definition
body: (block
(function_definition) @method
)
)
)
]],
javascript = [[
;; Capture exported functions, arrow functions, variables, classes, and method definitions
(export_statement
declaration: (lexical_declaration
(variable_declarator) @variable
)
)
(export_statement
declaration: (function_declaration) @function
)
(export_statement
declaration: (class_declaration
body: (class_body
(field_definition) @class_variable
)
)
)
(export_statement
declaration: (class_declaration
body: (class_body
(method_definition) @method
)
)
)
]],
typescript = [[
;; Capture exported functions, arrow functions, variables, classes, and method definitions
(export_statement
declaration: (lexical_declaration
(variable_declarator) @variable
)
)
(export_statement
declaration: (function_declaration) @function
)
(export_statement
declaration: (class_declaration
body: (class_body
(public_field_definition) @class_variable
)
)
)
(interface_declaration
body: (interface_body
(property_signature) @class_variable
)
)
(type_alias_declaration
value: (object_type
(property_signature) @class_variable
)
)
(export_statement
declaration: (class_declaration
body: (class_body
(method_definition) @method
)
)
)
]],
rust = [[
;; Capture public functions, structs, methods, and variable definitions
(function_item) @function
(impl_item
body: (declaration_list
(function_item) @method
)
)
(struct_item
body: (field_declaration_list
(field_declaration) @class_variable
)
)
(enum_item
body: (enum_variant_list
(enum_variant) @enum_item
)
)
(const_item) @variable
]],
go = [[
;; Capture top-level functions and struct definitions
(var_declaration
(var_spec) @variable
)
(const_declaration
(const_spec) @variable
)
(function_declaration) @function
(type_declaration
(type_spec (struct_type)) @class
)
(type_declaration
(type_spec
(struct_type
(field_declaration_list
(field_declaration) @class_variable)))
)
(method_declaration) @method
]],
c = [[
;; Capture extern functions, variables, public classes, and methods
(function_definition
(storage_class_specifier) @extern
) @function
(class_specifier
(public) @class
(function_definition) @method
) @class
(declaration
(storage_class_specifier) @extern
) @variable
]],
cpp = [[
;; Capture extern functions, variables, public classes, and methods
(function_definition
(storage_class_specifier) @extern
) @function
(class_specifier
(public) @class
(function_definition) @method
) @class
(declaration
(storage_class_specifier) @extern
) @variable
]],
lua = [[
;; Capture function and method definitions
(variable_list) @variable
(function_declaration) @function
]],
ruby = [[
;; Capture top-level methods, class definitions, and methods within classes
(method) @function
(assignment) @assignment
(class
body: (body_statement
(assignment) @class_assignment
(method) @method
)
)
]],
}
local queries_filetype_map = {
["javascriptreact"] = "javascript",
["typescriptreact"] = "typescript",
}
local function get_query(queries, filetype)
filetype = queries_filetype_map[filetype] or filetype
return queries[filetype]
end
local function get_ts_lang(bufnr)
local lang = parsers.get_buf_lang(bufnr)
return lang
end
function RepoMap.get_parser(bufnr)
local lang = get_ts_lang(bufnr)
if not lang then return end
local parser = parsers.get_parser(bufnr, lang)
return parser, lang
end
function RepoMap.extract_dependencies(bufnr)
local parser, lang = RepoMap.get_parser(bufnr)
if not lang or not parser or not dependencies_queries[lang] then
print("No parser or query available for this buffer's language: " .. (lang or "unknown"))
return {}
end
local dependencies = {}
local tree = parser:parse()[1]
local root = tree:root()
local filetype = vim.api.nvim_get_option_value("filetype", { buf = bufnr })
local query = get_query(dependencies_queries, filetype)
if not query then return dependencies end
local query_obj = vim.treesitter.query.parse(lang, query)
for _, node, _ in query_obj:iter_captures(root, bufnr, 0, -1) do
-- local name = query.captures[id]
local required_file = vim.treesitter.get_node_text(node, bufnr):gsub('"', ""):gsub("'", "")
table.insert(dependencies, required_file)
end
return dependencies
end
function RepoMap.get_filetype_by_filepath(filepath) return vim.filetype.match({ filename = filepath }) end
function RepoMap.parse_file(filepath)
local File = require("avante.utils.file")
local source = File.read_content(filepath)
local filetype = RepoMap.get_filetype_by_filepath(filepath)
local lang = parsers.ft_to_lang(filetype)
if lang then
local ok, parser = pcall(vim.treesitter.get_string_parser, source, lang)
if ok then
local tree = parser:parse()[1]
local node = tree:root()
return { node = node, source = source }
else
print("parser error", parser)
end
end
end
local function get_closest_parent_name(node, source)
local parent = node:parent()
while parent do
local name = parent:field("name")[1]
if name then return get_node_text(name, source) end
parent = parent:parent()
end
return ""
end
local function find_parent_by_type(node, type)
local parent = node:parent()
while parent do
if parent:type() == type then return parent end
parent = parent:parent()
end
return nil
end
local function find_child_by_type(node, type)
for child in node:iter_children() do
if child:type() == type then return child end
local res = find_child_by_type(child, type)
if res then return res end
end
return nil
end
local function get_node_type(node, source)
local node_type
local predefined_type_node = find_child_by_type(node, "predefined_type")
if predefined_type_node then
node_type = get_node_text(predefined_type_node, source)
else
local value_type_node = node:field("type")[1]
node_type = value_type_node and get_node_text(value_type_node, source) or ""
end
return node_type
end
-- Function to extract definitions from the file
function RepoMap.extract_definitions(filepath)
local Utils = require("avante.utils")
local filetype = RepoMap.get_filetype_by_filepath(filepath)
if not filetype then return {} end
-- Get the corresponding query for the detected language
local query = get_query(definitions_queries, filetype)
if not query then return {} end
local parsed = RepoMap.parse_file(filepath)
if not parsed then return {} end
-- Get the current buffer's syntax tree
local root = parsed.node
local lang = parsers.ft_to_lang(filetype)
-- Parse the query
local query_obj = vim.treesitter.query.parse(lang, query)
-- Store captured results
local definitions = {}
local class_def_map = {}
local enum_def_map = {}
local function get_class_def(name)
local def = class_def_map[name]
if def == nil then
def = {
type = "class",
name = name,
methods = {},
properties = {},
}
class_def_map[name] = def
end
return def
end
local function get_enum_def(name)
local def = enum_def_map[name]
if def == nil then
def = {
type = "enum",
name = name,
items = {},
}
enum_def_map[name] = def
end
return def
end
for _, captures, _ in query_obj:iter_matches(root, parsed.source) do
for id, node in pairs(captures) do
local type = query_obj.captures[id]
local name_node = node:field("name")[1]
local name = name_node and get_node_text(name_node, parsed.source) or ""
if type == "class" then
if name ~= "" then get_class_def(name) end
elseif type == "enum_item" then
local enum_name = get_closest_parent_name(node, parsed.source)
if enum_name and filetype == "go" and not Utils.is_first_letter_uppercase(enum_name) then goto continue end
local enum_def = get_enum_def(enum_name)
local enum_type_node = find_child_by_type(node, "type_identifier")
local enum_type = enum_type_node and get_node_text(enum_type_node, parsed.source) or ""
table.insert(enum_def.items, {
name = name,
type = enum_type,
})
elseif type == "method" then
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
local params_node = node:field("parameters")[1]
local params = params_node and get_node_text(params_node, parsed.source) or "()"
local return_type_node = node:field("return_type")[1] or node:field("result")[1]
local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void"
local class_name
local impl_item_node = find_parent_by_type(node, "impl_item")
local receiver_node = node:field("receiver")[1]
if impl_item_node then
local impl_type_node = impl_item_node:field("type")[1]
class_name = impl_type_node and get_node_text(impl_type_node, parsed.source) or ""
elseif receiver_node then
local type_identifier_node = find_child_by_type(receiver_node, "type_identifier")
class_name = type_identifier_node and get_node_text(type_identifier_node, parsed.source) or ""
else
class_name = get_closest_parent_name(node, parsed.source)
end
local class_def = get_class_def(class_name)
local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier")
local accessibility_modifier = accessibility_modifier_node
and get_node_text(accessibility_modifier_node, parsed.source)
or ""
table.insert(class_def.methods, {
type = "function",
name = name,
params = params,
return_type = return_type,
accessibility_modifier = accessibility_modifier,
})
elseif type == "class_assignment" then
local left_node = node:field("left")[1]
local left = left_node and get_node_text(left_node, parsed.source) or ""
local value_type = get_node_type(node, parsed.source)
local class_name = get_closest_parent_name(node, parsed.source)
if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end
local class_def = get_class_def(class_name)
table.insert(class_def.properties, {
type = "variable",
name = left,
value_type = value_type,
})
elseif type == "class_variable" then
local value_type = get_node_type(node, parsed.source)
local class_name = get_closest_parent_name(node, parsed.source)
if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end
local class_def = get_class_def(class_name)
table.insert(class_def.properties, {
type = "variable",
name = name,
value_type = value_type,
})
elseif type == "function" or type == "arrow_function" then
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
local impl_item_node = find_parent_by_type(node, "impl_item")
if impl_item_node then goto continue end
local function_node = find_parent_by_type(node, "function_declaration")
or find_parent_by_type(node, "function_definition")
if function_node then goto continue end
-- Extract function parameters and return type
local params_node = node:field("parameters")[1]
local params = params_node and get_node_text(params_node, parsed.source) or "()"
local return_type_node = node:field("return_type")[1] or node:field("result")[1]
local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void"
local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier")
local accessibility_modifier = accessibility_modifier_node
and get_node_text(accessibility_modifier_node, parsed.source)
or ""
local def = {
type = "function",
name = name,
params = params,
return_type = return_type,
accessibility_modifier = accessibility_modifier,
}
table.insert(definitions, def)
elseif type == "assignment" then
local impl_item_node = find_parent_by_type(node, "impl_item")
or find_parent_by_type(node, "class_declaration")
or find_parent_by_type(node, "class_definition")
if impl_item_node then goto continue end
local function_node = find_parent_by_type(node, "function_declaration")
or find_parent_by_type(node, "function_definition")
if function_node then goto continue end
local left_node = node:field("left")[1]
local left = left_node and get_node_text(left_node, parsed.source) or ""
if left and filetype == "go" and not Utils.is_first_letter_uppercase(left) then goto continue end
local value_type = get_node_type(node, parsed.source)
local def = {
type = "variable",
name = left,
value_type = value_type,
}
table.insert(definitions, def)
elseif type == "variable" then
local impl_item_node = find_parent_by_type(node, "impl_item")
or find_parent_by_type(node, "class_declaration")
or find_parent_by_type(node, "class_definition")
if impl_item_node then goto continue end
local function_node = find_parent_by_type(node, "function_declaration")
or find_parent_by_type(node, "function_definition")
if function_node then goto continue end
local value_type = get_node_type(node, parsed.source)
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
local def = { type = "variable", name = name, value_type = value_type }
table.insert(definitions, def)
end
::continue::
end
end
for _, def in pairs(class_def_map) do
table.insert(definitions, def)
end
for _, def in pairs(enum_def_map) do
table.insert(definitions, def)
end
return definitions
end
local function stringify_function(def)
local res = "func " .. def.name .. def.params .. ":" .. def.return_type .. ";"
if def.accessibility_modifier and def.accessibility_modifier ~= "" then
res = def.accessibility_modifier .. " " .. res
end
return res
end
local function stringify_variable(def)
local res = "var " .. def.name
if def.value_type and def.value_type ~= "" then res = res .. ":" .. def.value_type end
return res .. ";"
end
local function stringify_enum_item(def)
local res = def.name
if def.value_type and def.value_type ~= "" then res = res .. ":" .. def.value_type end
return res .. ";"
end
-- Function to load file content into a temporary buffer, process it, and then delete the buffer
function RepoMap.stringify_definitions(filepath)
if vim.endswith(filepath, "~") then return "" end
-- Extract definitions
local definitions = RepoMap.extract_definitions(filepath)
local output = ""
-- Print or process the definitions
for _, def in ipairs(definitions) do
if def.type == "class" then
output = output .. def.type .. " " .. def.name .. "{"
for _, property in ipairs(def.properties) do
output = output .. stringify_variable(property)
end
for _, method in ipairs(def.methods) do
output = output .. stringify_function(method)
end
output = output .. "}"
elseif def.type == "enum" then
output = output .. def.type .. " " .. def.name .. "{"
for _, item in ipairs(def.items) do
output = output .. stringify_enum_item(item) .. ""
end
output = output .. "}"
elseif def.type == "function" then
output = output .. stringify_function(def)
elseif def.type == "variable" then
output = output .. stringify_variable(def)
end
end
return output
end
function RepoMap._build_repo_map(project_root, file_ext)
local Utils = require("avante.utils")
local output = {}
local gitignore_path = project_root .. "/.gitignore"
local ignore_patterns, negate_patterns = Utils.parse_gitignore(gitignore_path)
local filepaths = Utils.scan_directory(project_root, ignore_patterns, negate_patterns)
vim.iter(filepaths):each(function(filepath)
if not Utils.is_same_file_ext(file_ext, filepath) then return end
local definitions = RepoMap.stringify_definitions(filepath)
if definitions == "" then return end
table.insert(output, {
path = Utils.relative_path(filepath),
lang = RepoMap.get_filetype_by_filepath(filepath),
defs = definitions,
})
end)
return output
end
local cache = {}
function RepoMap.get_repo_map(file_ext)
file_ext = file_ext or vim.fn.expand("%:e")
local Utils = require("avante.utils")
local project_root = Utils.root.get()
local cache_key = project_root .. "." .. file_ext
local cached = cache[cache_key]
if cached then return cached end
local PPath = require("plenary.path")
local Path = require("avante.path")
local repo_map
local function build_and_save()
repo_map = RepoMap._build_repo_map(project_root, file_ext)
cache[cache_key] = repo_map
Path.repo_map.save(project_root, file_ext, repo_map)
end
repo_map = Path.repo_map.load(project_root, file_ext)
if not repo_map or next(repo_map) == nil then
build_and_save()
if not repo_map then return end
else
local timer = vim.loop.new_timer()
if timer then
timer:start(
0,
0,
vim.schedule_wrap(function()
build_and_save()
timer:close()
end)
)
end
end
local update_repo_map = vim.schedule_wrap(function(rel_filepath)
if rel_filepath and Utils.is_same_file_ext(file_ext, rel_filepath) then
local abs_filepath = PPath:new(project_root):joinpath(rel_filepath):absolute()
local definitions = RepoMap.stringify_definitions(abs_filepath)
if definitions == "" then return end
local found = false
for _, m in ipairs(repo_map) do
if m.path == rel_filepath then
m.defs = definitions
found = true
break
end
end
if not found then
table.insert(repo_map, {
path = Utils.relative_path(abs_filepath),
lang = RepoMap.get_filetype_by_filepath(abs_filepath),
defs = definitions,
})
end
cache[cache_key] = repo_map
Path.repo_map.save(project_root, file_ext, repo_map)
end
end)
local handle = vim.loop.new_fs_event()
if handle then
handle:start(project_root, { recursive = true }, function(err, rel_filepath)
if err then
print("Error watching directory " .. project_root .. ":", err)
return
end
if rel_filepath then update_repo_map(rel_filepath) end
end)
end
vim.api.nvim_create_autocmd({ "BufReadPost", "BufNewFile" }, {
callback = function(ev)
vim.defer_fn(function()
local filepath = vim.api.nvim_buf_get_name(ev.buf)
if not vim.startswith(filepath, project_root) then return end
local rel_filepath = Utils.relative_path(filepath)
update_repo_map(rel_filepath)
end, 0)
end,
})
return repo_map
end
return RepoMap

View File

@ -1,24 +1,24 @@
---@class source ---@class commands_source
---@field sidebar avante.Sidebar ---@field sidebar avante.Sidebar
local source = {} local commands_source = {}
---@param sidebar avante.Sidebar ---@param sidebar avante.Sidebar
function source.new(sidebar) function commands_source.new(sidebar)
---@type cmp.Source ---@type cmp.Source
return setmetatable({ return setmetatable({
sidebar = sidebar, sidebar = sidebar,
}, { __index = source }) }, { __index = commands_source })
end end
function source:is_available() return vim.bo.filetype == "AvanteInput" end function commands_source:is_available() return vim.bo.filetype == "AvanteInput" end
source.get_position_encoding_kind = function() return "utf-8" end commands_source.get_position_encoding_kind = function() return "utf-8" end
function source:get_trigger_characters() return { "/" } end function commands_source:get_trigger_characters() return { "/" } end
function source:get_keyword_pattern() return [[\%(@\|#\|/\)\k*]] end function commands_source:get_keyword_pattern() return [[\%(@\|#\|/\)\k*]] end
function source:complete(_, callback) function commands_source:complete(_, callback)
local kind = require("cmp").lsp.CompletionItemKind.Variable local kind = require("cmp").lsp.CompletionItemKind.Variable
local items = {} local items = {}
@ -39,4 +39,4 @@ function source:complete(_, callback)
}) })
end end
return source return commands_source

View File

@ -0,0 +1,43 @@
---@class mentions_source
---@field mentions {description: string, command: AvanteMentions, details: string, shorthelp?: string, callback?: AvanteMentionCallback}[]
---@field bufnr integer
local mentions_source = {}
---@param mentions {description: string, command: AvanteMentions, details: string, shorthelp?: string, callback?: AvanteMentionCallback}[]
---@param bufnr integer
function mentions_source.new(mentions, bufnr)
---@type cmp.Source
return setmetatable({
mentions = mentions,
bufnr = bufnr,
}, { __index = mentions_source })
end
function mentions_source:is_available() return vim.api.nvim_get_current_buf() == self.bufnr end
mentions_source.get_position_encoding_kind = function() return "utf-8" end
function mentions_source:get_trigger_characters() return { "@" } end
function mentions_source:get_keyword_pattern() return [[\%(@\|#\|/\)\k*]] end
function mentions_source:complete(_, callback)
local kind = require("cmp").lsp.CompletionItemKind.Variable
local items = {}
for _, mention in ipairs(self.mentions) do
table.insert(items, {
label = "@" .. mention.command .. " ",
kind = kind,
detail = mention.details,
})
end
callback({
items = items,
isIncomplete = false,
})
end
return mentions_source