feat: repo map (#496)

* feat: repo map

* chore: remove breakline

* chore: remove spaces

* fix: golang public method

* feat: mentions for editing input
This commit is contained in:
yetone 2024-09-23 18:52:26 +08:00 committed by GitHub
parent 8dbfe85dd4
commit 8e1018fef7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1191 additions and 64 deletions

View File

@ -38,6 +38,7 @@ For building binary if you wish to build from source, then `cargo` is required.
build = "make",
-- build = "powershell -ExecutionPolicy Bypass -File Build.ps1 -BuildFromSource false" -- for windows
dependencies = {
"nvim-treesitter/nvim-treesitter",
"stevearc/dressing.nvim",
"nvim-lua/plenary.nvim",
"MunifTanjim/nui.nvim",
@ -427,7 +428,7 @@ If you have the following structure:
- [x] Slash commands
- [x] Edit the selected block
- [x] Smart Tab (Cursor Flow)
- [ ] Chat with project
- [x] Chat with project (You can use `@codebase` to chat with the whole project)
- [ ] Chat with selected files
## Roadmap

View File

@ -21,6 +21,7 @@ struct TemplateContext {
ask: bool,
question: String,
code_lang: String,
filepath: String,
file_content: String,
selected_code: Option<String>,
project_context: Option<String>,
@ -45,6 +46,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
ask => context.ask,
question => context.question,
code_lang => context.code_lang,
filepath => context.filepath,
file_content => context.file_content,
selected_code => context.selected_code,
project_context => context.project_context,

View File

@ -60,11 +60,14 @@ M.stream = function(opts)
Path.prompts.initialize(Path.prompts.get(opts.bufnr))
local filepath = Utils.relative_path(api.nvim_buf_get_name(opts.bufnr))
local template_opts = {
use_xml_format = Provider.use_xml_format,
ask = opts.ask, -- TODO: add mode without ask instruction
question = original_instructions,
code_lang = opts.code_lang,
filepath = filepath,
file_content = opts.file_content,
selected_code = opts.selected_code,
project_context = opts.project_context,

View File

@ -28,17 +28,17 @@ end
H.get_mode_file = function(mode) return string.format("custom.%s.avanterules", mode) end
-- History path
local M = {}
local History = {}
-- Returns the Path to the chat history file for the given buffer.
---@param bufnr integer
---@return Path
M.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(H.filename(bufnr)) end
History.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(H.filename(bufnr)) end
-- Loads the chat history for the given buffer.
---@param bufnr integer
M.load = function(bufnr)
local history_file = M.get(bufnr)
History.load = function(bufnr)
local history_file = History.get(bufnr)
if history_file:exists() then
local content = history_file:read()
return content ~= nil and vim.json.decode(content) or {}
@ -49,29 +49,29 @@ end
-- Saves the chat history for the given buffer.
---@param bufnr integer
---@param history table
M.save = function(bufnr, history)
local history_file = M.get(bufnr)
History.save = function(bufnr, history)
local history_file = History.get(bufnr)
history_file:write(vim.json.encode(history), "w")
end
P.history = M
P.history = History
-- Prompt path
local N = {}
local Prompt = {}
---@class AvanteTemplates
---@field initialize fun(directory: string): nil
---@field render fun(template: string, context: TemplateOptions): string
local templates = nil
N.templates = { planning = nil, editing = nil, suggesting = nil }
Prompt.templates = { planning = nil, editing = nil, suggesting = nil }
-- Creates a directory in the cache path for the given buffer and copies the custom prompts to it.
-- We need to do this beacuse the prompt template engine requires a given directory to load all required files.
-- PERF: Hmm instead of copy to cache, we can also load in globals context, but it requires some work on bindings. (eh maybe?)
---@param bufnr number
---@return string the resulted cache_directory to be loaded with avante_templates
N.get = function(bufnr)
Prompt.get = function(bufnr)
if not P.available() then error("Make sure to build avante (missing avante_templates)", 2) end
-- get root directory of given bufnr
@ -85,19 +85,19 @@ N.get = function(bufnr)
local scanner = Scan.scan_dir(directory:absolute(), { depth = 1, add_dirs = true })
for _, entry in ipairs(scanner) do
local file = Path:new(entry)
if entry:find("planning") and N.templates.planning == nil then
N.templates.planning = file:read()
elseif entry:find("editing") and N.templates.editing == nil then
N.templates.editing = file:read()
elseif entry:find("suggesting") and N.templates.suggesting == nil then
N.templates.suggesting = file:read()
if entry:find("planning") and Prompt.templates.planning == nil then
Prompt.templates.planning = file:read()
elseif entry:find("editing") and Prompt.templates.editing == nil then
Prompt.templates.editing = file:read()
elseif entry:find("suggesting") and Prompt.templates.suggesting == nil then
Prompt.templates.suggesting = file:read()
end
end
Path:new(debug.getinfo(1).source:match("@?(.*/)"):gsub("/lua/avante/path.lua$", "") .. "templates")
:copy({ destination = cache_prompt_dir, recursive = true })
vim.iter(N.templates):filter(function(_, v) return v ~= nil end):each(function(k, v)
vim.iter(Prompt.templates):filter(function(_, v) return v ~= nil end):each(function(k, v)
local f = cache_prompt_dir:joinpath(H.get_mode_file(k))
f:write(v, "w")
end)
@ -106,22 +106,53 @@ N.get = function(bufnr)
end
---@param mode LlmMode
N.get_file = function(mode)
if N.templates[mode] ~= nil then return H.get_mode_file(mode) end
Prompt.get_file = function(mode)
if Prompt.templates[mode] ~= nil then return H.get_mode_file(mode) end
return string.format("%s.avanterules", mode)
end
---@param path string
---@param opts TemplateOptions
N.render_file = function(path, opts) return templates.render(path, opts) end
Prompt.render_file = function(path, opts) return templates.render(path, opts) end
---@param mode LlmMode
---@param opts TemplateOptions
N.render_mode = function(mode, opts) return templates.render(N.get_file(mode), opts) end
Prompt.render_mode = function(mode, opts) return templates.render(Prompt.get_file(mode), opts) end
N.initialize = function(directory) templates.initialize(directory) end
Prompt.initialize = function(directory) templates.initialize(directory) end
P.prompts = N
P.prompts = Prompt
local RepoMap = {}
-- Get a chat history file name given a buffer
---@param project_root string
---@param ext string
---@return string
RepoMap.filename = function(project_root, ext)
-- Replace path separators with double underscores
local path_with_separators = fn.substitute(project_root, "/", "__", "g")
-- Replace other non-alphanumeric characters with single underscores
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. "." .. ext .. ".repo_map.json"
end
RepoMap.get = function(project_root, ext) return Path:new(P.data_path):joinpath(RepoMap.filename(project_root, ext)) end
RepoMap.save = function(project_root, ext, data)
local file = RepoMap.get(project_root, ext)
file:write(vim.json.encode(data), "w")
end
RepoMap.load = function(project_root, ext)
local file = RepoMap.get(project_root, ext)
if file:exists() then
local content = file:read()
return content ~= nil and vim.json.decode(content) or {}
end
return nil
end
P.repo_map = RepoMap
P.setup = function()
local history_path = Path:new(Config.history.storage_path)
@ -132,6 +163,10 @@ P.setup = function()
if not cache_path:exists() then cache_path:mkdir({ parents = true }) end
P.cache_path = cache_path
local data_path = Path:new(vim.fn.stdpath("data") .. "/avante")
if not data_path:exists() then data_path:mkdir({ parents = true }) end
P.data_path = data_path
vim.defer_fn(function()
local ok, module = pcall(require, "avante_templates")
---@cast module AvanteTemplates

View File

@ -1,7 +1,6 @@
local Utils = require("avante.utils")
local Config = require("avante.config")
local Llm = require("avante.llm")
local Highlights = require("avante.highlights")
local Provider = require("avante.providers")
local api = vim.api
@ -391,10 +390,16 @@ function Selection:create_editing_input()
end
local filetype = api.nvim_get_option_value("filetype", { buf = code_bufnr })
local file_ext = api.nvim_buf_get_name(code_bufnr):match("^.+%.(.+)$")
local mentions = Utils.extract_mentions(input)
input = mentions.new_content
local project_context = mentions.enable_project_context and Utils.repo_map.get_repo_map(file_ext) or nil
Llm.stream({
bufnr = code_bufnr,
ask = true,
project_context = vim.json.encode(project_context),
file_content = code_content,
code_lang = filetype,
selected_code = self.selection.content,
@ -453,6 +458,25 @@ function Selection:create_editing_input()
end,
})
api.nvim_create_autocmd("InsertEnter", {
group = self.augroup,
buffer = bufnr,
once = true,
desc = "Setup the completion of helpers in the input buffer",
callback = function()
local has_cmp, cmp = pcall(require, "cmp")
if has_cmp then
cmp.register_source("avante_mentions", require("cmp_avante.mentions").new(Utils.get_mentions(), bufnr))
cmp.setup.buffer({
enabled = true,
sources = {
{ name = "avante_mentions" },
},
})
end
end,
})
api.nvim_create_autocmd("User", {
pattern = "AvanteEditSubmitted",
callback = function(ev)

View File

@ -441,10 +441,8 @@ local function insert_conflict_contents(bufnr, snippets)
local result = {}
table.insert(result, "<<<<<<< HEAD")
if start_line ~= end_line then
for i = start_line, end_line do
table.insert(result, lines[i])
end
for i = start_line, end_line do
table.insert(result, lines[i])
end
table.insert(result, "=======")
@ -460,8 +458,6 @@ local function insert_conflict_contents(bufnr, snippets)
table.insert(result, line)
end
if start_line == end_line then table.insert(result, lines[start_line]) end
table.insert(result, ">>>>>>> Snippet")
api.nvim_buf_set_lines(bufnr, offset + start_line - 1, offset + end_line, false, result)
@ -1296,9 +1292,17 @@ function Sidebar:create_input(opts)
Path.history.save(self.code.bufnr, chat_history)
end
local mentions = Utils.extract_mentions(request)
request = mentions.new_content
local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$")
local project_context = mentions.enable_project_context and Utils.repo_map.get_repo_map(file_ext) or nil
Llm.stream({
bufnr = self.code.bufnr,
ask = opts.ask,
project_context = vim.json.encode(project_context),
file_content = content_with_line_numbers,
code_lang = filetype,
selected_code = selected_code_content_with_line_numbers,
@ -1358,9 +1362,9 @@ function Sidebar:create_input(opts)
local function place_sign_at_first_line(bufnr)
local group = "avante_input_prompt_group"
vim.fn.sign_unplace(group, { buffer = bufnr })
fn.sign_unplace(group, { buffer = bufnr })
vim.fn.sign_place(0, group, "AvanteInputPromptSign", bufnr, { lnum = 1 })
fn.sign_place(0, group, "AvanteInputPromptSign", bufnr, { lnum = 1 })
end
place_sign_at_first_line(self.input.bufnr)
@ -1387,11 +1391,16 @@ function Sidebar:create_input(opts)
if not self.registered_cmp then
self.registered_cmp = true
cmp.register_source("avante_commands", require("cmp_avante.commands").new(self))
cmp.register_source(
"avante_mentions",
require("cmp_avante.mentions").new(Utils.get_mentions(), self.input.bufnr)
)
end
cmp.setup.buffer({
enabled = true,
sources = {
{ name = "avante_commands" },
{ name = "avante_mentions" },
},
})
end

View File

@ -82,24 +82,26 @@ function Suggestion:suggest()
return
end
Utils.debug("full_response: " .. vim.inspect(full_response))
local cursor_row, cursor_col = Utils.get_cursor_pos()
if cursor_row ~= doc.position.row or cursor_col ~= doc.position.col then return end
local ok, suggestions = pcall(vim.json.decode, full_response)
if not ok then
Utils.error("Error while decoding suggestions: " .. full_response, { once = true, title = "Avante" })
return
end
if not suggestions then
Utils.info("No suggestions found", { once = true, title = "Avante" })
return
end
suggestions = vim
.iter(suggestions)
:map(function(s) return { row = s.row, col = s.col, content = Utils.trim_all_line_numbers(s.content) } end)
:totable()
ctx.suggestions = suggestions
ctx.current_suggestion_idx = 1
self:show()
vim.schedule(function()
local cursor_row, cursor_col = Utils.get_cursor_pos()
if cursor_row ~= doc.position.row or cursor_col ~= doc.position.col then return end
local ok, suggestions = pcall(vim.json.decode, full_response)
if not ok then
Utils.error("Error while decoding suggestions: " .. full_response, { once = true, title = "Avante" })
return
end
if not suggestions then
Utils.info("No suggestions found", { once = true, title = "Avante" })
return
end
suggestions = vim
.iter(suggestions)
:map(function(s) return { row = s.row, col = s.col, content = Utils.trim_all_line_numbers(s.content) } end)
:totable()
ctx.suggestions = suggestions
ctx.current_suggestion_idx = 1
self:show()
end)
end,
})
end

View File

@ -1,4 +1,6 @@
{%- if use_xml_format -%}
<filepath>{{filepath}}</filepath>
{%- if selected_code -%}
<context>
```{{code_lang}}
@ -19,6 +21,8 @@
</code>
{%- endif %}
{% else %}
FILEPATH: {{filepath}}
{%- if selected_code -%}
CONTEXT:
```{{code_lang}}

View File

@ -17,6 +17,7 @@ Your task is to suggest code modifications at the cursor position. Follow these
{% endraw %}
3. When suggesting suggested code:
- DO NOT include three backticks: {%raw%}```{%endraw%} in your suggestion. Treat the suggested code AS IS.
- Each element in the returned list is a COMPLETE and INDEPENDENT code snippet.
- MUST be a valid json format. Don't be lazy!
- Only return the new code to be inserted.
@ -29,4 +30,3 @@ Your task is to suggest code modifications at the cursor position. Follow these
Remember to ONLY RETURN the suggested code snippet, without any additional formatting or explanation.
{% endblock %}

36
lua/avante/utils/file.lua Normal file
View File

@ -0,0 +1,36 @@
local LRUCache = require("avante.utils.lru_cache")
---@class avante.utils.file
local M = {}
local api = vim.api
local fn = vim.fn
local _file_content_lru_cache = LRUCache:new(60)
api.nvim_create_autocmd("BufWritePost", {
callback = function()
local filepath = api.nvim_buf_get_name(0)
local keys = _file_content_lru_cache:keys()
if vim.tbl_contains(keys, filepath) then
local content = table.concat(api.nvim_buf_get_lines(0, 0, -1, false), "\n")
_file_content_lru_cache:set(filepath, content)
end
end,
})
function M.read_content(filepath)
local cached_content = _file_content_lru_cache:get(filepath)
if cached_content then return cached_content end
local content = fn.readfile(filepath)
if content then
content = table.concat(content, "\n")
_file_content_lru_cache:set(filepath, content)
return content
end
return nil
end
return M

View File

@ -5,6 +5,7 @@ local lsp = vim.lsp
---@class avante.utils: LazyUtilCore
---@field tokens avante.utils.tokens
---@field root avante.utils.root
---@field repo_map avante.utils.repo_map
local M = {}
setmetatable(M, {
@ -444,7 +445,7 @@ function M.get_indentation(code) return code:match("^%s*") or "" end
--- remove indentation from code: spaces or tabs
function M.remove_indentation(code) return code:gsub("^%s*", "") end
local function relative_path(absolute)
function M.relative_path(absolute)
local relative = fn.fnamemodify(absolute, ":.")
if string.sub(relative, 0, 1) == "/" then return fn.fnamemodify(absolute, ":t") end
return relative
@ -462,7 +463,7 @@ function M.get_doc()
local doc = {
uri = params.textDocument.uri,
version = api.nvim_buf_get_var(0, "changedtick"),
relativePath = relative_path(absolute),
relativePath = M.relative_path(absolute),
insertSpaces = vim.o.expandtab,
tabSize = fn.shiftwidth(),
indentSize = fn.shiftwidth(),
@ -520,4 +521,126 @@ function M.winline(winid)
return line
end
function M.get_project_root() return M.root.get() end
function M.is_same_file_ext(target_ext, filepath)
local ext = fn.fnamemodify(filepath, ":e")
if target_ext == "tsx" and ext == "ts" then return true end
if target_ext == "jsx" and ext == "js" then return true end
return ext == target_ext
end
-- Get recent filepaths in the same project and same file ext
function M.get_recent_filepaths(limit, filenames)
local project_root = M.get_project_root()
local current_ext = fn.expand("%:e")
local oldfiles = vim.v.oldfiles
local recent_files = {}
for _, file in ipairs(oldfiles) do
if vim.startswith(file, project_root) and M.is_same_file_ext(current_ext, file) then
if filenames and #filenames > 0 then
for _, filename in ipairs(filenames) do
if file:find(filename) then table.insert(recent_files, file) end
end
else
table.insert(recent_files, file)
end
if #recent_files >= (limit or 10) then break end
end
end
return recent_files
end
local function pattern_to_lua(pattern)
local lua_pattern = pattern:gsub("[%(%)%.%%%+%-%*%?%[%]%^%$]", "%%%1")
lua_pattern = lua_pattern:gsub("%*%*/", ".-/")
lua_pattern = lua_pattern:gsub("%*", "[^/]*")
lua_pattern = lua_pattern:gsub("%?", ".")
if lua_pattern:sub(-1) == "/" then lua_pattern = lua_pattern .. ".*" end
return lua_pattern
end
function M.parse_gitignore(gitignore_path)
local ignore_patterns = { ".git", ".worktree", "__pycache__", "node_modules" }
local negate_patterns = {}
local file = io.open(gitignore_path, "r")
if not file then return ignore_patterns, negate_patterns end
for line in file:lines() do
if line:match("%S") and not line:match("^#") then
local trimmed_line = line:match("^%s*(.-)%s*$")
if trimmed_line:sub(1, 1) == "!" then
table.insert(negate_patterns, pattern_to_lua(trimmed_line:sub(2)))
else
table.insert(ignore_patterns, pattern_to_lua(trimmed_line))
end
end
end
file:close()
return ignore_patterns, negate_patterns
end
local function is_ignored(file, ignore_patterns, negate_patterns)
for _, pattern in ipairs(negate_patterns) do
if file:match(pattern) then return false end
end
for _, pattern in ipairs(ignore_patterns) do
if file:match(pattern) then return true end
end
return false
end
function M.scan_directory(directory, ignore_patterns, negate_patterns)
local files = {}
local handle = vim.loop.fs_scandir(directory)
if not handle then return files end
while true do
local name, type = vim.loop.fs_scandir_next(handle)
if not name then break end
local full_path = directory .. "/" .. name
if type == "directory" then
vim.list_extend(files, M.scan_directory(full_path, ignore_patterns, negate_patterns))
elseif type == "file" then
if not is_ignored(full_path, ignore_patterns, negate_patterns) then table.insert(files, full_path) end
end
end
return files
end
function M.is_first_letter_uppercase(str) return string.match(str, "^[A-Z]") ~= nil end
---@param content string
---@return { new_content: string, enable_project_context: boolean }
function M.extract_mentions(content)
-- if content contains @codebase, enable project context and remove @codebase
local new_content = content
local enable_project_context = false
if content:match("@codebase") then
enable_project_context = true
new_content = content:gsub("@codebase", "")
end
return { new_content = new_content, enable_project_context = enable_project_context }
end
---@alias AvanteMentions "codebase"
---@alias AvanteMentionCallback fun(args: string, cb?: fun(args: string): nil): nil
---@alias AvanteMention {description: string, command: AvanteMentions, details: string, shorthelp?: string, callback?: AvanteMentionCallback}
---@return AvanteMention[]
function M.get_mentions()
return {
{
description = "codebase",
command = "codebase",
details = "repo map",
},
}
end
return M

View File

@ -0,0 +1,115 @@
local LRUCache = {}
LRUCache.__index = LRUCache
function LRUCache:new(capacity)
return setmetatable({
capacity = capacity,
cache = {},
head = nil,
tail = nil,
size = 0,
}, LRUCache)
end
-- Internal function: Move node to head (indicating most recently used)
function LRUCache:_move_to_head(node)
if self.head == node then return end
-- Disconnect the node
if node.prev then node.prev.next = node.next end
if node.next then node.next.prev = node.prev end
if self.tail == node then self.tail = node.prev end
-- Insert the node at the head
node.next = self.head
node.prev = nil
if self.head then self.head.prev = node end
self.head = node
if not self.tail then self.tail = node end
end
-- Get value from cache
function LRUCache:get(key)
local node = self.cache[key]
if not node then return nil end
self:_move_to_head(node)
return node.value
end
-- Set value in cache
function LRUCache:set(key, value)
local node = self.cache[key]
if node then
node.value = value
self:_move_to_head(node)
else
node = { key = key, value = value }
self.cache[key] = node
self.size = self.size + 1
self:_move_to_head(node)
if self.size > self.capacity then
local tail_key = self.tail.key
self.tail = self.tail.prev
if self.tail then self.tail.next = nil end
self.cache[tail_key] = nil
self.size = self.size - 1
end
end
end
-- Remove specified cache entry
function LRUCache:remove(key)
local node = self.cache[key]
if not node then return end
if node.prev then
node.prev.next = node.next
else
self.head = node.next
end
if node.next then
node.next.prev = node.prev
else
self.tail = node.prev
end
self.cache[key] = nil
self.size = self.size - 1
end
-- Get current size of cache
function LRUCache:get_size() return self.size end
-- Get capacity of cache
function LRUCache:get_capacity() return self.capacity end
-- Print current cache contents (for debugging)
function LRUCache:print_cache()
local node = self.head
while node do
print(node.key, node.value)
node = node.next
end
end
function LRUCache:keys()
local keys = {}
local node = self.head
while node do
table.insert(keys, node.key)
node = node.next
end
return keys
end
return LRUCache

View File

@ -0,0 +1,730 @@
local parsers = require("nvim-treesitter.parsers")
local Config = require("avante.config")
local get_node_text = vim.treesitter.get_node_text
---@class avante.utils.repo_map
local RepoMap = {}
local dependencies_queries = {
lua = [[
(function_call
name: (identifier) @function_name
arguments: (arguments
(string) @required_file))
]],
python = [[
(import_from_statement
module_name: (dotted_name) @import_module)
(import_statement
(dotted_name) @import_module)
]],
javascript = [[
(import_statement
source: (string) @import_module)
(call_expression
function: (identifier) @function_name
arguments: (arguments
(string) @required_file))
]],
typescript = [[
(import_statement
source: (string) @import_module)
(call_expression
function: (identifier) @function_name
arguments: (arguments
(string) @required_file))
]],
go = [[
(import_spec
path: (interpreted_string_literal) @import_module)
]],
rust = [[
(use_declaration
(scoped_identifier) @import_module)
(use_declaration
(identifier) @import_module)
]],
c = [[
(preproc_include
(string_literal) @import_module)
(preproc_include
(system_lib_string) @import_module)
]],
cpp = [[
(preproc_include
(string_literal) @import_module)
(preproc_include
(system_lib_string) @import_module)
]],
}
local definitions_queries = {
python = [[
;; Capture top-level functions, class, and method definitions
(module
(expression_statement
(assignment) @assignment
)
)
(module
(function_definition) @function
)
(module
(class_definition
body: (block
(expression_statement
(assignment) @class_assignment
)
)
)
)
(module
(class_definition
body: (block
(function_definition) @method
)
)
)
]],
javascript = [[
;; Capture exported functions, arrow functions, variables, classes, and method definitions
(export_statement
declaration: (lexical_declaration
(variable_declarator) @variable
)
)
(export_statement
declaration: (function_declaration) @function
)
(export_statement
declaration: (class_declaration
body: (class_body
(field_definition) @class_variable
)
)
)
(export_statement
declaration: (class_declaration
body: (class_body
(method_definition) @method
)
)
)
]],
typescript = [[
;; Capture exported functions, arrow functions, variables, classes, and method definitions
(export_statement
declaration: (lexical_declaration
(variable_declarator) @variable
)
)
(export_statement
declaration: (function_declaration) @function
)
(export_statement
declaration: (class_declaration
body: (class_body
(public_field_definition) @class_variable
)
)
)
(interface_declaration
body: (interface_body
(property_signature) @class_variable
)
)
(type_alias_declaration
value: (object_type
(property_signature) @class_variable
)
)
(export_statement
declaration: (class_declaration
body: (class_body
(method_definition) @method
)
)
)
]],
rust = [[
;; Capture public functions, structs, methods, and variable definitions
(function_item) @function
(impl_item
body: (declaration_list
(function_item) @method
)
)
(struct_item
body: (field_declaration_list
(field_declaration) @class_variable
)
)
(enum_item
body: (enum_variant_list
(enum_variant) @enum_item
)
)
(const_item) @variable
]],
go = [[
;; Capture top-level functions and struct definitions
(var_declaration
(var_spec) @variable
)
(const_declaration
(const_spec) @variable
)
(function_declaration) @function
(type_declaration
(type_spec (struct_type)) @class
)
(type_declaration
(type_spec
(struct_type
(field_declaration_list
(field_declaration) @class_variable)))
)
(method_declaration) @method
]],
c = [[
;; Capture extern functions, variables, public classes, and methods
(function_definition
(storage_class_specifier) @extern
) @function
(class_specifier
(public) @class
(function_definition) @method
) @class
(declaration
(storage_class_specifier) @extern
) @variable
]],
cpp = [[
;; Capture extern functions, variables, public classes, and methods
(function_definition
(storage_class_specifier) @extern
) @function
(class_specifier
(public) @class
(function_definition) @method
) @class
(declaration
(storage_class_specifier) @extern
) @variable
]],
lua = [[
;; Capture function and method definitions
(variable_list) @variable
(function_declaration) @function
]],
ruby = [[
;; Capture top-level methods, class definitions, and methods within classes
(method) @function
(assignment) @assignment
(class
body: (body_statement
(assignment) @class_assignment
(method) @method
)
)
]],
}
local queries_filetype_map = {
["javascriptreact"] = "javascript",
["typescriptreact"] = "typescript",
}
local function get_query(queries, filetype)
filetype = queries_filetype_map[filetype] or filetype
return queries[filetype]
end
local function get_ts_lang(bufnr)
local lang = parsers.get_buf_lang(bufnr)
return lang
end
function RepoMap.get_parser(bufnr)
local lang = get_ts_lang(bufnr)
if not lang then return end
local parser = parsers.get_parser(bufnr, lang)
return parser, lang
end
function RepoMap.extract_dependencies(bufnr)
local parser, lang = RepoMap.get_parser(bufnr)
if not lang or not parser or not dependencies_queries[lang] then
print("No parser or query available for this buffer's language: " .. (lang or "unknown"))
return {}
end
local dependencies = {}
local tree = parser:parse()[1]
local root = tree:root()
local filetype = vim.api.nvim_get_option_value("filetype", { buf = bufnr })
local query = get_query(dependencies_queries, filetype)
if not query then return dependencies end
local query_obj = vim.treesitter.query.parse(lang, query)
for _, node, _ in query_obj:iter_captures(root, bufnr, 0, -1) do
-- local name = query.captures[id]
local required_file = vim.treesitter.get_node_text(node, bufnr):gsub('"', ""):gsub("'", "")
table.insert(dependencies, required_file)
end
return dependencies
end
function RepoMap.get_filetype_by_filepath(filepath) return vim.filetype.match({ filename = filepath }) end
function RepoMap.parse_file(filepath)
local File = require("avante.utils.file")
local source = File.read_content(filepath)
local filetype = RepoMap.get_filetype_by_filepath(filepath)
local lang = parsers.ft_to_lang(filetype)
if lang then
local ok, parser = pcall(vim.treesitter.get_string_parser, source, lang)
if ok then
local tree = parser:parse()[1]
local node = tree:root()
return { node = node, source = source }
else
print("parser error", parser)
end
end
end
local function get_closest_parent_name(node, source)
local parent = node:parent()
while parent do
local name = parent:field("name")[1]
if name then return get_node_text(name, source) end
parent = parent:parent()
end
return ""
end
local function find_parent_by_type(node, type)
local parent = node:parent()
while parent do
if parent:type() == type then return parent end
parent = parent:parent()
end
return nil
end
local function find_child_by_type(node, type)
for child in node:iter_children() do
if child:type() == type then return child end
local res = find_child_by_type(child, type)
if res then return res end
end
return nil
end
local function get_node_type(node, source)
local node_type
local predefined_type_node = find_child_by_type(node, "predefined_type")
if predefined_type_node then
node_type = get_node_text(predefined_type_node, source)
else
local value_type_node = node:field("type")[1]
node_type = value_type_node and get_node_text(value_type_node, source) or ""
end
return node_type
end
-- Function to extract definitions from the file
function RepoMap.extract_definitions(filepath)
local Utils = require("avante.utils")
local filetype = RepoMap.get_filetype_by_filepath(filepath)
if not filetype then return {} end
-- Get the corresponding query for the detected language
local query = get_query(definitions_queries, filetype)
if not query then return {} end
local parsed = RepoMap.parse_file(filepath)
if not parsed then return {} end
-- Get the current buffer's syntax tree
local root = parsed.node
local lang = parsers.ft_to_lang(filetype)
-- Parse the query
local query_obj = vim.treesitter.query.parse(lang, query)
-- Store captured results
local definitions = {}
local class_def_map = {}
local enum_def_map = {}
local function get_class_def(name)
local def = class_def_map[name]
if def == nil then
def = {
type = "class",
name = name,
methods = {},
properties = {},
}
class_def_map[name] = def
end
return def
end
local function get_enum_def(name)
local def = enum_def_map[name]
if def == nil then
def = {
type = "enum",
name = name,
items = {},
}
enum_def_map[name] = def
end
return def
end
for _, captures, _ in query_obj:iter_matches(root, parsed.source) do
for id, node in pairs(captures) do
local type = query_obj.captures[id]
local name_node = node:field("name")[1]
local name = name_node and get_node_text(name_node, parsed.source) or ""
if type == "class" then
if name ~= "" then get_class_def(name) end
elseif type == "enum_item" then
local enum_name = get_closest_parent_name(node, parsed.source)
if enum_name and filetype == "go" and not Utils.is_first_letter_uppercase(enum_name) then goto continue end
local enum_def = get_enum_def(enum_name)
local enum_type_node = find_child_by_type(node, "type_identifier")
local enum_type = enum_type_node and get_node_text(enum_type_node, parsed.source) or ""
table.insert(enum_def.items, {
name = name,
type = enum_type,
})
elseif type == "method" then
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
local params_node = node:field("parameters")[1]
local params = params_node and get_node_text(params_node, parsed.source) or "()"
local return_type_node = node:field("return_type")[1] or node:field("result")[1]
local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void"
local class_name
local impl_item_node = find_parent_by_type(node, "impl_item")
local receiver_node = node:field("receiver")[1]
if impl_item_node then
local impl_type_node = impl_item_node:field("type")[1]
class_name = impl_type_node and get_node_text(impl_type_node, parsed.source) or ""
elseif receiver_node then
local type_identifier_node = find_child_by_type(receiver_node, "type_identifier")
class_name = type_identifier_node and get_node_text(type_identifier_node, parsed.source) or ""
else
class_name = get_closest_parent_name(node, parsed.source)
end
local class_def = get_class_def(class_name)
local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier")
local accessibility_modifier = accessibility_modifier_node
and get_node_text(accessibility_modifier_node, parsed.source)
or ""
table.insert(class_def.methods, {
type = "function",
name = name,
params = params,
return_type = return_type,
accessibility_modifier = accessibility_modifier,
})
elseif type == "class_assignment" then
local left_node = node:field("left")[1]
local left = left_node and get_node_text(left_node, parsed.source) or ""
local value_type = get_node_type(node, parsed.source)
local class_name = get_closest_parent_name(node, parsed.source)
if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end
local class_def = get_class_def(class_name)
table.insert(class_def.properties, {
type = "variable",
name = left,
value_type = value_type,
})
elseif type == "class_variable" then
local value_type = get_node_type(node, parsed.source)
local class_name = get_closest_parent_name(node, parsed.source)
if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end
local class_def = get_class_def(class_name)
table.insert(class_def.properties, {
type = "variable",
name = name,
value_type = value_type,
})
elseif type == "function" or type == "arrow_function" then
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
local impl_item_node = find_parent_by_type(node, "impl_item")
if impl_item_node then goto continue end
local function_node = find_parent_by_type(node, "function_declaration")
or find_parent_by_type(node, "function_definition")
if function_node then goto continue end
-- Extract function parameters and return type
local params_node = node:field("parameters")[1]
local params = params_node and get_node_text(params_node, parsed.source) or "()"
local return_type_node = node:field("return_type")[1] or node:field("result")[1]
local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void"
local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier")
local accessibility_modifier = accessibility_modifier_node
and get_node_text(accessibility_modifier_node, parsed.source)
or ""
local def = {
type = "function",
name = name,
params = params,
return_type = return_type,
accessibility_modifier = accessibility_modifier,
}
table.insert(definitions, def)
elseif type == "assignment" then
local impl_item_node = find_parent_by_type(node, "impl_item")
or find_parent_by_type(node, "class_declaration")
or find_parent_by_type(node, "class_definition")
if impl_item_node then goto continue end
local function_node = find_parent_by_type(node, "function_declaration")
or find_parent_by_type(node, "function_definition")
if function_node then goto continue end
local left_node = node:field("left")[1]
local left = left_node and get_node_text(left_node, parsed.source) or ""
if left and filetype == "go" and not Utils.is_first_letter_uppercase(left) then goto continue end
local value_type = get_node_type(node, parsed.source)
local def = {
type = "variable",
name = left,
value_type = value_type,
}
table.insert(definitions, def)
elseif type == "variable" then
local impl_item_node = find_parent_by_type(node, "impl_item")
or find_parent_by_type(node, "class_declaration")
or find_parent_by_type(node, "class_definition")
if impl_item_node then goto continue end
local function_node = find_parent_by_type(node, "function_declaration")
or find_parent_by_type(node, "function_definition")
if function_node then goto continue end
local value_type = get_node_type(node, parsed.source)
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
local def = { type = "variable", name = name, value_type = value_type }
table.insert(definitions, def)
end
::continue::
end
end
for _, def in pairs(class_def_map) do
table.insert(definitions, def)
end
for _, def in pairs(enum_def_map) do
table.insert(definitions, def)
end
return definitions
end
local function stringify_function(def)
local res = "func " .. def.name .. def.params .. ":" .. def.return_type .. ";"
if def.accessibility_modifier and def.accessibility_modifier ~= "" then
res = def.accessibility_modifier .. " " .. res
end
return res
end
local function stringify_variable(def)
local res = "var " .. def.name
if def.value_type and def.value_type ~= "" then res = res .. ":" .. def.value_type end
return res .. ";"
end
local function stringify_enum_item(def)
local res = def.name
if def.value_type and def.value_type ~= "" then res = res .. ":" .. def.value_type end
return res .. ";"
end
-- Function to load file content into a temporary buffer, process it, and then delete the buffer
function RepoMap.stringify_definitions(filepath)
if vim.endswith(filepath, "~") then return "" end
-- Extract definitions
local definitions = RepoMap.extract_definitions(filepath)
local output = ""
-- Print or process the definitions
for _, def in ipairs(definitions) do
if def.type == "class" then
output = output .. def.type .. " " .. def.name .. "{"
for _, property in ipairs(def.properties) do
output = output .. stringify_variable(property)
end
for _, method in ipairs(def.methods) do
output = output .. stringify_function(method)
end
output = output .. "}"
elseif def.type == "enum" then
output = output .. def.type .. " " .. def.name .. "{"
for _, item in ipairs(def.items) do
output = output .. stringify_enum_item(item) .. ""
end
output = output .. "}"
elseif def.type == "function" then
output = output .. stringify_function(def)
elseif def.type == "variable" then
output = output .. stringify_variable(def)
end
end
return output
end
function RepoMap._build_repo_map(project_root, file_ext)
local Utils = require("avante.utils")
local output = {}
local gitignore_path = project_root .. "/.gitignore"
local ignore_patterns, negate_patterns = Utils.parse_gitignore(gitignore_path)
local filepaths = Utils.scan_directory(project_root, ignore_patterns, negate_patterns)
vim.iter(filepaths):each(function(filepath)
if not Utils.is_same_file_ext(file_ext, filepath) then return end
local definitions = RepoMap.stringify_definitions(filepath)
if definitions == "" then return end
table.insert(output, {
path = Utils.relative_path(filepath),
lang = RepoMap.get_filetype_by_filepath(filepath),
defs = definitions,
})
end)
return output
end
local cache = {}
function RepoMap.get_repo_map(file_ext)
file_ext = file_ext or vim.fn.expand("%:e")
local Utils = require("avante.utils")
local project_root = Utils.root.get()
local cache_key = project_root .. "." .. file_ext
local cached = cache[cache_key]
if cached then return cached end
local PPath = require("plenary.path")
local Path = require("avante.path")
local repo_map
local function build_and_save()
repo_map = RepoMap._build_repo_map(project_root, file_ext)
cache[cache_key] = repo_map
Path.repo_map.save(project_root, file_ext, repo_map)
end
repo_map = Path.repo_map.load(project_root, file_ext)
if not repo_map or next(repo_map) == nil then
build_and_save()
if not repo_map then return end
else
local timer = vim.loop.new_timer()
if timer then
timer:start(
0,
0,
vim.schedule_wrap(function()
build_and_save()
timer:close()
end)
)
end
end
local update_repo_map = vim.schedule_wrap(function(rel_filepath)
if rel_filepath and Utils.is_same_file_ext(file_ext, rel_filepath) then
local abs_filepath = PPath:new(project_root):joinpath(rel_filepath):absolute()
local definitions = RepoMap.stringify_definitions(abs_filepath)
if definitions == "" then return end
local found = false
for _, m in ipairs(repo_map) do
if m.path == rel_filepath then
m.defs = definitions
found = true
break
end
end
if not found then
table.insert(repo_map, {
path = Utils.relative_path(abs_filepath),
lang = RepoMap.get_filetype_by_filepath(abs_filepath),
defs = definitions,
})
end
cache[cache_key] = repo_map
Path.repo_map.save(project_root, file_ext, repo_map)
end
end)
local handle = vim.loop.new_fs_event()
if handle then
handle:start(project_root, { recursive = true }, function(err, rel_filepath)
if err then
print("Error watching directory " .. project_root .. ":", err)
return
end
if rel_filepath then update_repo_map(rel_filepath) end
end)
end
vim.api.nvim_create_autocmd({ "BufReadPost", "BufNewFile" }, {
callback = function(ev)
vim.defer_fn(function()
local filepath = vim.api.nvim_buf_get_name(ev.buf)
if not vim.startswith(filepath, project_root) then return end
local rel_filepath = Utils.relative_path(filepath)
update_repo_map(rel_filepath)
end, 0)
end,
})
return repo_map
end
return RepoMap

View File

@ -1,24 +1,24 @@
---@class source
---@class commands_source
---@field sidebar avante.Sidebar
local source = {}
local commands_source = {}
---@param sidebar avante.Sidebar
function source.new(sidebar)
function commands_source.new(sidebar)
---@type cmp.Source
return setmetatable({
sidebar = sidebar,
}, { __index = source })
}, { __index = commands_source })
end
function source:is_available() return vim.bo.filetype == "AvanteInput" end
function commands_source:is_available() return vim.bo.filetype == "AvanteInput" end
source.get_position_encoding_kind = function() return "utf-8" end
commands_source.get_position_encoding_kind = function() return "utf-8" end
function source:get_trigger_characters() return { "/" } end
function commands_source:get_trigger_characters() return { "/" } end
function source:get_keyword_pattern() return [[\%(@\|#\|/\)\k*]] end
function commands_source:get_keyword_pattern() return [[\%(@\|#\|/\)\k*]] end
function source:complete(_, callback)
function commands_source:complete(_, callback)
local kind = require("cmp").lsp.CompletionItemKind.Variable
local items = {}
@ -39,4 +39,4 @@ function source:complete(_, callback)
})
end
return source
return commands_source

View File

@ -0,0 +1,43 @@
---@class mentions_source
---@field mentions {description: string, command: AvanteMentions, details: string, shorthelp?: string, callback?: AvanteMentionCallback}[]
---@field bufnr integer
local mentions_source = {}
---@param mentions {description: string, command: AvanteMentions, details: string, shorthelp?: string, callback?: AvanteMentionCallback}[]
---@param bufnr integer
function mentions_source.new(mentions, bufnr)
---@type cmp.Source
return setmetatable({
mentions = mentions,
bufnr = bufnr,
}, { __index = mentions_source })
end
function mentions_source:is_available() return vim.api.nvim_get_current_buf() == self.bufnr end
mentions_source.get_position_encoding_kind = function() return "utf-8" end
function mentions_source:get_trigger_characters() return { "@" } end
function mentions_source:get_keyword_pattern() return [[\%(@\|#\|/\)\k*]] end
function mentions_source:complete(_, callback)
local kind = require("cmp").lsp.CompletionItemKind.Variable
local items = {}
for _, mention in ipairs(self.mentions) do
table.insert(items, {
label = "@" .. mention.command .. " ",
kind = kind,
detail = mention.details,
})
end
callback({
items = items,
isIncomplete = false,
})
end
return mentions_source