feat: avante repo map rust crate (#628)

This commit is contained in:
yetone 2024-09-26 03:45:49 +08:00 committed by GitHub
parent 5461342fce
commit 0d90c047ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1399 additions and 737 deletions

View File

@ -29,6 +29,18 @@ jobs:
uses: lunarmodules/luacheck@v1 uses: lunarmodules/luacheck@v1
with: with:
args: ./lua/ args: ./lua/
rust-tests:
name: Run Rust tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: Swatinem/rust-cache@v2
- uses: dtolnay/rust-toolchain@master
with:
toolchain: stable
components: clippy, rustfmt
- name: Run rust tests
run: cargo test --features luajit
rust: rust:
name: Check Rust style name: Check Rust style
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -74,6 +74,7 @@ jobs:
fi fi
cp target/release/libavante_templates.$EXT results/avante_templates.$EXT cp target/release/libavante_templates.$EXT results/avante_templates.$EXT
cp target/release/libavante_tokenizers.$EXT results/avante_tokenizers.$EXT cp target/release/libavante_tokenizers.$EXT results/avante_tokenizers.$EXT
cp target/release/libavante_repo_map.$EXT results/avante_repo_map.$EXT
cd results cd results
tar zcvf avante_lib-${{ matrix.os }}-${{ matrix.feature }}.tar.gz *.${EXT} tar zcvf avante_lib-${{ matrix.os }}-${{ matrix.feature }}.tar.gz *.${EXT}
@ -85,6 +86,7 @@ jobs:
Copy-Item -Path "target\release\avante_templates.dll" -Destination "results\avante_templates.dll" Copy-Item -Path "target\release\avante_templates.dll" -Destination "results\avante_templates.dll"
Copy-Item -Path "target\release\avante_tokenizers.dll" -Destination "results\avante_tokenizers.dll" Copy-Item -Path "target\release\avante_tokenizers.dll" -Destination "results\avante_tokenizers.dll"
Copy-Item -Path "target\release\avante_repo_map.dll" -Destination "results\avante_repo_map.dll"
Set-Location -Path results Set-Location -Path results

133
Cargo.lock generated
View File

@ -29,6 +29,27 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "avante-repo-map"
version = "0.1.0"
dependencies = [
"cc",
"minijinja",
"mlua",
"serde",
"tree-sitter",
"tree-sitter-c",
"tree-sitter-cpp",
"tree-sitter-go",
"tree-sitter-javascript",
"tree-sitter-language",
"tree-sitter-lua",
"tree-sitter-python",
"tree-sitter-ruby",
"tree-sitter-rust",
"tree-sitter-typescript",
]
[[package]] [[package]]
name = "avante-templates" name = "avante-templates"
version = "0.1.0" version = "0.1.0"
@ -117,9 +138,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.15" version = "1.1.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6" checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0"
dependencies = [ dependencies = [
"shlex", "shlex",
] ]
@ -1275,6 +1296,114 @@ dependencies = [
"unicode_categories", "unicode_categories",
] ]
[[package]]
name = "tree-sitter"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20f4cd3642c47a85052a887d86704f4eac272969f61b686bdd3f772122aabaff"
dependencies = [
"cc",
"regex",
"regex-syntax",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-c"
version = "0.23.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8b3fb515e498e258799a31d78e6603767cd6892770d9e2290ec00af5c3ad80b"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-cpp"
version = "0.23.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d67e862242878d6ee50e1e5814f267ee3eea0168aea2cdbd700ccfb4c74b6d3"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-go"
version = "0.23.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "caf57626e4c9b6d6efaf8a8d5ee1241c5f178ae7bfdf693713ae6a774f01424e"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-javascript"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59e1f62f8babb640b909f30675d1addeb1f17802f2a4d2af287569753b243977"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-language"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2545046bd1473dac6c626659cc2567c6c0ff302fc8b84a56c4243378276f7f57"
[[package]]
name = "tree-sitter-lua"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cdb9adf0965fec58e7660cbb3a059dbb12ebeec9459e6dcbae3db004739641e"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-python"
version = "0.23.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65661b1a3e24139e2e54207e47d910ab07e28790d78efc7d5dc3a11ce2a110eb"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-ruby"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ec5ee842e27791e0adffa0b2a177614de51d2a26e5c7e84d014ed7f097e5ed0"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-rust"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cffbbcb780348fbae8395742ae5b34c1fd794e4085d43aac9f259387f9a84dc8"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-typescript"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aecf1585ae2a9dddc2b1d4c0e2140b2ec9876e2a25fd79de47fcf7dae0384685"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]] [[package]]
name = "typeid" name = "typeid"
version = "1.0.2" version = "1.0.2"

View File

@ -11,6 +11,7 @@ version = "0.1.0"
[workspace.dependencies] [workspace.dependencies]
avante-tokenizers = { path = "crates/avante-tokenizers" } avante-tokenizers = { path = "crates/avante-tokenizers" }
avante-templates = { path = "crates/avante-templates" } avante-templates = { path = "crates/avante-templates" }
avante-repo-map = { path = "crates/avante-repo-map" }
minijinja = { version = "2.2.0", features = [ minijinja = { version = "2.2.0", features = [
"loader", "loader",
"json", "json",

View File

@ -22,13 +22,15 @@ all: luajit
define make_definitions define make_definitions
ifeq ($(BUILD_FROM_SOURCE),true) ifeq ($(BUILD_FROM_SOURCE),true)
ifeq ($(TARGET_LIBRARY), all) ifeq ($(TARGET_LIBRARY), all)
$1: $(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT) $(BUILD_DIR)/libAvanteTemplates-$1.$(EXT) $1: $(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT) $(BUILD_DIR)/libAvanteTemplates-$1.$(EXT) $(BUILD_DIR)/libAvanteRepoMap-$1.$(EXT)
else ifeq ($(TARGET_LIBRARY), tokenizers) else ifeq ($(TARGET_LIBRARY), tokenizers)
$1: $(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT) $1: $(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT)
else ifeq ($(TARGET_LIBRARY), templates) else ifeq ($(TARGET_LIBRARY), templates)
$1: $(BUILD_DIR)/libAvanteTemplates-$1.$(EXT) $1: $(BUILD_DIR)/libAvanteTemplates-$1.$(EXT)
else ifeq ($(TARGET_LIBRARY), repo-map)
$1: $(BUILD_DIR)/libAvanteRepoMap-$1.$(EXT)
else else
$$(error TARGET_LIBRARY must be one of all, tokenizers, templates) $$(error TARGET_LIBRARY must be one of all, tokenizers, templates, repo-map)
endif endif
else else
$1: $1:
@ -41,16 +43,18 @@ $(foreach lua_version,$(LUA_VERSIONS),$(eval $(call make_definitions,$(lua_versi
define build_package define build_package
$1-$2: $1-$2:
cargo build --release --features=$1 -p avante-$2 cargo build --release --features=$1 -p avante-$2
cp target/release/libavante_$2.$(EXT) $(BUILD_DIR)/avante_$2.$(EXT) cp target/release/libavante_$(shell echo $2 | tr - _).$(EXT) $(BUILD_DIR)/avante_$(shell echo $2 | tr - _).$(EXT)
endef endef
define build_targets define build_targets
$(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT): $(BUILD_DIR) $1-tokenizers $(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT): $(BUILD_DIR) $1-tokenizers
$(BUILD_DIR)/libAvanteTemplates-$1.$(EXT): $(BUILD_DIR) $1-templates $(BUILD_DIR)/libAvanteTemplates-$1.$(EXT): $(BUILD_DIR) $1-templates
$(BUILD_DIR)/libAvanteRepoMap-$1.$(EXT): $(BUILD_DIR) $1-repo-map
endef endef
$(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_package,$(lua_version),tokenizers))) $(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_package,$(lua_version),tokenizers)))
$(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_package,$(lua_version),templates))) $(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_package,$(lua_version),templates)))
$(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_package,$(lua_version),repo-map)))
$(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_targets,$(lua_version)))) $(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_targets,$(lua_version))))
$(BUILD_DIR): $(BUILD_DIR):

View File

@ -0,0 +1,38 @@
[lib]
crate-type = ["cdylib"]
[package]
name = "avante-repo-map"
edition.workspace = true
rust-version.workspace = true
license.workspace = true
version.workspace = true
[build-dependencies]
cc="*"
[dependencies]
mlua = { workspace = true }
minijinja = { workspace = true }
serde = { workspace = true, features = ["derive"] }
tree-sitter = "0.23"
tree-sitter-language = "0.1"
tree-sitter-rust = "0.23"
tree-sitter-python = "0.23"
tree-sitter-javascript = "0.23"
tree-sitter-typescript = "0.23"
tree-sitter-go = "0.23"
tree-sitter-c = "0.23"
tree-sitter-cpp = "0.23"
tree-sitter-lua = "0.2"
tree-sitter-ruby = "0.23"
[lints]
workspace = true
[features]
lua51 = ["mlua/lua51"]
lua52 = ["mlua/lua52"]
lua53 = ["mlua/lua53"]
lua54 = ["mlua/lua54"]
luajit = ["mlua/luajit"]

View File

@ -0,0 +1,11 @@
;; Capture extern functions, variables, public classes, and methods
(function_definition
(storage_class_specifier) @extern
) @function
(class_specifier
(public) @class
(function_definition) @method
) @class
(declaration
(storage_class_specifier) @extern
) @variable

View File

@ -0,0 +1,11 @@
;; Capture extern functions, variables, public classes, and methods
(function_definition
(storage_class_specifier) @extern
) @function
(class_specifier
(public) @class
(function_definition) @method
) @class
(declaration
(storage_class_specifier) @extern
) @variable

View File

@ -0,0 +1,18 @@
;; Capture top-level functions and struct definitions
(var_declaration
(var_spec) @variable
)
(const_declaration
(const_spec) @variable
)
(function_declaration) @function
(type_declaration
(type_spec (struct_type)) @class
)
(type_declaration
(type_spec
(struct_type
(field_declaration_list
(field_declaration) @class_variable)))
)
(method_declaration) @method

View File

@ -0,0 +1,23 @@
;; Capture exported functions, arrow functions, variables, classes, and method definitions
(export_statement
declaration: (lexical_declaration
(variable_declarator) @variable
)
)
(export_statement
declaration: (function_declaration) @function
)
(export_statement
declaration: (class_declaration
body: (class_body
(field_definition) @class_variable
)
)
)
(export_statement
declaration: (class_declaration
body: (class_body
(method_definition) @method
)
)
)

View File

@ -0,0 +1,3 @@
;; Capture function and method definitions
(variable_list) @variable
(function_declaration) @function

View File

@ -0,0 +1,25 @@
;; Capture top-level functions, class, and method definitions
(module
(expression_statement
(assignment) @assignment
)
)
(module
(function_definition) @function
)
(module
(class_definition
body: (block
(expression_statement
(assignment) @class_assignment
)
)
)
)
(module
(class_definition
body: (block
(function_definition) @method
)
)
)

View File

@ -0,0 +1,16 @@
;; Capture top-level methods, class definitions, and methods within classes
(program
(class
(body_statement
(call) @class_call
(assignment) @class_assignment
(method) @method
)
) @class
)
(program
(method) @function
)
(program
(assignment) @assignment
)

View File

@ -0,0 +1,20 @@
;; Capture public functions, structs, methods, and variable definitions
(function_item) @function
(impl_item
body: (declaration_list
(function_item) @method
)
)
(struct_item) @class
(struct_item
body: (field_declaration_list
(field_declaration) @class_variable
)
)
(enum_item
body: (enum_variant_list
(enum_variant) @enum_item
)
)
(const_item) @variable
(static_item) @variable

View File

@ -0,0 +1,33 @@
;; Capture exported functions, arrow functions, variables, classes, and method definitions
(export_statement
declaration: (lexical_declaration
(variable_declarator) @variable
)
)
(export_statement
declaration: (function_declaration) @function
)
(export_statement
declaration: (class_declaration
body: (class_body
(public_field_definition) @class_variable
)
)
)
(interface_declaration
body: (interface_body
(property_signature) @class_variable
)
)
(type_alias_declaration
value: (object_type
(property_signature) @class_variable
)
)
(export_statement
declaration: (class_declaration
body: (class_body
(method_definition) @method
)
)
)

View File

@ -0,0 +1,894 @@
use mlua::prelude::*;
use std::cell::RefCell;
use std::collections::HashMap;
use tree_sitter::{Node, Parser, Query, QueryCursor};
use tree_sitter_language::LanguageFn;
#[derive(Debug, Clone)]
pub struct Func {
pub name: String,
pub params: String,
pub return_type: String,
pub accessibility_modifier: Option<String>,
}
#[derive(Debug, Clone)]
pub struct Class {
pub name: String,
pub methods: Vec<Func>,
pub properties: Vec<Variable>,
pub visibility_modifier: Option<String>,
}
#[derive(Debug, Clone)]
pub struct Enum {
pub name: String,
pub items: Vec<Variable>,
}
#[derive(Debug, Clone)]
pub struct Variable {
pub name: String,
pub value_type: String,
}
#[derive(Debug, Clone)]
pub enum Definition {
Func(Func),
Class(Class),
Enum(Enum),
Variable(Variable),
}
fn get_ts_language(language: &str) -> Result<LanguageFn, String> {
match language {
"rust" => Ok(tree_sitter_rust::LANGUAGE),
"python" => Ok(tree_sitter_python::LANGUAGE),
"javascript" => Ok(tree_sitter_javascript::LANGUAGE),
"typescript" => Ok(tree_sitter_typescript::LANGUAGE_TSX),
"go" => Ok(tree_sitter_go::LANGUAGE),
"c" => Ok(tree_sitter_c::LANGUAGE),
"cpp" => Ok(tree_sitter_cpp::LANGUAGE),
"lua" => Ok(tree_sitter_lua::LANGUAGE),
"ruby" => Ok(tree_sitter_ruby::LANGUAGE),
_ => Err(format!("Unsupported language: {language}")),
}
}
const C_QUERY: &str = include_str!("../queries/tree-sitter-c-defs.scm");
const CPP_QUERY: &str = include_str!("../queries/tree-sitter-cpp-defs.scm");
const GO_QUERY: &str = include_str!("../queries/tree-sitter-go-defs.scm");
const JAVASCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-javascript-defs.scm");
const LUA_QUERY: &str = include_str!("../queries/tree-sitter-lua-defs.scm");
const PYTHON_QUERY: &str = include_str!("../queries/tree-sitter-python-defs.scm");
const RUST_QUERY: &str = include_str!("../queries/tree-sitter-rust-defs.scm");
const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-defs.scm");
const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm");
fn get_definitions_query(language: &str) -> Result<Query, String> {
let ts_language = get_ts_language(language)?;
let contents = match language {
"c" => C_QUERY,
"cpp" => CPP_QUERY,
"go" => GO_QUERY,
"javascript" => JAVASCRIPT_QUERY,
"lua" => LUA_QUERY,
"python" => PYTHON_QUERY,
"rust" => RUST_QUERY,
"typescript" => TYPESCRIPT_QUERY,
"ruby" => RUBY_QUERY,
_ => return Err(format!("Unsupported language: {language}")),
};
let query = Query::new(&ts_language.into(), contents)
.unwrap_or_else(|_| panic!("Failed to parse query for {language}"));
Ok(query)
}
fn get_closest_ancestor_name(node: &Node, source: &str) -> String {
let mut parent = node.parent();
while let Some(parent_node) = parent {
let name_node = parent_node.child_by_field_name("name");
if let Some(name_node) = name_node {
return get_node_text(&name_node, source.as_bytes()).to_string();
}
parent = parent_node.parent();
}
String::new()
}
fn find_ancestor_by_type<'a>(node: &'a Node, parent_type: &str) -> Option<Node<'a>> {
let mut parent = node.parent();
while let Some(parent_node) = parent {
if parent_node.kind() == parent_type {
return Some(parent_node);
}
parent = parent_node.parent();
}
None
}
fn find_descendant_by_type<'a>(node: &'a Node, child_type: &str) -> Option<Node<'a>> {
let mut cursor = node.walk();
for i in 0..node.descendant_count() {
cursor.goto_descendant(i);
let node = cursor.node();
if node.kind() == child_type {
return Some(node);
}
}
None
}
fn find_child_by_type<'a>(node: &'a Node, child_type: &str) -> Option<Node<'a>> {
node.children(&mut node.walk())
.find(|child| child.kind() == child_type)
}
fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String {
node.utf8_text(source).unwrap_or_default().to_string()
}
fn get_node_type<'a>(node: &'a Node, source: &'a [u8]) -> String {
let predefined_type_node = find_descendant_by_type(node, "predefined_type");
if let Some(type_node) = predefined_type_node {
return type_node.utf8_text(source).unwrap().to_string();
}
let value_type_node = node.child_by_field_name("type");
value_type_node
.map(|n| n.utf8_text(source).unwrap().to_string())
.unwrap_or_default()
}
fn is_first_letter_uppercase(name: &str) -> bool {
if name.is_empty() {
return false;
}
name.chars().next().unwrap().is_uppercase()
}
// Given a language, parse the given source code and return exported definitions
fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>, String> {
let ts_language = get_ts_language(language)?;
let mut definitions = Vec::new();
let mut parser = Parser::new();
parser
.set_language(&ts_language.into())
.unwrap_or_else(|_| panic!("Failed to set language for {language}"));
let tree = parser
.parse(source, None)
.unwrap_or_else(|| panic!("Failed to parse source code for {language}"));
let root_node = tree.root_node();
let query = get_definitions_query(language)?;
let mut query_cursor = QueryCursor::new();
let captures = query_cursor.captures(&query, root_node, source.as_bytes());
let mut class_def_map: HashMap<String, RefCell<Class>> = HashMap::new();
let mut enum_def_map: HashMap<String, RefCell<Enum>> = HashMap::new();
let ensure_class_def = |name: &str, class_def_map: &mut HashMap<String, RefCell<Class>>| {
class_def_map.entry(name.to_string()).or_insert_with(|| {
RefCell::new(Class {
name: name.to_string(),
methods: vec![],
properties: vec![],
visibility_modifier: None,
})
});
};
let ensure_enum_def = |name: &str, enum_def_map: &mut HashMap<String, RefCell<Enum>>| {
enum_def_map.entry(name.to_string()).or_insert_with(|| {
RefCell::new(Enum {
name: name.to_string(),
items: vec![],
})
});
};
for (m, _) in captures {
for capture in m.captures {
let capture_name = &query.capture_names()[capture.index as usize];
let node = capture.node;
let name_node = node.child_by_field_name("name");
let name = name_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
match *capture_name {
"class" => {
if !name.is_empty() {
if language == "go" && !is_first_letter_uppercase(name) {
continue;
}
ensure_class_def(name, &mut class_def_map);
let visibility_modifier_node =
find_child_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
let class_def = class_def_map.get_mut(name).unwrap();
class_def.borrow_mut().visibility_modifier =
if visibility_modifier.is_empty() {
None
} else {
Some(visibility_modifier.to_string())
};
}
}
"enum_item" => {
let visibility_modifier_node =
find_descendant_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
if language == "rust" && !visibility_modifier.contains("pub") {
continue;
}
let enum_name = get_closest_ancestor_name(&node, source);
if !enum_name.is_empty()
&& language == "go"
&& !is_first_letter_uppercase(&enum_name)
{
continue;
}
ensure_enum_def(&enum_name, &mut enum_def_map);
let enum_def = enum_def_map.get_mut(&enum_name).unwrap();
let enum_type_node = find_descendant_by_type(&node, "type_identifier");
let enum_type = enum_type_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
let variable = Variable {
name: name.to_string(),
value_type: enum_type.to_string(),
};
enum_def.borrow_mut().items.push(variable);
}
"method" => {
let visibility_modifier_node =
find_descendant_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
if language == "rust" && !visibility_modifier.contains("pub") {
continue;
}
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
continue;
}
let params_node = node.child_by_field_name("parameters");
let params = params_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("()");
let mut return_type_node = node.child_by_field_name("return_type");
if return_type_node.is_none() {
return_type_node = node.child_by_field_name("result");
}
let mut return_type = "void".to_string();
if return_type_node.is_some() {
return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes());
if return_type.is_empty() {
return_type = return_type_node
.unwrap()
.utf8_text(source.as_bytes())
.unwrap_or("void")
.to_string();
}
}
let impl_item_node = find_ancestor_by_type(&node, "impl_item");
let receiver_node = node.child_by_field_name("receiver");
let class_name = if let Some(impl_item) = impl_item_node {
let impl_type_node = impl_item.child_by_field_name("type");
impl_type_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("")
.to_string()
} else if let Some(receiver) = receiver_node {
let type_identifier_node =
find_descendant_by_type(&receiver, "type_identifier");
type_identifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("")
.to_string()
} else {
get_closest_ancestor_name(&node, source).to_string()
};
if language == "go" && !is_first_letter_uppercase(&class_name) {
continue;
}
ensure_class_def(&class_name, &mut class_def_map);
let class_def = class_def_map.get_mut(&class_name).unwrap();
let accessibility_modifier_node =
find_descendant_by_type(&node, "accessibility_modifier");
let accessibility_modifier = accessibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
let func = Func {
name: name.to_string(),
params: params.to_string(),
return_type: return_type.to_string(),
accessibility_modifier: if accessibility_modifier.is_empty() {
None
} else {
Some(accessibility_modifier.to_string())
},
};
class_def.borrow_mut().methods.push(func);
}
"class_assignment" => {
let visibility_modifier_node =
find_descendant_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
if language == "rust" && !visibility_modifier.contains("pub") {
continue;
}
let left_node = node.child_by_field_name("left");
let left = left_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
let value_type = get_node_type(&node, source.as_bytes());
let class_name = get_closest_ancestor_name(&node, source);
if !class_name.is_empty()
&& language == "go"
&& !is_first_letter_uppercase(&class_name)
{
continue;
}
if class_name.is_empty() {
continue;
}
ensure_class_def(&class_name, &mut class_def_map);
let class_def = class_def_map.get_mut(&class_name).unwrap();
let variable = Variable {
name: left.to_string(),
value_type: value_type.to_string(),
};
class_def.borrow_mut().properties.push(variable);
}
"class_variable" => {
let visibility_modifier_node =
find_descendant_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
if language == "rust" && !visibility_modifier.contains("pub") {
continue;
}
let value_type = get_node_type(&node, source.as_bytes());
let class_name = get_closest_ancestor_name(&node, source);
if !class_name.is_empty()
&& language == "go"
&& !is_first_letter_uppercase(&class_name)
{
continue;
}
if class_name.is_empty() {
continue;
}
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
continue;
}
ensure_class_def(&class_name, &mut class_def_map);
let class_def = class_def_map.get_mut(&class_name).unwrap();
let variable = Variable {
name: name.to_string(),
value_type: value_type.to_string(),
};
class_def.borrow_mut().properties.push(variable);
}
"function" | "arrow_function" => {
let visibility_modifier_node =
find_descendant_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
if language == "rust" && !visibility_modifier.contains("pub") {
continue;
}
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
continue;
}
let impl_item_node = find_ancestor_by_type(&node, "impl_item");
if impl_item_node.is_some() {
continue;
}
let function_node = find_ancestor_by_type(&node, "function_declaration")
.or_else(|| find_ancestor_by_type(&node, "function_definition"));
if function_node.is_some() {
continue;
}
let params_node = node.child_by_field_name("parameters");
let params = params_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("()");
let mut return_type_node = node.child_by_field_name("return_type");
if return_type_node.is_none() {
return_type_node = node.child_by_field_name("result");
}
let mut return_type = "void".to_string();
if return_type_node.is_some() {
return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes());
if return_type.is_empty() {
return_type = return_type_node
.unwrap()
.utf8_text(source.as_bytes())
.unwrap_or("void")
.to_string();
}
}
let accessibility_modifier_node =
find_descendant_by_type(&node, "accessibility_modifier");
let accessibility_modifier = accessibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
let func = Func {
name: name.to_string(),
params: params.to_string(),
return_type: return_type.to_string(),
accessibility_modifier: if accessibility_modifier.is_empty() {
None
} else {
Some(accessibility_modifier.to_string())
},
};
definitions.push(Definition::Func(func));
}
"assignment" => {
let visibility_modifier_node =
find_descendant_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
if language == "rust" && !visibility_modifier.contains("pub") {
continue;
}
let impl_item_node = find_ancestor_by_type(&node, "impl_item")
.or_else(|| find_ancestor_by_type(&node, "class_declaration"))
.or_else(|| find_ancestor_by_type(&node, "class_definition"));
if impl_item_node.is_some() {
continue;
}
let function_node = find_ancestor_by_type(&node, "function_declaration")
.or_else(|| find_ancestor_by_type(&node, "function_definition"));
if function_node.is_some() {
continue;
}
let left_node = node.child_by_field_name("left");
let left = left_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
if !left.is_empty() && language == "go" && !is_first_letter_uppercase(left) {
continue;
}
let value_type = get_node_type(&node, source.as_bytes());
let variable = Variable {
name: left.to_string(),
value_type: value_type.to_string(),
};
definitions.push(Definition::Variable(variable));
}
"variable" => {
let visibility_modifier_node =
find_descendant_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
if language == "rust" && !visibility_modifier.contains("pub") {
continue;
}
let impl_item_node = find_ancestor_by_type(&node, "impl_item")
.or_else(|| find_ancestor_by_type(&node, "class_declaration"))
.or_else(|| find_ancestor_by_type(&node, "class_definition"));
if impl_item_node.is_some() {
continue;
}
let function_node = find_ancestor_by_type(&node, "function_declaration")
.or_else(|| find_ancestor_by_type(&node, "function_definition"));
if function_node.is_some() {
continue;
}
let value_node = node.child_by_field_name("value");
if value_node.is_some() {
let value_type = value_node.unwrap().kind();
if value_type == "arrow_function" {
let params_node = value_node.unwrap().child_by_field_name("parameters");
let params = params_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("()");
let mut return_type = "void".to_string();
let return_type_node =
value_node.unwrap().child_by_field_name("return_type");
if return_type_node.is_some() {
return_type =
get_node_type(&return_type_node.unwrap(), source.as_bytes());
}
let func = Func {
name: name.to_string(),
params: params.to_string(),
return_type,
accessibility_modifier: None,
};
definitions.push(Definition::Func(func));
continue;
}
}
let value_type = get_node_type(&node, source.as_bytes());
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
continue;
}
let variable = Variable {
name: name.to_string(),
value_type: value_type.to_string(),
};
definitions.push(Definition::Variable(variable));
}
_ => {}
}
}
}
for (_, def) in class_def_map {
let class_def = def.into_inner();
if language == "rust" {
if let Some(visibility_modifier) = &class_def.visibility_modifier {
if visibility_modifier.contains("pub") {
definitions.push(Definition::Class(class_def));
}
}
} else {
definitions.push(Definition::Class(class_def));
}
}
for (_, def) in enum_def_map {
definitions.push(Definition::Enum(def.into_inner()));
}
Ok(definitions)
}
fn stringify_function(func: &Func) -> String {
let mut res = format!("func {}", func.name);
if func.params.is_empty() {
res = format!("{res}()");
} else {
res = format!("{res}{}", func.params);
}
if !func.return_type.is_empty() {
res = format!("{res} -> {}", func.return_type);
}
if let Some(modifier) = &func.accessibility_modifier {
res = format!("{modifier} {res}");
}
format!("{res};")
}
fn stringify_variable(variable: &Variable) -> String {
let mut res = format!("var {}", variable.name);
if !variable.value_type.is_empty() {
res = format!("{res}:{}", variable.value_type);
}
format!("{res};")
}
fn stringify_enum_item(item: &Variable) -> String {
let mut res = item.name.clone();
if !item.value_type.is_empty() {
res = format!("{res}:{}", item.value_type);
}
format!("{res};")
}
fn stringify_class(class: &Class) -> String {
let mut res = format!("class {}{{", class.name);
for method in &class.methods {
let method_str = stringify_function(method);
res = format!("{res}{method_str}");
}
for property in &class.properties {
let property_str = stringify_variable(property);
res = format!("{res}{property_str}");
}
format!("{res}}};")
}
fn stringify_enum(enum_def: &Enum) -> String {
let mut res = format!("enum {}{{", enum_def.name);
for item in &enum_def.items {
let item_str = stringify_enum_item(item);
res = format!("{res}{item_str}");
}
format!("{res}}};")
}
fn stringify_definitions(definitions: &Vec<Definition>) -> String {
let mut res = String::new();
for definition in definitions {
match definition {
Definition::Class(class) => res = format!("{res}{}", stringify_class(class)),
Definition::Enum(enum_def) => res = format!("{res}{}", stringify_enum(enum_def)),
Definition::Func(func) => res = format!("{res}{}", stringify_function(func)),
Definition::Variable(variable) => {
let variable_str = stringify_variable(variable);
res = format!("{res}{variable_str}");
}
}
}
res
}
pub fn get_definitions_string(language: &str, source: &str) -> LuaResult<String> {
let definitions =
extract_definitions(language, source).map_err(|e| LuaError::RuntimeError(e.to_string()))?;
let stringified = stringify_definitions(&definitions);
Ok(stringified)
}
#[mlua::lua_module]
fn avante_repo_map(lua: &Lua) -> LuaResult<LuaTable> {
let exports = lua.create_table()?;
exports.set(
"stringify_definitions",
lua.create_function(move |_, (language, source): (String, String)| {
get_definitions_string(language.as_str(), source.as_str())
})?,
)?;
Ok(exports)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rust() {
let source = r#"
// This is a test comment
pub const TEST_CONST: u32 = 1;
pub static TEST_STATIC: u32 = 2;
const INNER_TEST_CONST: u32 = 3;
static INNER_TEST_STATIC: u32 = 4;
pub(crate) struct TestStruct {
pub test_field: String,
inner_test_field: String,
}
impl TestStruct {
pub fn test_method(&self, a: u32, b: u32) -> u32 {
a + b
}
fn inner_test_method(&self, a: u32, b: u32) -> u32 {
a + b
}
}
struct InnerTestStruct {
pub test_field: String,
inner_test_field: String,
}
impl InnerTestStruct {
pub fn test_method(&self, a: u32, b: u32) -> u32 {
a + b
}
fn inner_test_method(&self, a: u32, b: u32) -> u32 {
a + b
}
}
pub enum TestEnum {
TestEnumField1,
TestEnumField2,
}
enum InnerTestEnum {
InnerTestEnumField1,
InnerTestEnumField2,
}
pub fn test_fn(a: u32, b: u32) -> u32 {
a + b
}
fn inner_test_fn(a: u32, b: u32) -> u32 {
a + b
}
"#;
let definitions = extract_definitions("rust", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
let expected = "var TEST_CONST:u32;var TEST_STATIC:u32;func test_fn(a: u32, b: u32) -> u32;class TestStruct{func test_method(&self, a: u32, b: u32) -> u32;var test_field:String;};";
assert_eq!(stringified, expected);
}
#[test]
fn test_go() {
let source = r#"
// This is a test comment
package main
import "fmt"
const TestConst string = "test"
const innerTestConst string = "test"
var TestVar string
var innerTestVar string
type TestStruct struct {
TestField string
innerTestField string
}
func (t *TestStruct) TestMethod(a int, b int) (int, error) {
return a + b, nil
}
func (t *TestStruct) innerTestMethod(a int, b int) (int, error) {
return a + b, nil
}
type innerTestStruct struct {
innerTestField string
}
func (t *innerTestStruct) testMethod(a int, b int) (int, error) {
return a + b, nil
}
func (t *innerTestStruct) innerTestMethod(a int, b int) (int, error) {
return a + b, nil
}
func TestFunc(a int, b int) (int, error) {
return a + b, nil
}
func innerTestFunc(a int, b int) (int, error) {
return a + b, nil
}
"#;
let definitions = extract_definitions("go", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
let expected = "var TestConst:string;var TestVar:string;func TestFunc(a int, b int) -> (int, error);class TestStruct{func TestMethod(a int, b int) -> (int, error);var TestField:string;};";
assert_eq!(stringified, expected);
}
#[test]
fn test_python() {
let source = r#"
# This is a test comment
test_var: str = "test"
class TestClass:
def __init__(self, a, b):
self.a = a
self.b = b
def test_method(self, a: int, b: int) -> int:
return a + b
def test_func(a: int, b: int) -> int:
return a + b
"#;
let definitions = extract_definitions("python", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
let expected = "var test_var:str;func test_func(a: int, b: int) -> int;class TestClass{func __init__(self, a, b) -> void;func test_method(self, a: int, b: int) -> int;};";
assert_eq!(stringified, expected);
}
#[test]
fn test_typescript() {
let source = r#"
// This is a test comment
export const testVar: string = "test";
const innerTestVar: string = "test";
export class TestClass {
a: number;
b: number;
constructor(a: number, b: number) {
this.a = a;
this.b = b;
}
testMethod(a: number, b: number): number {
return a + b;
}
}
class InnerTestClass {
a: number;
b: number;
}
export function testFunc(a: number, b: number) {
return a + b;
}
export const testFunc2 = (a: number, b: number) => {
return a + b;
}
export const testFunc3 = (a: number, b: number): number => {
return a + b;
}
function innerTestFunc(a: number, b: number) {
return a + b;
}
"#;
let definitions = extract_definitions("typescript", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
let expected = "var testVar:string;func testFunc(a: number, b: number) -> void;func testFunc2(a: number, b: number) -> void;func testFunc3(a: number, b: number) -> number;class TestClass{func constructor(a: number, b: number) -> void;func testMethod(a: number, b: number) -> number;var a:number;var b:number;};"
;
assert_eq!(stringified, expected);
}
#[test]
fn test_javascript() {
let source = r#"
// This is a test comment
export const testVar = "test";
const innerTestVar = "test";
export class TestClass {
constructor(a, b) {
this.a = a;
this.b = b;
}
testMethod(a, b) {
return a + b;
}
}
class InnerTestClass {
constructor(a, b) {
this.a = a;
this.b = b;
}
}
export const testFunc = function(a, b) {
return a + b;
}
export const testFunc2 = (a, b) => {
return a + b;
}
export const testFunc3 = (a, b) => a + b;
function innerTestFunc(a, b) {
return a + b;
}
"#;
let definitions = extract_definitions("javascript", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
let expected = "var testVar;var testFunc;func testFunc2(a, b) -> void;func testFunc3(a, b) -> void;class TestClass{func constructor(a, b) -> void;func testMethod(a, b) -> void;};";
assert_eq!(stringified, expected);
}
#[test]
fn test_ruby() {
let source = r#"
# This is a test comment
test_var = "test"
def test_func(a, b)
return a + b
end
class TestClass
attr_accessor :a, :b
def initialize(a, b)
@a = a
@b = b
end
def test_method(a, b)
return a + b
end
end
"#;
let definitions = extract_definitions("ruby", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
// FIXME:
let expected = "var test_var;func test_func(a, b) -> void;";
assert_eq!(stringified, expected);
}
#[test]
fn test_lua() {
let source = r#"
-- This is a test comment
local test_var = "test"
function test_func(a, b)
return a + b
end
"#;
let definitions = extract_definitions("lua", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
let expected = "var test_var;func test_func(a, b) -> void;";
assert_eq!(stringified, expected);
}
}

View File

@ -397,6 +397,7 @@ function M.setup(opts)
if M.did_setup then return end if M.did_setup then return end
require("avante.repo_map").setup()
require("avante.path").setup() require("avante.path").setup()
require("avante.highlights").setup() require("avante.highlights").setup()
require("avante.diff").setup() require("avante.diff").setup()

View File

@ -85,6 +85,7 @@ M.stream = function(opts)
:totable() :totable()
Utils.debug(user_prompts) Utils.debug(user_prompts)
-- print(user_prompts[1])
---@type AvantePromptOptions ---@type AvantePromptOptions
local code_opts = { local code_opts = {

147
lua/avante/repo_map.lua Normal file
View File

@ -0,0 +1,147 @@
local Utils = require("avante.utils")
local filetype_map = {
["javascriptreact"] = "javascript",
["typescriptreact"] = "typescript",
}
---@class AvanteRepoMap
---@field stringify_definitions fun(lang: string, source: string): string
local repo_map_lib = nil
---@class avante.utils.repo_map
local RepoMap = {}
function RepoMap.setup()
vim.defer_fn(function()
local ok, core = pcall(require, "avante_repo_map")
if not ok then
error("Failed to load avante_repo_map")
return
end
if repo_map_lib == nil then repo_map_lib = core end
end, 1000)
end
function RepoMap.get_ts_lang(filepath)
local filetype = vim.filetype.match({ filename = filepath })
return filetype_map[filetype] or filetype
end
function RepoMap.get_filetype(filepath) return vim.filetype.match({ filename = filepath }) end
function RepoMap._build_repo_map(project_root, file_ext)
local output = {}
local gitignore_path = project_root .. "/.gitignore"
local ignore_patterns, negate_patterns = Utils.parse_gitignore(gitignore_path)
local filepaths = Utils.scan_directory(project_root, ignore_patterns, negate_patterns)
vim.iter(filepaths):each(function(filepath)
if not Utils.is_same_file_ext(file_ext, filepath) then return end
local definitions =
repo_map_lib.stringify_definitions(RepoMap.get_ts_lang(filepath), Utils.file.read_content(filepath) or "")
if definitions == "" then return end
table.insert(output, {
path = Utils.relative_path(filepath),
lang = RepoMap.get_filetype(filepath),
defs = definitions,
})
end)
return output
end
local cache = {}
function RepoMap.get_repo_map(file_ext)
file_ext = file_ext or vim.fn.expand("%:e")
local project_root = Utils.root.get()
local cache_key = project_root .. "." .. file_ext
local cached = cache[cache_key]
if cached then return cached end
local PPath = require("plenary.path")
local Path = require("avante.path")
local repo_map
local function build_and_save()
repo_map = RepoMap._build_repo_map(project_root, file_ext)
cache[cache_key] = repo_map
Path.repo_map.save(project_root, file_ext, repo_map)
end
repo_map = Path.repo_map.load(project_root, file_ext)
if not repo_map or next(repo_map) == nil then
build_and_save()
if not repo_map then return end
else
local timer = vim.loop.new_timer()
if timer then
timer:start(
0,
0,
vim.schedule_wrap(function()
build_and_save()
timer:close()
end)
)
end
end
local update_repo_map = vim.schedule_wrap(function(rel_filepath)
if rel_filepath and Utils.is_same_file_ext(file_ext, rel_filepath) then
local abs_filepath = PPath:new(project_root):joinpath(rel_filepath):absolute()
local definitions = repo_map_lib.stringify_definitions(
RepoMap.get_ts_lang(abs_filepath),
Utils.file.read_content(abs_filepath) or ""
)
if definitions == "" then return end
local found = false
for _, m in ipairs(repo_map) do
if m.path == rel_filepath then
m.defs = definitions
found = true
break
end
end
if not found then
table.insert(repo_map, {
path = Utils.relative_path(abs_filepath),
lang = RepoMap.get_filetype(abs_filepath),
defs = definitions,
})
end
cache[cache_key] = repo_map
Path.repo_map.save(project_root, file_ext, repo_map)
end
end)
local handle = vim.loop.new_fs_event()
if handle then
handle:start(project_root, { recursive = true }, function(err, rel_filepath)
if err then
print("Error watching directory " .. project_root .. ":", err)
return
end
if rel_filepath then update_repo_map(rel_filepath) end
end)
end
vim.api.nvim_create_autocmd({ "BufReadPost", "BufNewFile" }, {
callback = function(ev)
vim.defer_fn(function()
local filepath = vim.api.nvim_buf_get_name(ev.buf)
if not vim.startswith(filepath, project_root) then return end
local rel_filepath = Utils.relative_path(filepath)
update_repo_map(rel_filepath)
end, 0)
end,
})
return repo_map
end
return RepoMap

View File

@ -2,6 +2,7 @@ local Utils = require("avante.utils")
local Config = require("avante.config") local Config = require("avante.config")
local Llm = require("avante.llm") local Llm = require("avante.llm")
local Provider = require("avante.providers") local Provider = require("avante.providers")
local RepoMap = require("avante.repo_map")
local api = vim.api local api = vim.api
local fn = vim.fn local fn = vim.fn
@ -394,7 +395,7 @@ function Selection:create_editing_input()
local mentions = Utils.extract_mentions(input) local mentions = Utils.extract_mentions(input)
input = mentions.new_content input = mentions.new_content
local project_context = mentions.enable_project_context and Utils.repo_map.get_repo_map(file_ext) or nil local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil
Llm.stream({ Llm.stream({
bufnr = code_bufnr, bufnr = code_bufnr,

View File

@ -11,6 +11,7 @@ local Diff = require("avante.diff")
local Llm = require("avante.llm") local Llm = require("avante.llm")
local Utils = require("avante.utils") local Utils = require("avante.utils")
local Highlights = require("avante.highlights") local Highlights = require("avante.highlights")
local RepoMap = require("avante.repo_map")
local RESULT_BUF_NAME = "AVANTE_RESULT" local RESULT_BUF_NAME = "AVANTE_RESULT"
local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated" local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated"
@ -1295,7 +1296,7 @@ function Sidebar:create_input(opts)
local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$") local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$")
local project_context = mentions.enable_project_context and Utils.repo_map.get_repo_map(file_ext) or nil local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil
Llm.stream({ Llm.stream({
bufnr = self.code.bufnr, bufnr = self.code.bufnr,

View File

@ -6,6 +6,7 @@ local lsp = vim.lsp
---@field tokens avante.utils.tokens ---@field tokens avante.utils.tokens
---@field root avante.utils.root ---@field root avante.utils.root
---@field repo_map avante.utils.repo_map ---@field repo_map avante.utils.repo_map
---@field file avante.utils.file
local M = {} local M = {}
setmetatable(M, { setmetatable(M, {

View File

@ -1,730 +0,0 @@
local parsers = require("nvim-treesitter.parsers")
local Config = require("avante.config")
local get_node_text = vim.treesitter.get_node_text
---@class avante.utils.repo_map
local RepoMap = {}
local dependencies_queries = {
lua = [[
(function_call
name: (identifier) @function_name
arguments: (arguments
(string) @required_file))
]],
python = [[
(import_from_statement
module_name: (dotted_name) @import_module)
(import_statement
(dotted_name) @import_module)
]],
javascript = [[
(import_statement
source: (string) @import_module)
(call_expression
function: (identifier) @function_name
arguments: (arguments
(string) @required_file))
]],
typescript = [[
(import_statement
source: (string) @import_module)
(call_expression
function: (identifier) @function_name
arguments: (arguments
(string) @required_file))
]],
go = [[
(import_spec
path: (interpreted_string_literal) @import_module)
]],
rust = [[
(use_declaration
(scoped_identifier) @import_module)
(use_declaration
(identifier) @import_module)
]],
c = [[
(preproc_include
(string_literal) @import_module)
(preproc_include
(system_lib_string) @import_module)
]],
cpp = [[
(preproc_include
(string_literal) @import_module)
(preproc_include
(system_lib_string) @import_module)
]],
}
local definitions_queries = {
python = [[
;; Capture top-level functions, class, and method definitions
(module
(expression_statement
(assignment) @assignment
)
)
(module
(function_definition) @function
)
(module
(class_definition
body: (block
(expression_statement
(assignment) @class_assignment
)
)
)
)
(module
(class_definition
body: (block
(function_definition) @method
)
)
)
]],
javascript = [[
;; Capture exported functions, arrow functions, variables, classes, and method definitions
(export_statement
declaration: (lexical_declaration
(variable_declarator) @variable
)
)
(export_statement
declaration: (function_declaration) @function
)
(export_statement
declaration: (class_declaration
body: (class_body
(field_definition) @class_variable
)
)
)
(export_statement
declaration: (class_declaration
body: (class_body
(method_definition) @method
)
)
)
]],
typescript = [[
;; Capture exported functions, arrow functions, variables, classes, and method definitions
(export_statement
declaration: (lexical_declaration
(variable_declarator) @variable
)
)
(export_statement
declaration: (function_declaration) @function
)
(export_statement
declaration: (class_declaration
body: (class_body
(public_field_definition) @class_variable
)
)
)
(interface_declaration
body: (interface_body
(property_signature) @class_variable
)
)
(type_alias_declaration
value: (object_type
(property_signature) @class_variable
)
)
(export_statement
declaration: (class_declaration
body: (class_body
(method_definition) @method
)
)
)
]],
rust = [[
;; Capture public functions, structs, methods, and variable definitions
(function_item) @function
(impl_item
body: (declaration_list
(function_item) @method
)
)
(struct_item
body: (field_declaration_list
(field_declaration) @class_variable
)
)
(enum_item
body: (enum_variant_list
(enum_variant) @enum_item
)
)
(const_item) @variable
]],
go = [[
;; Capture top-level functions and struct definitions
(var_declaration
(var_spec) @variable
)
(const_declaration
(const_spec) @variable
)
(function_declaration) @function
(type_declaration
(type_spec (struct_type)) @class
)
(type_declaration
(type_spec
(struct_type
(field_declaration_list
(field_declaration) @class_variable)))
)
(method_declaration) @method
]],
c = [[
;; Capture extern functions, variables, public classes, and methods
(function_definition
(storage_class_specifier) @extern
) @function
(class_specifier
(public) @class
(function_definition) @method
) @class
(declaration
(storage_class_specifier) @extern
) @variable
]],
cpp = [[
;; Capture extern functions, variables, public classes, and methods
(function_definition
(storage_class_specifier) @extern
) @function
(class_specifier
(public) @class
(function_definition) @method
) @class
(declaration
(storage_class_specifier) @extern
) @variable
]],
lua = [[
;; Capture function and method definitions
(variable_list) @variable
(function_declaration) @function
]],
ruby = [[
;; Capture top-level methods, class definitions, and methods within classes
(method) @function
(assignment) @assignment
(class
body: (body_statement
(assignment) @class_assignment
(method) @method
)
)
]],
}
local queries_filetype_map = {
["javascriptreact"] = "javascript",
["typescriptreact"] = "typescript",
}
local function get_query(queries, filetype)
filetype = queries_filetype_map[filetype] or filetype
return queries[filetype]
end
local function get_ts_lang(bufnr)
local lang = parsers.get_buf_lang(bufnr)
return lang
end
function RepoMap.get_parser(bufnr)
local lang = get_ts_lang(bufnr)
if not lang then return end
local parser = parsers.get_parser(bufnr, lang)
return parser, lang
end
function RepoMap.extract_dependencies(bufnr)
local parser, lang = RepoMap.get_parser(bufnr)
if not lang or not parser or not dependencies_queries[lang] then
print("No parser or query available for this buffer's language: " .. (lang or "unknown"))
return {}
end
local dependencies = {}
local tree = parser:parse()[1]
local root = tree:root()
local filetype = vim.api.nvim_get_option_value("filetype", { buf = bufnr })
local query = get_query(dependencies_queries, filetype)
if not query then return dependencies end
local query_obj = vim.treesitter.query.parse(lang, query)
for _, node, _ in query_obj:iter_captures(root, bufnr, 0, -1) do
-- local name = query.captures[id]
local required_file = vim.treesitter.get_node_text(node, bufnr):gsub('"', ""):gsub("'", "")
table.insert(dependencies, required_file)
end
return dependencies
end
function RepoMap.get_filetype_by_filepath(filepath) return vim.filetype.match({ filename = filepath }) end
function RepoMap.parse_file(filepath)
local File = require("avante.utils.file")
local source = File.read_content(filepath)
local filetype = RepoMap.get_filetype_by_filepath(filepath)
local lang = parsers.ft_to_lang(filetype)
if lang then
local ok, parser = pcall(vim.treesitter.get_string_parser, source, lang)
if ok then
local tree = parser:parse()[1]
local node = tree:root()
return { node = node, source = source }
else
print("parser error", parser)
end
end
end
local function get_closest_parent_name(node, source)
local parent = node:parent()
while parent do
local name = parent:field("name")[1]
if name then return get_node_text(name, source) end
parent = parent:parent()
end
return ""
end
local function find_parent_by_type(node, type)
local parent = node:parent()
while parent do
if parent:type() == type then return parent end
parent = parent:parent()
end
return nil
end
local function find_child_by_type(node, type)
for child in node:iter_children() do
if child:type() == type then return child end
local res = find_child_by_type(child, type)
if res then return res end
end
return nil
end
local function get_node_type(node, source)
local node_type
local predefined_type_node = find_child_by_type(node, "predefined_type")
if predefined_type_node then
node_type = get_node_text(predefined_type_node, source)
else
local value_type_node = node:field("type")[1]
node_type = value_type_node and get_node_text(value_type_node, source) or ""
end
return node_type
end
-- Function to extract definitions from the file
function RepoMap.extract_definitions(filepath)
local Utils = require("avante.utils")
local filetype = RepoMap.get_filetype_by_filepath(filepath)
if not filetype then return {} end
-- Get the corresponding query for the detected language
local query = get_query(definitions_queries, filetype)
if not query then return {} end
local parsed = RepoMap.parse_file(filepath)
if not parsed then return {} end
-- Get the current buffer's syntax tree
local root = parsed.node
local lang = parsers.ft_to_lang(filetype)
-- Parse the query
local query_obj = vim.treesitter.query.parse(lang, query)
-- Store captured results
local definitions = {}
local class_def_map = {}
local enum_def_map = {}
local function get_class_def(name)
local def = class_def_map[name]
if def == nil then
def = {
type = "class",
name = name,
methods = {},
properties = {},
}
class_def_map[name] = def
end
return def
end
local function get_enum_def(name)
local def = enum_def_map[name]
if def == nil then
def = {
type = "enum",
name = name,
items = {},
}
enum_def_map[name] = def
end
return def
end
for _, captures, _ in query_obj:iter_matches(root, parsed.source) do
for id, node in pairs(captures) do
local type = query_obj.captures[id]
local name_node = node:field("name")[1]
local name = name_node and get_node_text(name_node, parsed.source) or ""
if type == "class" then
if name ~= "" then get_class_def(name) end
elseif type == "enum_item" then
local enum_name = get_closest_parent_name(node, parsed.source)
if enum_name and filetype == "go" and not Utils.is_first_letter_uppercase(enum_name) then goto continue end
local enum_def = get_enum_def(enum_name)
local enum_type_node = find_child_by_type(node, "type_identifier")
local enum_type = enum_type_node and get_node_text(enum_type_node, parsed.source) or ""
table.insert(enum_def.items, {
name = name,
type = enum_type,
})
elseif type == "method" then
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
local params_node = node:field("parameters")[1]
local params = params_node and get_node_text(params_node, parsed.source) or "()"
local return_type_node = node:field("return_type")[1] or node:field("result")[1]
local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void"
local class_name
local impl_item_node = find_parent_by_type(node, "impl_item")
local receiver_node = node:field("receiver")[1]
if impl_item_node then
local impl_type_node = impl_item_node:field("type")[1]
class_name = impl_type_node and get_node_text(impl_type_node, parsed.source) or ""
elseif receiver_node then
local type_identifier_node = find_child_by_type(receiver_node, "type_identifier")
class_name = type_identifier_node and get_node_text(type_identifier_node, parsed.source) or ""
else
class_name = get_closest_parent_name(node, parsed.source)
end
local class_def = get_class_def(class_name)
local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier")
local accessibility_modifier = accessibility_modifier_node
and get_node_text(accessibility_modifier_node, parsed.source)
or ""
table.insert(class_def.methods, {
type = "function",
name = name,
params = params,
return_type = return_type,
accessibility_modifier = accessibility_modifier,
})
elseif type == "class_assignment" then
local left_node = node:field("left")[1]
local left = left_node and get_node_text(left_node, parsed.source) or ""
local value_type = get_node_type(node, parsed.source)
local class_name = get_closest_parent_name(node, parsed.source)
if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end
local class_def = get_class_def(class_name)
table.insert(class_def.properties, {
type = "variable",
name = left,
value_type = value_type,
})
elseif type == "class_variable" then
local value_type = get_node_type(node, parsed.source)
local class_name = get_closest_parent_name(node, parsed.source)
if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end
local class_def = get_class_def(class_name)
table.insert(class_def.properties, {
type = "variable",
name = name,
value_type = value_type,
})
elseif type == "function" or type == "arrow_function" then
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
local impl_item_node = find_parent_by_type(node, "impl_item")
if impl_item_node then goto continue end
local function_node = find_parent_by_type(node, "function_declaration")
or find_parent_by_type(node, "function_definition")
if function_node then goto continue end
-- Extract function parameters and return type
local params_node = node:field("parameters")[1]
local params = params_node and get_node_text(params_node, parsed.source) or "()"
local return_type_node = node:field("return_type")[1] or node:field("result")[1]
local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void"
local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier")
local accessibility_modifier = accessibility_modifier_node
and get_node_text(accessibility_modifier_node, parsed.source)
or ""
local def = {
type = "function",
name = name,
params = params,
return_type = return_type,
accessibility_modifier = accessibility_modifier,
}
table.insert(definitions, def)
elseif type == "assignment" then
local impl_item_node = find_parent_by_type(node, "impl_item")
or find_parent_by_type(node, "class_declaration")
or find_parent_by_type(node, "class_definition")
if impl_item_node then goto continue end
local function_node = find_parent_by_type(node, "function_declaration")
or find_parent_by_type(node, "function_definition")
if function_node then goto continue end
local left_node = node:field("left")[1]
local left = left_node and get_node_text(left_node, parsed.source) or ""
if left and filetype == "go" and not Utils.is_first_letter_uppercase(left) then goto continue end
local value_type = get_node_type(node, parsed.source)
local def = {
type = "variable",
name = left,
value_type = value_type,
}
table.insert(definitions, def)
elseif type == "variable" then
local impl_item_node = find_parent_by_type(node, "impl_item")
or find_parent_by_type(node, "class_declaration")
or find_parent_by_type(node, "class_definition")
if impl_item_node then goto continue end
local function_node = find_parent_by_type(node, "function_declaration")
or find_parent_by_type(node, "function_definition")
if function_node then goto continue end
local value_type = get_node_type(node, parsed.source)
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
local def = { type = "variable", name = name, value_type = value_type }
table.insert(definitions, def)
end
::continue::
end
end
for _, def in pairs(class_def_map) do
table.insert(definitions, def)
end
for _, def in pairs(enum_def_map) do
table.insert(definitions, def)
end
return definitions
end
local function stringify_function(def)
local res = "func " .. def.name .. def.params .. ":" .. def.return_type .. ";"
if def.accessibility_modifier and def.accessibility_modifier ~= "" then
res = def.accessibility_modifier .. " " .. res
end
return res
end
local function stringify_variable(def)
local res = "var " .. def.name
if def.value_type and def.value_type ~= "" then res = res .. ":" .. def.value_type end
return res .. ";"
end
local function stringify_enum_item(def)
local res = def.name
if def.value_type and def.value_type ~= "" then res = res .. ":" .. def.value_type end
return res .. ";"
end
-- Function to load file content into a temporary buffer, process it, and then delete the buffer
function RepoMap.stringify_definitions(filepath)
if vim.endswith(filepath, "~") then return "" end
-- Extract definitions
local definitions = RepoMap.extract_definitions(filepath)
local output = ""
-- Print or process the definitions
for _, def in ipairs(definitions) do
if def.type == "class" then
output = output .. def.type .. " " .. def.name .. "{"
for _, property in ipairs(def.properties) do
output = output .. stringify_variable(property)
end
for _, method in ipairs(def.methods) do
output = output .. stringify_function(method)
end
output = output .. "}"
elseif def.type == "enum" then
output = output .. def.type .. " " .. def.name .. "{"
for _, item in ipairs(def.items) do
output = output .. stringify_enum_item(item) .. ""
end
output = output .. "}"
elseif def.type == "function" then
output = output .. stringify_function(def)
elseif def.type == "variable" then
output = output .. stringify_variable(def)
end
end
return output
end
function RepoMap._build_repo_map(project_root, file_ext)
local Utils = require("avante.utils")
local output = {}
local gitignore_path = project_root .. "/.gitignore"
local ignore_patterns, negate_patterns = Utils.parse_gitignore(gitignore_path)
local filepaths = Utils.scan_directory(project_root, ignore_patterns, negate_patterns)
vim.iter(filepaths):each(function(filepath)
if not Utils.is_same_file_ext(file_ext, filepath) then return end
local definitions = RepoMap.stringify_definitions(filepath)
if definitions == "" then return end
table.insert(output, {
path = Utils.relative_path(filepath),
lang = RepoMap.get_filetype_by_filepath(filepath),
defs = definitions,
})
end)
return output
end
local cache = {}
function RepoMap.get_repo_map(file_ext)
file_ext = file_ext or vim.fn.expand("%:e")
local Utils = require("avante.utils")
local project_root = Utils.root.get()
local cache_key = project_root .. "." .. file_ext
local cached = cache[cache_key]
if cached then return cached end
local PPath = require("plenary.path")
local Path = require("avante.path")
local repo_map
local function build_and_save()
repo_map = RepoMap._build_repo_map(project_root, file_ext)
cache[cache_key] = repo_map
Path.repo_map.save(project_root, file_ext, repo_map)
end
repo_map = Path.repo_map.load(project_root, file_ext)
if not repo_map or next(repo_map) == nil then
build_and_save()
if not repo_map then return end
else
local timer = vim.loop.new_timer()
if timer then
timer:start(
0,
0,
vim.schedule_wrap(function()
build_and_save()
timer:close()
end)
)
end
end
local update_repo_map = vim.schedule_wrap(function(rel_filepath)
if rel_filepath and Utils.is_same_file_ext(file_ext, rel_filepath) then
local abs_filepath = PPath:new(project_root):joinpath(rel_filepath):absolute()
local definitions = RepoMap.stringify_definitions(abs_filepath)
if definitions == "" then return end
local found = false
for _, m in ipairs(repo_map) do
if m.path == rel_filepath then
m.defs = definitions
found = true
break
end
end
if not found then
table.insert(repo_map, {
path = Utils.relative_path(abs_filepath),
lang = RepoMap.get_filetype_by_filepath(abs_filepath),
defs = definitions,
})
end
cache[cache_key] = repo_map
Path.repo_map.save(project_root, file_ext, repo_map)
end
end)
local handle = vim.loop.new_fs_event()
if handle then
handle:start(project_root, { recursive = true }, function(err, rel_filepath)
if err then
print("Error watching directory " .. project_root .. ":", err)
return
end
if rel_filepath then update_repo_map(rel_filepath) end
end)
end
vim.api.nvim_create_autocmd({ "BufReadPost", "BufNewFile" }, {
callback = function(ev)
vim.defer_fn(function()
local filepath = vim.api.nvim_buf_get_name(ev.buf)
if not vim.startswith(filepath, project_root) then return end
local rel_filepath = Utils.relative_path(filepath)
update_repo_map(rel_filepath)
end, 0)
end,
})
return repo_map
end
return RepoMap