avante.nvim/lua/avante/utils/repo_map.lua
yetone 8e1018fef7
feat: repo map (#496)
* feat: repo map

* chore: remove breakline

* chore: remove spaces

* fix: golang public method

* feat: mentions for editing input
2024-09-23 18:52:26 +08:00

731 lines
22 KiB
Lua

local parsers = require("nvim-treesitter.parsers")
local Config = require("avante.config")
local get_node_text = vim.treesitter.get_node_text
---@class avante.utils.repo_map
local RepoMap = {}
local dependencies_queries = {
lua = [[
(function_call
name: (identifier) @function_name
arguments: (arguments
(string) @required_file))
]],
python = [[
(import_from_statement
module_name: (dotted_name) @import_module)
(import_statement
(dotted_name) @import_module)
]],
javascript = [[
(import_statement
source: (string) @import_module)
(call_expression
function: (identifier) @function_name
arguments: (arguments
(string) @required_file))
]],
typescript = [[
(import_statement
source: (string) @import_module)
(call_expression
function: (identifier) @function_name
arguments: (arguments
(string) @required_file))
]],
go = [[
(import_spec
path: (interpreted_string_literal) @import_module)
]],
rust = [[
(use_declaration
(scoped_identifier) @import_module)
(use_declaration
(identifier) @import_module)
]],
c = [[
(preproc_include
(string_literal) @import_module)
(preproc_include
(system_lib_string) @import_module)
]],
cpp = [[
(preproc_include
(string_literal) @import_module)
(preproc_include
(system_lib_string) @import_module)
]],
}
local definitions_queries = {
python = [[
;; Capture top-level functions, class, and method definitions
(module
(expression_statement
(assignment) @assignment
)
)
(module
(function_definition) @function
)
(module
(class_definition
body: (block
(expression_statement
(assignment) @class_assignment
)
)
)
)
(module
(class_definition
body: (block
(function_definition) @method
)
)
)
]],
javascript = [[
;; Capture exported functions, arrow functions, variables, classes, and method definitions
(export_statement
declaration: (lexical_declaration
(variable_declarator) @variable
)
)
(export_statement
declaration: (function_declaration) @function
)
(export_statement
declaration: (class_declaration
body: (class_body
(field_definition) @class_variable
)
)
)
(export_statement
declaration: (class_declaration
body: (class_body
(method_definition) @method
)
)
)
]],
typescript = [[
;; Capture exported functions, arrow functions, variables, classes, and method definitions
(export_statement
declaration: (lexical_declaration
(variable_declarator) @variable
)
)
(export_statement
declaration: (function_declaration) @function
)
(export_statement
declaration: (class_declaration
body: (class_body
(public_field_definition) @class_variable
)
)
)
(interface_declaration
body: (interface_body
(property_signature) @class_variable
)
)
(type_alias_declaration
value: (object_type
(property_signature) @class_variable
)
)
(export_statement
declaration: (class_declaration
body: (class_body
(method_definition) @method
)
)
)
]],
rust = [[
;; Capture public functions, structs, methods, and variable definitions
(function_item) @function
(impl_item
body: (declaration_list
(function_item) @method
)
)
(struct_item
body: (field_declaration_list
(field_declaration) @class_variable
)
)
(enum_item
body: (enum_variant_list
(enum_variant) @enum_item
)
)
(const_item) @variable
]],
go = [[
;; Capture top-level functions and struct definitions
(var_declaration
(var_spec) @variable
)
(const_declaration
(const_spec) @variable
)
(function_declaration) @function
(type_declaration
(type_spec (struct_type)) @class
)
(type_declaration
(type_spec
(struct_type
(field_declaration_list
(field_declaration) @class_variable)))
)
(method_declaration) @method
]],
c = [[
;; Capture extern functions, variables, public classes, and methods
(function_definition
(storage_class_specifier) @extern
) @function
(class_specifier
(public) @class
(function_definition) @method
) @class
(declaration
(storage_class_specifier) @extern
) @variable
]],
cpp = [[
;; Capture extern functions, variables, public classes, and methods
(function_definition
(storage_class_specifier) @extern
) @function
(class_specifier
(public) @class
(function_definition) @method
) @class
(declaration
(storage_class_specifier) @extern
) @variable
]],
lua = [[
;; Capture function and method definitions
(variable_list) @variable
(function_declaration) @function
]],
ruby = [[
;; Capture top-level methods, class definitions, and methods within classes
(method) @function
(assignment) @assignment
(class
body: (body_statement
(assignment) @class_assignment
(method) @method
)
)
]],
}
local queries_filetype_map = {
["javascriptreact"] = "javascript",
["typescriptreact"] = "typescript",
}
local function get_query(queries, filetype)
filetype = queries_filetype_map[filetype] or filetype
return queries[filetype]
end
local function get_ts_lang(bufnr)
local lang = parsers.get_buf_lang(bufnr)
return lang
end
function RepoMap.get_parser(bufnr)
local lang = get_ts_lang(bufnr)
if not lang then return end
local parser = parsers.get_parser(bufnr, lang)
return parser, lang
end
function RepoMap.extract_dependencies(bufnr)
local parser, lang = RepoMap.get_parser(bufnr)
if not lang or not parser or not dependencies_queries[lang] then
print("No parser or query available for this buffer's language: " .. (lang or "unknown"))
return {}
end
local dependencies = {}
local tree = parser:parse()[1]
local root = tree:root()
local filetype = vim.api.nvim_get_option_value("filetype", { buf = bufnr })
local query = get_query(dependencies_queries, filetype)
if not query then return dependencies end
local query_obj = vim.treesitter.query.parse(lang, query)
for _, node, _ in query_obj:iter_captures(root, bufnr, 0, -1) do
-- local name = query.captures[id]
local required_file = vim.treesitter.get_node_text(node, bufnr):gsub('"', ""):gsub("'", "")
table.insert(dependencies, required_file)
end
return dependencies
end
function RepoMap.get_filetype_by_filepath(filepath) return vim.filetype.match({ filename = filepath }) end
function RepoMap.parse_file(filepath)
local File = require("avante.utils.file")
local source = File.read_content(filepath)
local filetype = RepoMap.get_filetype_by_filepath(filepath)
local lang = parsers.ft_to_lang(filetype)
if lang then
local ok, parser = pcall(vim.treesitter.get_string_parser, source, lang)
if ok then
local tree = parser:parse()[1]
local node = tree:root()
return { node = node, source = source }
else
print("parser error", parser)
end
end
end
local function get_closest_parent_name(node, source)
local parent = node:parent()
while parent do
local name = parent:field("name")[1]
if name then return get_node_text(name, source) end
parent = parent:parent()
end
return ""
end
local function find_parent_by_type(node, type)
local parent = node:parent()
while parent do
if parent:type() == type then return parent end
parent = parent:parent()
end
return nil
end
local function find_child_by_type(node, type)
for child in node:iter_children() do
if child:type() == type then return child end
local res = find_child_by_type(child, type)
if res then return res end
end
return nil
end
local function get_node_type(node, source)
local node_type
local predefined_type_node = find_child_by_type(node, "predefined_type")
if predefined_type_node then
node_type = get_node_text(predefined_type_node, source)
else
local value_type_node = node:field("type")[1]
node_type = value_type_node and get_node_text(value_type_node, source) or ""
end
return node_type
end
-- Function to extract definitions from the file
function RepoMap.extract_definitions(filepath)
local Utils = require("avante.utils")
local filetype = RepoMap.get_filetype_by_filepath(filepath)
if not filetype then return {} end
-- Get the corresponding query for the detected language
local query = get_query(definitions_queries, filetype)
if not query then return {} end
local parsed = RepoMap.parse_file(filepath)
if not parsed then return {} end
-- Get the current buffer's syntax tree
local root = parsed.node
local lang = parsers.ft_to_lang(filetype)
-- Parse the query
local query_obj = vim.treesitter.query.parse(lang, query)
-- Store captured results
local definitions = {}
local class_def_map = {}
local enum_def_map = {}
local function get_class_def(name)
local def = class_def_map[name]
if def == nil then
def = {
type = "class",
name = name,
methods = {},
properties = {},
}
class_def_map[name] = def
end
return def
end
local function get_enum_def(name)
local def = enum_def_map[name]
if def == nil then
def = {
type = "enum",
name = name,
items = {},
}
enum_def_map[name] = def
end
return def
end
for _, captures, _ in query_obj:iter_matches(root, parsed.source) do
for id, node in pairs(captures) do
local type = query_obj.captures[id]
local name_node = node:field("name")[1]
local name = name_node and get_node_text(name_node, parsed.source) or ""
if type == "class" then
if name ~= "" then get_class_def(name) end
elseif type == "enum_item" then
local enum_name = get_closest_parent_name(node, parsed.source)
if enum_name and filetype == "go" and not Utils.is_first_letter_uppercase(enum_name) then goto continue end
local enum_def = get_enum_def(enum_name)
local enum_type_node = find_child_by_type(node, "type_identifier")
local enum_type = enum_type_node and get_node_text(enum_type_node, parsed.source) or ""
table.insert(enum_def.items, {
name = name,
type = enum_type,
})
elseif type == "method" then
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
local params_node = node:field("parameters")[1]
local params = params_node and get_node_text(params_node, parsed.source) or "()"
local return_type_node = node:field("return_type")[1] or node:field("result")[1]
local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void"
local class_name
local impl_item_node = find_parent_by_type(node, "impl_item")
local receiver_node = node:field("receiver")[1]
if impl_item_node then
local impl_type_node = impl_item_node:field("type")[1]
class_name = impl_type_node and get_node_text(impl_type_node, parsed.source) or ""
elseif receiver_node then
local type_identifier_node = find_child_by_type(receiver_node, "type_identifier")
class_name = type_identifier_node and get_node_text(type_identifier_node, parsed.source) or ""
else
class_name = get_closest_parent_name(node, parsed.source)
end
local class_def = get_class_def(class_name)
local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier")
local accessibility_modifier = accessibility_modifier_node
and get_node_text(accessibility_modifier_node, parsed.source)
or ""
table.insert(class_def.methods, {
type = "function",
name = name,
params = params,
return_type = return_type,
accessibility_modifier = accessibility_modifier,
})
elseif type == "class_assignment" then
local left_node = node:field("left")[1]
local left = left_node and get_node_text(left_node, parsed.source) or ""
local value_type = get_node_type(node, parsed.source)
local class_name = get_closest_parent_name(node, parsed.source)
if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end
local class_def = get_class_def(class_name)
table.insert(class_def.properties, {
type = "variable",
name = left,
value_type = value_type,
})
elseif type == "class_variable" then
local value_type = get_node_type(node, parsed.source)
local class_name = get_closest_parent_name(node, parsed.source)
if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end
local class_def = get_class_def(class_name)
table.insert(class_def.properties, {
type = "variable",
name = name,
value_type = value_type,
})
elseif type == "function" or type == "arrow_function" then
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
local impl_item_node = find_parent_by_type(node, "impl_item")
if impl_item_node then goto continue end
local function_node = find_parent_by_type(node, "function_declaration")
or find_parent_by_type(node, "function_definition")
if function_node then goto continue end
-- Extract function parameters and return type
local params_node = node:field("parameters")[1]
local params = params_node and get_node_text(params_node, parsed.source) or "()"
local return_type_node = node:field("return_type")[1] or node:field("result")[1]
local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void"
local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier")
local accessibility_modifier = accessibility_modifier_node
and get_node_text(accessibility_modifier_node, parsed.source)
or ""
local def = {
type = "function",
name = name,
params = params,
return_type = return_type,
accessibility_modifier = accessibility_modifier,
}
table.insert(definitions, def)
elseif type == "assignment" then
local impl_item_node = find_parent_by_type(node, "impl_item")
or find_parent_by_type(node, "class_declaration")
or find_parent_by_type(node, "class_definition")
if impl_item_node then goto continue end
local function_node = find_parent_by_type(node, "function_declaration")
or find_parent_by_type(node, "function_definition")
if function_node then goto continue end
local left_node = node:field("left")[1]
local left = left_node and get_node_text(left_node, parsed.source) or ""
if left and filetype == "go" and not Utils.is_first_letter_uppercase(left) then goto continue end
local value_type = get_node_type(node, parsed.source)
local def = {
type = "variable",
name = left,
value_type = value_type,
}
table.insert(definitions, def)
elseif type == "variable" then
local impl_item_node = find_parent_by_type(node, "impl_item")
or find_parent_by_type(node, "class_declaration")
or find_parent_by_type(node, "class_definition")
if impl_item_node then goto continue end
local function_node = find_parent_by_type(node, "function_declaration")
or find_parent_by_type(node, "function_definition")
if function_node then goto continue end
local value_type = get_node_type(node, parsed.source)
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
local def = { type = "variable", name = name, value_type = value_type }
table.insert(definitions, def)
end
::continue::
end
end
for _, def in pairs(class_def_map) do
table.insert(definitions, def)
end
for _, def in pairs(enum_def_map) do
table.insert(definitions, def)
end
return definitions
end
local function stringify_function(def)
local res = "func " .. def.name .. def.params .. ":" .. def.return_type .. ";"
if def.accessibility_modifier and def.accessibility_modifier ~= "" then
res = def.accessibility_modifier .. " " .. res
end
return res
end
local function stringify_variable(def)
local res = "var " .. def.name
if def.value_type and def.value_type ~= "" then res = res .. ":" .. def.value_type end
return res .. ";"
end
local function stringify_enum_item(def)
local res = def.name
if def.value_type and def.value_type ~= "" then res = res .. ":" .. def.value_type end
return res .. ";"
end
-- Function to load file content into a temporary buffer, process it, and then delete the buffer
function RepoMap.stringify_definitions(filepath)
if vim.endswith(filepath, "~") then return "" end
-- Extract definitions
local definitions = RepoMap.extract_definitions(filepath)
local output = ""
-- Print or process the definitions
for _, def in ipairs(definitions) do
if def.type == "class" then
output = output .. def.type .. " " .. def.name .. "{"
for _, property in ipairs(def.properties) do
output = output .. stringify_variable(property)
end
for _, method in ipairs(def.methods) do
output = output .. stringify_function(method)
end
output = output .. "}"
elseif def.type == "enum" then
output = output .. def.type .. " " .. def.name .. "{"
for _, item in ipairs(def.items) do
output = output .. stringify_enum_item(item) .. ""
end
output = output .. "}"
elseif def.type == "function" then
output = output .. stringify_function(def)
elseif def.type == "variable" then
output = output .. stringify_variable(def)
end
end
return output
end
function RepoMap._build_repo_map(project_root, file_ext)
local Utils = require("avante.utils")
local output = {}
local gitignore_path = project_root .. "/.gitignore"
local ignore_patterns, negate_patterns = Utils.parse_gitignore(gitignore_path)
local filepaths = Utils.scan_directory(project_root, ignore_patterns, negate_patterns)
vim.iter(filepaths):each(function(filepath)
if not Utils.is_same_file_ext(file_ext, filepath) then return end
local definitions = RepoMap.stringify_definitions(filepath)
if definitions == "" then return end
table.insert(output, {
path = Utils.relative_path(filepath),
lang = RepoMap.get_filetype_by_filepath(filepath),
defs = definitions,
})
end)
return output
end
local cache = {}
function RepoMap.get_repo_map(file_ext)
file_ext = file_ext or vim.fn.expand("%:e")
local Utils = require("avante.utils")
local project_root = Utils.root.get()
local cache_key = project_root .. "." .. file_ext
local cached = cache[cache_key]
if cached then return cached end
local PPath = require("plenary.path")
local Path = require("avante.path")
local repo_map
local function build_and_save()
repo_map = RepoMap._build_repo_map(project_root, file_ext)
cache[cache_key] = repo_map
Path.repo_map.save(project_root, file_ext, repo_map)
end
repo_map = Path.repo_map.load(project_root, file_ext)
if not repo_map or next(repo_map) == nil then
build_and_save()
if not repo_map then return end
else
local timer = vim.loop.new_timer()
if timer then
timer:start(
0,
0,
vim.schedule_wrap(function()
build_and_save()
timer:close()
end)
)
end
end
local update_repo_map = vim.schedule_wrap(function(rel_filepath)
if rel_filepath and Utils.is_same_file_ext(file_ext, rel_filepath) then
local abs_filepath = PPath:new(project_root):joinpath(rel_filepath):absolute()
local definitions = RepoMap.stringify_definitions(abs_filepath)
if definitions == "" then return end
local found = false
for _, m in ipairs(repo_map) do
if m.path == rel_filepath then
m.defs = definitions
found = true
break
end
end
if not found then
table.insert(repo_map, {
path = Utils.relative_path(abs_filepath),
lang = RepoMap.get_filetype_by_filepath(abs_filepath),
defs = definitions,
})
end
cache[cache_key] = repo_map
Path.repo_map.save(project_root, file_ext, repo_map)
end
end)
local handle = vim.loop.new_fs_event()
if handle then
handle:start(project_root, { recursive = true }, function(err, rel_filepath)
if err then
print("Error watching directory " .. project_root .. ":", err)
return
end
if rel_filepath then update_repo_map(rel_filepath) end
end)
end
vim.api.nvim_create_autocmd({ "BufReadPost", "BufNewFile" }, {
callback = function(ev)
vim.defer_fn(function()
local filepath = vim.api.nvim_buf_get_name(ev.buf)
if not vim.startswith(filepath, project_root) then return end
local rel_filepath = Utils.relative_path(filepath)
update_repo_map(rel_filepath)
end, 0)
end,
})
return repo_map
end
return RepoMap