feat: avante repo map rust crate (#628)
This commit is contained in:
parent
5461342fce
commit
0d90c047ef
12
.github/workflows/ci.yaml
vendored
12
.github/workflows/ci.yaml
vendored
@ -29,6 +29,18 @@ jobs:
|
||||
uses: lunarmodules/luacheck@v1
|
||||
with:
|
||||
args: ./lua/
|
||||
rust-tests:
|
||||
name: Run Rust tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- uses: dtolnay/rust-toolchain@master
|
||||
with:
|
||||
toolchain: stable
|
||||
components: clippy, rustfmt
|
||||
- name: Run rust tests
|
||||
run: cargo test --features luajit
|
||||
rust:
|
||||
name: Check Rust style
|
||||
runs-on: ubuntu-latest
|
||||
|
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@ -74,6 +74,7 @@ jobs:
|
||||
fi
|
||||
cp target/release/libavante_templates.$EXT results/avante_templates.$EXT
|
||||
cp target/release/libavante_tokenizers.$EXT results/avante_tokenizers.$EXT
|
||||
cp target/release/libavante_repo_map.$EXT results/avante_repo_map.$EXT
|
||||
|
||||
cd results
|
||||
tar zcvf avante_lib-${{ matrix.os }}-${{ matrix.feature }}.tar.gz *.${EXT}
|
||||
@ -85,6 +86,7 @@ jobs:
|
||||
|
||||
Copy-Item -Path "target\release\avante_templates.dll" -Destination "results\avante_templates.dll"
|
||||
Copy-Item -Path "target\release\avante_tokenizers.dll" -Destination "results\avante_tokenizers.dll"
|
||||
Copy-Item -Path "target\release\avante_repo_map.dll" -Destination "results\avante_repo_map.dll"
|
||||
|
||||
Set-Location -Path results
|
||||
|
||||
|
133
Cargo.lock
generated
133
Cargo.lock
generated
@ -29,6 +29,27 @@ version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
|
||||
|
||||
[[package]]
|
||||
name = "avante-repo-map"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"minijinja",
|
||||
"mlua",
|
||||
"serde",
|
||||
"tree-sitter",
|
||||
"tree-sitter-c",
|
||||
"tree-sitter-cpp",
|
||||
"tree-sitter-go",
|
||||
"tree-sitter-javascript",
|
||||
"tree-sitter-language",
|
||||
"tree-sitter-lua",
|
||||
"tree-sitter-python",
|
||||
"tree-sitter-ruby",
|
||||
"tree-sitter-rust",
|
||||
"tree-sitter-typescript",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "avante-templates"
|
||||
version = "0.1.0"
|
||||
@ -117,9 +138,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.1.15"
|
||||
version = "1.1.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6"
|
||||
checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0"
|
||||
dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
@ -1275,6 +1296,114 @@ dependencies = [
|
||||
"unicode_categories",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20f4cd3642c47a85052a887d86704f4eac272969f61b686bdd3f772122aabaff"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"regex",
|
||||
"regex-syntax",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-c"
|
||||
version = "0.23.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c8b3fb515e498e258799a31d78e6603767cd6892770d9e2290ec00af5c3ad80b"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-cpp"
|
||||
version = "0.23.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d67e862242878d6ee50e1e5814f267ee3eea0168aea2cdbd700ccfb4c74b6d3"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-go"
|
||||
version = "0.23.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "caf57626e4c9b6d6efaf8a8d5ee1241c5f178ae7bfdf693713ae6a774f01424e"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-javascript"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59e1f62f8babb640b909f30675d1addeb1f17802f2a4d2af287569753b243977"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-language"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2545046bd1473dac6c626659cc2567c6c0ff302fc8b84a56c4243378276f7f57"
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-lua"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5cdb9adf0965fec58e7660cbb3a059dbb12ebeec9459e6dcbae3db004739641e"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-python"
|
||||
version = "0.23.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "65661b1a3e24139e2e54207e47d910ab07e28790d78efc7d5dc3a11ce2a110eb"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-ruby"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ec5ee842e27791e0adffa0b2a177614de51d2a26e5c7e84d014ed7f097e5ed0"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-rust"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cffbbcb780348fbae8395742ae5b34c1fd794e4085d43aac9f259387f9a84dc8"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-typescript"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aecf1585ae2a9dddc2b1d4c0e2140b2ec9876e2a25fd79de47fcf7dae0384685"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typeid"
|
||||
version = "1.0.2"
|
||||
|
@ -11,6 +11,7 @@ version = "0.1.0"
|
||||
[workspace.dependencies]
|
||||
avante-tokenizers = { path = "crates/avante-tokenizers" }
|
||||
avante-templates = { path = "crates/avante-templates" }
|
||||
avante-repo-map = { path = "crates/avante-repo-map" }
|
||||
minijinja = { version = "2.2.0", features = [
|
||||
"loader",
|
||||
"json",
|
||||
|
10
Makefile
10
Makefile
@ -22,13 +22,15 @@ all: luajit
|
||||
define make_definitions
|
||||
ifeq ($(BUILD_FROM_SOURCE),true)
|
||||
ifeq ($(TARGET_LIBRARY), all)
|
||||
$1: $(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT) $(BUILD_DIR)/libAvanteTemplates-$1.$(EXT)
|
||||
$1: $(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT) $(BUILD_DIR)/libAvanteTemplates-$1.$(EXT) $(BUILD_DIR)/libAvanteRepoMap-$1.$(EXT)
|
||||
else ifeq ($(TARGET_LIBRARY), tokenizers)
|
||||
$1: $(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT)
|
||||
else ifeq ($(TARGET_LIBRARY), templates)
|
||||
$1: $(BUILD_DIR)/libAvanteTemplates-$1.$(EXT)
|
||||
else ifeq ($(TARGET_LIBRARY), repo-map)
|
||||
$1: $(BUILD_DIR)/libAvanteRepoMap-$1.$(EXT)
|
||||
else
|
||||
$$(error TARGET_LIBRARY must be one of all, tokenizers, templates)
|
||||
$$(error TARGET_LIBRARY must be one of all, tokenizers, templates, repo-map)
|
||||
endif
|
||||
else
|
||||
$1:
|
||||
@ -41,16 +43,18 @@ $(foreach lua_version,$(LUA_VERSIONS),$(eval $(call make_definitions,$(lua_versi
|
||||
define build_package
|
||||
$1-$2:
|
||||
cargo build --release --features=$1 -p avante-$2
|
||||
cp target/release/libavante_$2.$(EXT) $(BUILD_DIR)/avante_$2.$(EXT)
|
||||
cp target/release/libavante_$(shell echo $2 | tr - _).$(EXT) $(BUILD_DIR)/avante_$(shell echo $2 | tr - _).$(EXT)
|
||||
endef
|
||||
|
||||
define build_targets
|
||||
$(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT): $(BUILD_DIR) $1-tokenizers
|
||||
$(BUILD_DIR)/libAvanteTemplates-$1.$(EXT): $(BUILD_DIR) $1-templates
|
||||
$(BUILD_DIR)/libAvanteRepoMap-$1.$(EXT): $(BUILD_DIR) $1-repo-map
|
||||
endef
|
||||
|
||||
$(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_package,$(lua_version),tokenizers)))
|
||||
$(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_package,$(lua_version),templates)))
|
||||
$(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_package,$(lua_version),repo-map)))
|
||||
$(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_targets,$(lua_version))))
|
||||
|
||||
$(BUILD_DIR):
|
||||
|
38
crates/avante-repo-map/Cargo.toml
Normal file
38
crates/avante-repo-map/Cargo.toml
Normal file
@ -0,0 +1,38 @@
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[package]
|
||||
name = "avante-repo-map"
|
||||
edition.workspace = true
|
||||
rust-version.workspace = true
|
||||
license.workspace = true
|
||||
version.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
cc="*"
|
||||
|
||||
[dependencies]
|
||||
mlua = { workspace = true }
|
||||
minijinja = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
tree-sitter = "0.23"
|
||||
tree-sitter-language = "0.1"
|
||||
tree-sitter-rust = "0.23"
|
||||
tree-sitter-python = "0.23"
|
||||
tree-sitter-javascript = "0.23"
|
||||
tree-sitter-typescript = "0.23"
|
||||
tree-sitter-go = "0.23"
|
||||
tree-sitter-c = "0.23"
|
||||
tree-sitter-cpp = "0.23"
|
||||
tree-sitter-lua = "0.2"
|
||||
tree-sitter-ruby = "0.23"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
lua51 = ["mlua/lua51"]
|
||||
lua52 = ["mlua/lua52"]
|
||||
lua53 = ["mlua/lua53"]
|
||||
lua54 = ["mlua/lua54"]
|
||||
luajit = ["mlua/luajit"]
|
11
crates/avante-repo-map/queries/tree-sitter-c-defs.scm
Normal file
11
crates/avante-repo-map/queries/tree-sitter-c-defs.scm
Normal file
@ -0,0 +1,11 @@
|
||||
;; Capture extern functions, variables, public classes, and methods
|
||||
(function_definition
|
||||
(storage_class_specifier) @extern
|
||||
) @function
|
||||
(class_specifier
|
||||
(public) @class
|
||||
(function_definition) @method
|
||||
) @class
|
||||
(declaration
|
||||
(storage_class_specifier) @extern
|
||||
) @variable
|
11
crates/avante-repo-map/queries/tree-sitter-cpp-defs.scm
Normal file
11
crates/avante-repo-map/queries/tree-sitter-cpp-defs.scm
Normal file
@ -0,0 +1,11 @@
|
||||
;; Capture extern functions, variables, public classes, and methods
|
||||
(function_definition
|
||||
(storage_class_specifier) @extern
|
||||
) @function
|
||||
(class_specifier
|
||||
(public) @class
|
||||
(function_definition) @method
|
||||
) @class
|
||||
(declaration
|
||||
(storage_class_specifier) @extern
|
||||
) @variable
|
18
crates/avante-repo-map/queries/tree-sitter-go-defs.scm
Normal file
18
crates/avante-repo-map/queries/tree-sitter-go-defs.scm
Normal file
@ -0,0 +1,18 @@
|
||||
;; Capture top-level functions and struct definitions
|
||||
(var_declaration
|
||||
(var_spec) @variable
|
||||
)
|
||||
(const_declaration
|
||||
(const_spec) @variable
|
||||
)
|
||||
(function_declaration) @function
|
||||
(type_declaration
|
||||
(type_spec (struct_type)) @class
|
||||
)
|
||||
(type_declaration
|
||||
(type_spec
|
||||
(struct_type
|
||||
(field_declaration_list
|
||||
(field_declaration) @class_variable)))
|
||||
)
|
||||
(method_declaration) @method
|
@ -0,0 +1,23 @@
|
||||
;; Capture exported functions, arrow functions, variables, classes, and method definitions
|
||||
(export_statement
|
||||
declaration: (lexical_declaration
|
||||
(variable_declarator) @variable
|
||||
)
|
||||
)
|
||||
(export_statement
|
||||
declaration: (function_declaration) @function
|
||||
)
|
||||
(export_statement
|
||||
declaration: (class_declaration
|
||||
body: (class_body
|
||||
(field_definition) @class_variable
|
||||
)
|
||||
)
|
||||
)
|
||||
(export_statement
|
||||
declaration: (class_declaration
|
||||
body: (class_body
|
||||
(method_definition) @method
|
||||
)
|
||||
)
|
||||
)
|
3
crates/avante-repo-map/queries/tree-sitter-lua-defs.scm
Normal file
3
crates/avante-repo-map/queries/tree-sitter-lua-defs.scm
Normal file
@ -0,0 +1,3 @@
|
||||
;; Capture function and method definitions
|
||||
(variable_list) @variable
|
||||
(function_declaration) @function
|
25
crates/avante-repo-map/queries/tree-sitter-python-defs.scm
Normal file
25
crates/avante-repo-map/queries/tree-sitter-python-defs.scm
Normal file
@ -0,0 +1,25 @@
|
||||
;; Capture top-level functions, class, and method definitions
|
||||
(module
|
||||
(expression_statement
|
||||
(assignment) @assignment
|
||||
)
|
||||
)
|
||||
(module
|
||||
(function_definition) @function
|
||||
)
|
||||
(module
|
||||
(class_definition
|
||||
body: (block
|
||||
(expression_statement
|
||||
(assignment) @class_assignment
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
(module
|
||||
(class_definition
|
||||
body: (block
|
||||
(function_definition) @method
|
||||
)
|
||||
)
|
||||
)
|
16
crates/avante-repo-map/queries/tree-sitter-ruby-defs.scm
Normal file
16
crates/avante-repo-map/queries/tree-sitter-ruby-defs.scm
Normal file
@ -0,0 +1,16 @@
|
||||
;; Capture top-level methods, class definitions, and methods within classes
|
||||
(program
|
||||
(class
|
||||
(body_statement
|
||||
(call) @class_call
|
||||
(assignment) @class_assignment
|
||||
(method) @method
|
||||
)
|
||||
) @class
|
||||
)
|
||||
(program
|
||||
(method) @function
|
||||
)
|
||||
(program
|
||||
(assignment) @assignment
|
||||
)
|
20
crates/avante-repo-map/queries/tree-sitter-rust-defs.scm
Normal file
20
crates/avante-repo-map/queries/tree-sitter-rust-defs.scm
Normal file
@ -0,0 +1,20 @@
|
||||
;; Capture public functions, structs, methods, and variable definitions
|
||||
(function_item) @function
|
||||
(impl_item
|
||||
body: (declaration_list
|
||||
(function_item) @method
|
||||
)
|
||||
)
|
||||
(struct_item) @class
|
||||
(struct_item
|
||||
body: (field_declaration_list
|
||||
(field_declaration) @class_variable
|
||||
)
|
||||
)
|
||||
(enum_item
|
||||
body: (enum_variant_list
|
||||
(enum_variant) @enum_item
|
||||
)
|
||||
)
|
||||
(const_item) @variable
|
||||
(static_item) @variable
|
@ -0,0 +1,33 @@
|
||||
;; Capture exported functions, arrow functions, variables, classes, and method definitions
|
||||
(export_statement
|
||||
declaration: (lexical_declaration
|
||||
(variable_declarator) @variable
|
||||
)
|
||||
)
|
||||
(export_statement
|
||||
declaration: (function_declaration) @function
|
||||
)
|
||||
(export_statement
|
||||
declaration: (class_declaration
|
||||
body: (class_body
|
||||
(public_field_definition) @class_variable
|
||||
)
|
||||
)
|
||||
)
|
||||
(interface_declaration
|
||||
body: (interface_body
|
||||
(property_signature) @class_variable
|
||||
)
|
||||
)
|
||||
(type_alias_declaration
|
||||
value: (object_type
|
||||
(property_signature) @class_variable
|
||||
)
|
||||
)
|
||||
(export_statement
|
||||
declaration: (class_declaration
|
||||
body: (class_body
|
||||
(method_definition) @method
|
||||
)
|
||||
)
|
||||
)
|
894
crates/avante-repo-map/src/lib.rs
Normal file
894
crates/avante-repo-map/src/lib.rs
Normal file
@ -0,0 +1,894 @@
|
||||
use mlua::prelude::*;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use tree_sitter::{Node, Parser, Query, QueryCursor};
|
||||
use tree_sitter_language::LanguageFn;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Func {
|
||||
pub name: String,
|
||||
pub params: String,
|
||||
pub return_type: String,
|
||||
pub accessibility_modifier: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Class {
|
||||
pub name: String,
|
||||
pub methods: Vec<Func>,
|
||||
pub properties: Vec<Variable>,
|
||||
pub visibility_modifier: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Enum {
|
||||
pub name: String,
|
||||
pub items: Vec<Variable>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Variable {
|
||||
pub name: String,
|
||||
pub value_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Definition {
|
||||
Func(Func),
|
||||
Class(Class),
|
||||
Enum(Enum),
|
||||
Variable(Variable),
|
||||
}
|
||||
|
||||
fn get_ts_language(language: &str) -> Result<LanguageFn, String> {
|
||||
match language {
|
||||
"rust" => Ok(tree_sitter_rust::LANGUAGE),
|
||||
"python" => Ok(tree_sitter_python::LANGUAGE),
|
||||
"javascript" => Ok(tree_sitter_javascript::LANGUAGE),
|
||||
"typescript" => Ok(tree_sitter_typescript::LANGUAGE_TSX),
|
||||
"go" => Ok(tree_sitter_go::LANGUAGE),
|
||||
"c" => Ok(tree_sitter_c::LANGUAGE),
|
||||
"cpp" => Ok(tree_sitter_cpp::LANGUAGE),
|
||||
"lua" => Ok(tree_sitter_lua::LANGUAGE),
|
||||
"ruby" => Ok(tree_sitter_ruby::LANGUAGE),
|
||||
_ => Err(format!("Unsupported language: {language}")),
|
||||
}
|
||||
}
|
||||
|
||||
const C_QUERY: &str = include_str!("../queries/tree-sitter-c-defs.scm");
|
||||
const CPP_QUERY: &str = include_str!("../queries/tree-sitter-cpp-defs.scm");
|
||||
const GO_QUERY: &str = include_str!("../queries/tree-sitter-go-defs.scm");
|
||||
const JAVASCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-javascript-defs.scm");
|
||||
const LUA_QUERY: &str = include_str!("../queries/tree-sitter-lua-defs.scm");
|
||||
const PYTHON_QUERY: &str = include_str!("../queries/tree-sitter-python-defs.scm");
|
||||
const RUST_QUERY: &str = include_str!("../queries/tree-sitter-rust-defs.scm");
|
||||
const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-defs.scm");
|
||||
const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm");
|
||||
|
||||
fn get_definitions_query(language: &str) -> Result<Query, String> {
|
||||
let ts_language = get_ts_language(language)?;
|
||||
let contents = match language {
|
||||
"c" => C_QUERY,
|
||||
"cpp" => CPP_QUERY,
|
||||
"go" => GO_QUERY,
|
||||
"javascript" => JAVASCRIPT_QUERY,
|
||||
"lua" => LUA_QUERY,
|
||||
"python" => PYTHON_QUERY,
|
||||
"rust" => RUST_QUERY,
|
||||
"typescript" => TYPESCRIPT_QUERY,
|
||||
"ruby" => RUBY_QUERY,
|
||||
_ => return Err(format!("Unsupported language: {language}")),
|
||||
};
|
||||
let query = Query::new(&ts_language.into(), contents)
|
||||
.unwrap_or_else(|_| panic!("Failed to parse query for {language}"));
|
||||
Ok(query)
|
||||
}
|
||||
|
||||
fn get_closest_ancestor_name(node: &Node, source: &str) -> String {
|
||||
let mut parent = node.parent();
|
||||
while let Some(parent_node) = parent {
|
||||
let name_node = parent_node.child_by_field_name("name");
|
||||
if let Some(name_node) = name_node {
|
||||
return get_node_text(&name_node, source.as_bytes()).to_string();
|
||||
}
|
||||
parent = parent_node.parent();
|
||||
}
|
||||
String::new()
|
||||
}
|
||||
|
||||
fn find_ancestor_by_type<'a>(node: &'a Node, parent_type: &str) -> Option<Node<'a>> {
|
||||
let mut parent = node.parent();
|
||||
while let Some(parent_node) = parent {
|
||||
if parent_node.kind() == parent_type {
|
||||
return Some(parent_node);
|
||||
}
|
||||
parent = parent_node.parent();
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn find_descendant_by_type<'a>(node: &'a Node, child_type: &str) -> Option<Node<'a>> {
|
||||
let mut cursor = node.walk();
|
||||
for i in 0..node.descendant_count() {
|
||||
cursor.goto_descendant(i);
|
||||
let node = cursor.node();
|
||||
if node.kind() == child_type {
|
||||
return Some(node);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn find_child_by_type<'a>(node: &'a Node, child_type: &str) -> Option<Node<'a>> {
|
||||
node.children(&mut node.walk())
|
||||
.find(|child| child.kind() == child_type)
|
||||
}
|
||||
|
||||
fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String {
|
||||
node.utf8_text(source).unwrap_or_default().to_string()
|
||||
}
|
||||
|
||||
fn get_node_type<'a>(node: &'a Node, source: &'a [u8]) -> String {
|
||||
let predefined_type_node = find_descendant_by_type(node, "predefined_type");
|
||||
if let Some(type_node) = predefined_type_node {
|
||||
return type_node.utf8_text(source).unwrap().to_string();
|
||||
}
|
||||
let value_type_node = node.child_by_field_name("type");
|
||||
value_type_node
|
||||
.map(|n| n.utf8_text(source).unwrap().to_string())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn is_first_letter_uppercase(name: &str) -> bool {
|
||||
if name.is_empty() {
|
||||
return false;
|
||||
}
|
||||
name.chars().next().unwrap().is_uppercase()
|
||||
}
|
||||
|
||||
// Given a language, parse the given source code and return exported definitions
|
||||
fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>, String> {
|
||||
let ts_language = get_ts_language(language)?;
|
||||
|
||||
let mut definitions = Vec::new();
|
||||
let mut parser = Parser::new();
|
||||
parser
|
||||
.set_language(&ts_language.into())
|
||||
.unwrap_or_else(|_| panic!("Failed to set language for {language}"));
|
||||
let tree = parser
|
||||
.parse(source, None)
|
||||
.unwrap_or_else(|| panic!("Failed to parse source code for {language}"));
|
||||
let root_node = tree.root_node();
|
||||
|
||||
let query = get_definitions_query(language)?;
|
||||
let mut query_cursor = QueryCursor::new();
|
||||
let captures = query_cursor.captures(&query, root_node, source.as_bytes());
|
||||
|
||||
let mut class_def_map: HashMap<String, RefCell<Class>> = HashMap::new();
|
||||
let mut enum_def_map: HashMap<String, RefCell<Enum>> = HashMap::new();
|
||||
|
||||
let ensure_class_def = |name: &str, class_def_map: &mut HashMap<String, RefCell<Class>>| {
|
||||
class_def_map.entry(name.to_string()).or_insert_with(|| {
|
||||
RefCell::new(Class {
|
||||
name: name.to_string(),
|
||||
methods: vec![],
|
||||
properties: vec![],
|
||||
visibility_modifier: None,
|
||||
})
|
||||
});
|
||||
};
|
||||
|
||||
let ensure_enum_def = |name: &str, enum_def_map: &mut HashMap<String, RefCell<Enum>>| {
|
||||
enum_def_map.entry(name.to_string()).or_insert_with(|| {
|
||||
RefCell::new(Enum {
|
||||
name: name.to_string(),
|
||||
items: vec![],
|
||||
})
|
||||
});
|
||||
};
|
||||
|
||||
for (m, _) in captures {
|
||||
for capture in m.captures {
|
||||
let capture_name = &query.capture_names()[capture.index as usize];
|
||||
let node = capture.node;
|
||||
let name_node = node.child_by_field_name("name");
|
||||
let name = name_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
match *capture_name {
|
||||
"class" => {
|
||||
if !name.is_empty() {
|
||||
if language == "go" && !is_first_letter_uppercase(name) {
|
||||
continue;
|
||||
}
|
||||
ensure_class_def(name, &mut class_def_map);
|
||||
let visibility_modifier_node =
|
||||
find_child_by_type(&node, "visibility_modifier");
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
let class_def = class_def_map.get_mut(name).unwrap();
|
||||
class_def.borrow_mut().visibility_modifier =
|
||||
if visibility_modifier.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(visibility_modifier.to_string())
|
||||
};
|
||||
}
|
||||
}
|
||||
"enum_item" => {
|
||||
let visibility_modifier_node =
|
||||
find_descendant_by_type(&node, "visibility_modifier");
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
if language == "rust" && !visibility_modifier.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
let enum_name = get_closest_ancestor_name(&node, source);
|
||||
if !enum_name.is_empty()
|
||||
&& language == "go"
|
||||
&& !is_first_letter_uppercase(&enum_name)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
ensure_enum_def(&enum_name, &mut enum_def_map);
|
||||
let enum_def = enum_def_map.get_mut(&enum_name).unwrap();
|
||||
let enum_type_node = find_descendant_by_type(&node, "type_identifier");
|
||||
let enum_type = enum_type_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
let variable = Variable {
|
||||
name: name.to_string(),
|
||||
value_type: enum_type.to_string(),
|
||||
};
|
||||
enum_def.borrow_mut().items.push(variable);
|
||||
}
|
||||
"method" => {
|
||||
let visibility_modifier_node =
|
||||
find_descendant_by_type(&node, "visibility_modifier");
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
if language == "rust" && !visibility_modifier.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
|
||||
continue;
|
||||
}
|
||||
let params_node = node.child_by_field_name("parameters");
|
||||
let params = params_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("()");
|
||||
let mut return_type_node = node.child_by_field_name("return_type");
|
||||
if return_type_node.is_none() {
|
||||
return_type_node = node.child_by_field_name("result");
|
||||
}
|
||||
let mut return_type = "void".to_string();
|
||||
if return_type_node.is_some() {
|
||||
return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes());
|
||||
if return_type.is_empty() {
|
||||
return_type = return_type_node
|
||||
.unwrap()
|
||||
.utf8_text(source.as_bytes())
|
||||
.unwrap_or("void")
|
||||
.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
let impl_item_node = find_ancestor_by_type(&node, "impl_item");
|
||||
let receiver_node = node.child_by_field_name("receiver");
|
||||
let class_name = if let Some(impl_item) = impl_item_node {
|
||||
let impl_type_node = impl_item.child_by_field_name("type");
|
||||
impl_type_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
} else if let Some(receiver) = receiver_node {
|
||||
let type_identifier_node =
|
||||
find_descendant_by_type(&receiver, "type_identifier");
|
||||
type_identifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
} else {
|
||||
get_closest_ancestor_name(&node, source).to_string()
|
||||
};
|
||||
|
||||
if language == "go" && !is_first_letter_uppercase(&class_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ensure_class_def(&class_name, &mut class_def_map);
|
||||
let class_def = class_def_map.get_mut(&class_name).unwrap();
|
||||
|
||||
let accessibility_modifier_node =
|
||||
find_descendant_by_type(&node, "accessibility_modifier");
|
||||
let accessibility_modifier = accessibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
|
||||
let func = Func {
|
||||
name: name.to_string(),
|
||||
params: params.to_string(),
|
||||
return_type: return_type.to_string(),
|
||||
accessibility_modifier: if accessibility_modifier.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(accessibility_modifier.to_string())
|
||||
},
|
||||
};
|
||||
class_def.borrow_mut().methods.push(func);
|
||||
}
|
||||
"class_assignment" => {
|
||||
let visibility_modifier_node =
|
||||
find_descendant_by_type(&node, "visibility_modifier");
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
if language == "rust" && !visibility_modifier.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
let left_node = node.child_by_field_name("left");
|
||||
let left = left_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
let value_type = get_node_type(&node, source.as_bytes());
|
||||
let class_name = get_closest_ancestor_name(&node, source);
|
||||
if !class_name.is_empty()
|
||||
&& language == "go"
|
||||
&& !is_first_letter_uppercase(&class_name)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if class_name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
ensure_class_def(&class_name, &mut class_def_map);
|
||||
let class_def = class_def_map.get_mut(&class_name).unwrap();
|
||||
let variable = Variable {
|
||||
name: left.to_string(),
|
||||
value_type: value_type.to_string(),
|
||||
};
|
||||
class_def.borrow_mut().properties.push(variable);
|
||||
}
|
||||
"class_variable" => {
|
||||
let visibility_modifier_node =
|
||||
find_descendant_by_type(&node, "visibility_modifier");
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
if language == "rust" && !visibility_modifier.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
let value_type = get_node_type(&node, source.as_bytes());
|
||||
let class_name = get_closest_ancestor_name(&node, source);
|
||||
if !class_name.is_empty()
|
||||
&& language == "go"
|
||||
&& !is_first_letter_uppercase(&class_name)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if class_name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
|
||||
continue;
|
||||
}
|
||||
ensure_class_def(&class_name, &mut class_def_map);
|
||||
let class_def = class_def_map.get_mut(&class_name).unwrap();
|
||||
let variable = Variable {
|
||||
name: name.to_string(),
|
||||
value_type: value_type.to_string(),
|
||||
};
|
||||
class_def.borrow_mut().properties.push(variable);
|
||||
}
|
||||
"function" | "arrow_function" => {
|
||||
let visibility_modifier_node =
|
||||
find_descendant_by_type(&node, "visibility_modifier");
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
if language == "rust" && !visibility_modifier.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
|
||||
continue;
|
||||
}
|
||||
let impl_item_node = find_ancestor_by_type(&node, "impl_item");
|
||||
if impl_item_node.is_some() {
|
||||
continue;
|
||||
}
|
||||
let function_node = find_ancestor_by_type(&node, "function_declaration")
|
||||
.or_else(|| find_ancestor_by_type(&node, "function_definition"));
|
||||
if function_node.is_some() {
|
||||
continue;
|
||||
}
|
||||
let params_node = node.child_by_field_name("parameters");
|
||||
let params = params_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("()");
|
||||
let mut return_type_node = node.child_by_field_name("return_type");
|
||||
if return_type_node.is_none() {
|
||||
return_type_node = node.child_by_field_name("result");
|
||||
}
|
||||
let mut return_type = "void".to_string();
|
||||
if return_type_node.is_some() {
|
||||
return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes());
|
||||
if return_type.is_empty() {
|
||||
return_type = return_type_node
|
||||
.unwrap()
|
||||
.utf8_text(source.as_bytes())
|
||||
.unwrap_or("void")
|
||||
.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
let accessibility_modifier_node =
|
||||
find_descendant_by_type(&node, "accessibility_modifier");
|
||||
let accessibility_modifier = accessibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
|
||||
let func = Func {
|
||||
name: name.to_string(),
|
||||
params: params.to_string(),
|
||||
return_type: return_type.to_string(),
|
||||
accessibility_modifier: if accessibility_modifier.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(accessibility_modifier.to_string())
|
||||
},
|
||||
};
|
||||
definitions.push(Definition::Func(func));
|
||||
}
|
||||
"assignment" => {
|
||||
let visibility_modifier_node =
|
||||
find_descendant_by_type(&node, "visibility_modifier");
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
if language == "rust" && !visibility_modifier.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
let impl_item_node = find_ancestor_by_type(&node, "impl_item")
|
||||
.or_else(|| find_ancestor_by_type(&node, "class_declaration"))
|
||||
.or_else(|| find_ancestor_by_type(&node, "class_definition"));
|
||||
if impl_item_node.is_some() {
|
||||
continue;
|
||||
}
|
||||
let function_node = find_ancestor_by_type(&node, "function_declaration")
|
||||
.or_else(|| find_ancestor_by_type(&node, "function_definition"));
|
||||
if function_node.is_some() {
|
||||
continue;
|
||||
}
|
||||
let left_node = node.child_by_field_name("left");
|
||||
let left = left_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
if !left.is_empty() && language == "go" && !is_first_letter_uppercase(left) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let value_type = get_node_type(&node, source.as_bytes());
|
||||
let variable = Variable {
|
||||
name: left.to_string(),
|
||||
value_type: value_type.to_string(),
|
||||
};
|
||||
definitions.push(Definition::Variable(variable));
|
||||
}
|
||||
"variable" => {
|
||||
let visibility_modifier_node =
|
||||
find_descendant_by_type(&node, "visibility_modifier");
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
if language == "rust" && !visibility_modifier.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
let impl_item_node = find_ancestor_by_type(&node, "impl_item")
|
||||
.or_else(|| find_ancestor_by_type(&node, "class_declaration"))
|
||||
.or_else(|| find_ancestor_by_type(&node, "class_definition"));
|
||||
if impl_item_node.is_some() {
|
||||
continue;
|
||||
}
|
||||
let function_node = find_ancestor_by_type(&node, "function_declaration")
|
||||
.or_else(|| find_ancestor_by_type(&node, "function_definition"));
|
||||
if function_node.is_some() {
|
||||
continue;
|
||||
}
|
||||
let value_node = node.child_by_field_name("value");
|
||||
if value_node.is_some() {
|
||||
let value_type = value_node.unwrap().kind();
|
||||
if value_type == "arrow_function" {
|
||||
let params_node = value_node.unwrap().child_by_field_name("parameters");
|
||||
let params = params_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("()");
|
||||
let mut return_type = "void".to_string();
|
||||
let return_type_node =
|
||||
value_node.unwrap().child_by_field_name("return_type");
|
||||
if return_type_node.is_some() {
|
||||
return_type =
|
||||
get_node_type(&return_type_node.unwrap(), source.as_bytes());
|
||||
}
|
||||
let func = Func {
|
||||
name: name.to_string(),
|
||||
params: params.to_string(),
|
||||
return_type,
|
||||
accessibility_modifier: None,
|
||||
};
|
||||
definitions.push(Definition::Func(func));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let value_type = get_node_type(&node, source.as_bytes());
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
|
||||
continue;
|
||||
}
|
||||
let variable = Variable {
|
||||
name: name.to_string(),
|
||||
value_type: value_type.to_string(),
|
||||
};
|
||||
definitions.push(Definition::Variable(variable));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (_, def) in class_def_map {
|
||||
let class_def = def.into_inner();
|
||||
if language == "rust" {
|
||||
if let Some(visibility_modifier) = &class_def.visibility_modifier {
|
||||
if visibility_modifier.contains("pub") {
|
||||
definitions.push(Definition::Class(class_def));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
definitions.push(Definition::Class(class_def));
|
||||
}
|
||||
}
|
||||
|
||||
for (_, def) in enum_def_map {
|
||||
definitions.push(Definition::Enum(def.into_inner()));
|
||||
}
|
||||
|
||||
Ok(definitions)
|
||||
}
|
||||
|
||||
fn stringify_function(func: &Func) -> String {
|
||||
let mut res = format!("func {}", func.name);
|
||||
if func.params.is_empty() {
|
||||
res = format!("{res}()");
|
||||
} else {
|
||||
res = format!("{res}{}", func.params);
|
||||
}
|
||||
if !func.return_type.is_empty() {
|
||||
res = format!("{res} -> {}", func.return_type);
|
||||
}
|
||||
if let Some(modifier) = &func.accessibility_modifier {
|
||||
res = format!("{modifier} {res}");
|
||||
}
|
||||
format!("{res};")
|
||||
}
|
||||
|
||||
fn stringify_variable(variable: &Variable) -> String {
|
||||
let mut res = format!("var {}", variable.name);
|
||||
if !variable.value_type.is_empty() {
|
||||
res = format!("{res}:{}", variable.value_type);
|
||||
}
|
||||
format!("{res};")
|
||||
}
|
||||
|
||||
fn stringify_enum_item(item: &Variable) -> String {
|
||||
let mut res = item.name.clone();
|
||||
if !item.value_type.is_empty() {
|
||||
res = format!("{res}:{}", item.value_type);
|
||||
}
|
||||
format!("{res};")
|
||||
}
|
||||
|
||||
fn stringify_class(class: &Class) -> String {
|
||||
let mut res = format!("class {}{{", class.name);
|
||||
for method in &class.methods {
|
||||
let method_str = stringify_function(method);
|
||||
res = format!("{res}{method_str}");
|
||||
}
|
||||
for property in &class.properties {
|
||||
let property_str = stringify_variable(property);
|
||||
res = format!("{res}{property_str}");
|
||||
}
|
||||
format!("{res}}};")
|
||||
}
|
||||
|
||||
fn stringify_enum(enum_def: &Enum) -> String {
|
||||
let mut res = format!("enum {}{{", enum_def.name);
|
||||
for item in &enum_def.items {
|
||||
let item_str = stringify_enum_item(item);
|
||||
res = format!("{res}{item_str}");
|
||||
}
|
||||
format!("{res}}};")
|
||||
}
|
||||
|
||||
fn stringify_definitions(definitions: &Vec<Definition>) -> String {
|
||||
let mut res = String::new();
|
||||
for definition in definitions {
|
||||
match definition {
|
||||
Definition::Class(class) => res = format!("{res}{}", stringify_class(class)),
|
||||
Definition::Enum(enum_def) => res = format!("{res}{}", stringify_enum(enum_def)),
|
||||
Definition::Func(func) => res = format!("{res}{}", stringify_function(func)),
|
||||
Definition::Variable(variable) => {
|
||||
let variable_str = stringify_variable(variable);
|
||||
res = format!("{res}{variable_str}");
|
||||
}
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
pub fn get_definitions_string(language: &str, source: &str) -> LuaResult<String> {
|
||||
let definitions =
|
||||
extract_definitions(language, source).map_err(|e| LuaError::RuntimeError(e.to_string()))?;
|
||||
let stringified = stringify_definitions(&definitions);
|
||||
Ok(stringified)
|
||||
}
|
||||
|
||||
#[mlua::lua_module]
|
||||
fn avante_repo_map(lua: &Lua) -> LuaResult<LuaTable> {
|
||||
let exports = lua.create_table()?;
|
||||
exports.set(
|
||||
"stringify_definitions",
|
||||
lua.create_function(move |_, (language, source): (String, String)| {
|
||||
get_definitions_string(language.as_str(), source.as_str())
|
||||
})?,
|
||||
)?;
|
||||
Ok(exports)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_rust() {
|
||||
let source = r#"
|
||||
// This is a test comment
|
||||
pub const TEST_CONST: u32 = 1;
|
||||
pub static TEST_STATIC: u32 = 2;
|
||||
const INNER_TEST_CONST: u32 = 3;
|
||||
static INNER_TEST_STATIC: u32 = 4;
|
||||
pub(crate) struct TestStruct {
|
||||
pub test_field: String,
|
||||
inner_test_field: String,
|
||||
}
|
||||
impl TestStruct {
|
||||
pub fn test_method(&self, a: u32, b: u32) -> u32 {
|
||||
a + b
|
||||
}
|
||||
fn inner_test_method(&self, a: u32, b: u32) -> u32 {
|
||||
a + b
|
||||
}
|
||||
}
|
||||
struct InnerTestStruct {
|
||||
pub test_field: String,
|
||||
inner_test_field: String,
|
||||
}
|
||||
impl InnerTestStruct {
|
||||
pub fn test_method(&self, a: u32, b: u32) -> u32 {
|
||||
a + b
|
||||
}
|
||||
fn inner_test_method(&self, a: u32, b: u32) -> u32 {
|
||||
a + b
|
||||
}
|
||||
}
|
||||
pub enum TestEnum {
|
||||
TestEnumField1,
|
||||
TestEnumField2,
|
||||
}
|
||||
enum InnerTestEnum {
|
||||
InnerTestEnumField1,
|
||||
InnerTestEnumField2,
|
||||
}
|
||||
pub fn test_fn(a: u32, b: u32) -> u32 {
|
||||
a + b
|
||||
}
|
||||
fn inner_test_fn(a: u32, b: u32) -> u32 {
|
||||
a + b
|
||||
}
|
||||
"#;
|
||||
let definitions = extract_definitions("rust", source).unwrap();
|
||||
let stringified = stringify_definitions(&definitions);
|
||||
println!("{stringified}");
|
||||
let expected = "var TEST_CONST:u32;var TEST_STATIC:u32;func test_fn(a: u32, b: u32) -> u32;class TestStruct{func test_method(&self, a: u32, b: u32) -> u32;var test_field:String;};";
|
||||
assert_eq!(stringified, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_go() {
|
||||
let source = r#"
|
||||
// This is a test comment
|
||||
package main
|
||||
import "fmt"
|
||||
const TestConst string = "test"
|
||||
const innerTestConst string = "test"
|
||||
var TestVar string
|
||||
var innerTestVar string
|
||||
type TestStruct struct {
|
||||
TestField string
|
||||
innerTestField string
|
||||
}
|
||||
func (t *TestStruct) TestMethod(a int, b int) (int, error) {
|
||||
return a + b, nil
|
||||
}
|
||||
func (t *TestStruct) innerTestMethod(a int, b int) (int, error) {
|
||||
return a + b, nil
|
||||
}
|
||||
type innerTestStruct struct {
|
||||
innerTestField string
|
||||
}
|
||||
func (t *innerTestStruct) testMethod(a int, b int) (int, error) {
|
||||
return a + b, nil
|
||||
}
|
||||
func (t *innerTestStruct) innerTestMethod(a int, b int) (int, error) {
|
||||
return a + b, nil
|
||||
}
|
||||
func TestFunc(a int, b int) (int, error) {
|
||||
return a + b, nil
|
||||
}
|
||||
func innerTestFunc(a int, b int) (int, error) {
|
||||
return a + b, nil
|
||||
}
|
||||
"#;
|
||||
let definitions = extract_definitions("go", source).unwrap();
|
||||
let stringified = stringify_definitions(&definitions);
|
||||
println!("{stringified}");
|
||||
let expected = "var TestConst:string;var TestVar:string;func TestFunc(a int, b int) -> (int, error);class TestStruct{func TestMethod(a int, b int) -> (int, error);var TestField:string;};";
|
||||
assert_eq!(stringified, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_python() {
|
||||
let source = r#"
|
||||
# This is a test comment
|
||||
test_var: str = "test"
|
||||
class TestClass:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
def test_method(self, a: int, b: int) -> int:
|
||||
return a + b
|
||||
def test_func(a: int, b: int) -> int:
|
||||
return a + b
|
||||
"#;
|
||||
let definitions = extract_definitions("python", source).unwrap();
|
||||
let stringified = stringify_definitions(&definitions);
|
||||
println!("{stringified}");
|
||||
let expected = "var test_var:str;func test_func(a: int, b: int) -> int;class TestClass{func __init__(self, a, b) -> void;func test_method(self, a: int, b: int) -> int;};";
|
||||
assert_eq!(stringified, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_typescript() {
|
||||
let source = r#"
|
||||
// This is a test comment
|
||||
export const testVar: string = "test";
|
||||
const innerTestVar: string = "test";
|
||||
export class TestClass {
|
||||
a: number;
|
||||
b: number;
|
||||
constructor(a: number, b: number) {
|
||||
this.a = a;
|
||||
this.b = b;
|
||||
}
|
||||
testMethod(a: number, b: number): number {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
class InnerTestClass {
|
||||
a: number;
|
||||
b: number;
|
||||
}
|
||||
export function testFunc(a: number, b: number) {
|
||||
return a + b;
|
||||
}
|
||||
export const testFunc2 = (a: number, b: number) => {
|
||||
return a + b;
|
||||
}
|
||||
export const testFunc3 = (a: number, b: number): number => {
|
||||
return a + b;
|
||||
}
|
||||
function innerTestFunc(a: number, b: number) {
|
||||
return a + b;
|
||||
}
|
||||
"#;
|
||||
let definitions = extract_definitions("typescript", source).unwrap();
|
||||
let stringified = stringify_definitions(&definitions);
|
||||
println!("{stringified}");
|
||||
let expected = "var testVar:string;func testFunc(a: number, b: number) -> void;func testFunc2(a: number, b: number) -> void;func testFunc3(a: number, b: number) -> number;class TestClass{func constructor(a: number, b: number) -> void;func testMethod(a: number, b: number) -> number;var a:number;var b:number;};"
|
||||
;
|
||||
assert_eq!(stringified, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_javascript() {
|
||||
let source = r#"
|
||||
// This is a test comment
|
||||
export const testVar = "test";
|
||||
const innerTestVar = "test";
|
||||
export class TestClass {
|
||||
constructor(a, b) {
|
||||
this.a = a;
|
||||
this.b = b;
|
||||
}
|
||||
testMethod(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
class InnerTestClass {
|
||||
constructor(a, b) {
|
||||
this.a = a;
|
||||
this.b = b;
|
||||
}
|
||||
}
|
||||
export const testFunc = function(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
export const testFunc2 = (a, b) => {
|
||||
return a + b;
|
||||
}
|
||||
export const testFunc3 = (a, b) => a + b;
|
||||
function innerTestFunc(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
"#;
|
||||
let definitions = extract_definitions("javascript", source).unwrap();
|
||||
let stringified = stringify_definitions(&definitions);
|
||||
println!("{stringified}");
|
||||
let expected = "var testVar;var testFunc;func testFunc2(a, b) -> void;func testFunc3(a, b) -> void;class TestClass{func constructor(a, b) -> void;func testMethod(a, b) -> void;};";
|
||||
assert_eq!(stringified, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ruby() {
|
||||
let source = r#"
|
||||
# This is a test comment
|
||||
test_var = "test"
|
||||
def test_func(a, b)
|
||||
return a + b
|
||||
end
|
||||
class TestClass
|
||||
attr_accessor :a, :b
|
||||
def initialize(a, b)
|
||||
@a = a
|
||||
@b = b
|
||||
end
|
||||
def test_method(a, b)
|
||||
return a + b
|
||||
end
|
||||
end
|
||||
"#;
|
||||
let definitions = extract_definitions("ruby", source).unwrap();
|
||||
let stringified = stringify_definitions(&definitions);
|
||||
println!("{stringified}");
|
||||
// FIXME:
|
||||
let expected = "var test_var;func test_func(a, b) -> void;";
|
||||
assert_eq!(stringified, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lua() {
|
||||
let source = r#"
|
||||
-- This is a test comment
|
||||
local test_var = "test"
|
||||
function test_func(a, b)
|
||||
return a + b
|
||||
end
|
||||
"#;
|
||||
let definitions = extract_definitions("lua", source).unwrap();
|
||||
let stringified = stringify_definitions(&definitions);
|
||||
println!("{stringified}");
|
||||
let expected = "var test_var;func test_func(a, b) -> void;";
|
||||
assert_eq!(stringified, expected);
|
||||
}
|
||||
}
|
@ -397,6 +397,7 @@ function M.setup(opts)
|
||||
|
||||
if M.did_setup then return end
|
||||
|
||||
require("avante.repo_map").setup()
|
||||
require("avante.path").setup()
|
||||
require("avante.highlights").setup()
|
||||
require("avante.diff").setup()
|
||||
|
@ -85,6 +85,7 @@ M.stream = function(opts)
|
||||
:totable()
|
||||
|
||||
Utils.debug(user_prompts)
|
||||
-- print(user_prompts[1])
|
||||
|
||||
---@type AvantePromptOptions
|
||||
local code_opts = {
|
||||
|
147
lua/avante/repo_map.lua
Normal file
147
lua/avante/repo_map.lua
Normal file
@ -0,0 +1,147 @@
|
||||
local Utils = require("avante.utils")
|
||||
|
||||
local filetype_map = {
|
||||
["javascriptreact"] = "javascript",
|
||||
["typescriptreact"] = "typescript",
|
||||
}
|
||||
|
||||
---@class AvanteRepoMap
|
||||
---@field stringify_definitions fun(lang: string, source: string): string
|
||||
local repo_map_lib = nil
|
||||
|
||||
---@class avante.utils.repo_map
|
||||
local RepoMap = {}
|
||||
|
||||
function RepoMap.setup()
|
||||
vim.defer_fn(function()
|
||||
local ok, core = pcall(require, "avante_repo_map")
|
||||
if not ok then
|
||||
error("Failed to load avante_repo_map")
|
||||
return
|
||||
end
|
||||
|
||||
if repo_map_lib == nil then repo_map_lib = core end
|
||||
end, 1000)
|
||||
end
|
||||
|
||||
function RepoMap.get_ts_lang(filepath)
|
||||
local filetype = vim.filetype.match({ filename = filepath })
|
||||
return filetype_map[filetype] or filetype
|
||||
end
|
||||
|
||||
function RepoMap.get_filetype(filepath) return vim.filetype.match({ filename = filepath }) end
|
||||
|
||||
function RepoMap._build_repo_map(project_root, file_ext)
|
||||
local output = {}
|
||||
local gitignore_path = project_root .. "/.gitignore"
|
||||
local ignore_patterns, negate_patterns = Utils.parse_gitignore(gitignore_path)
|
||||
local filepaths = Utils.scan_directory(project_root, ignore_patterns, negate_patterns)
|
||||
vim.iter(filepaths):each(function(filepath)
|
||||
if not Utils.is_same_file_ext(file_ext, filepath) then return end
|
||||
local definitions =
|
||||
repo_map_lib.stringify_definitions(RepoMap.get_ts_lang(filepath), Utils.file.read_content(filepath) or "")
|
||||
if definitions == "" then return end
|
||||
table.insert(output, {
|
||||
path = Utils.relative_path(filepath),
|
||||
lang = RepoMap.get_filetype(filepath),
|
||||
defs = definitions,
|
||||
})
|
||||
end)
|
||||
return output
|
||||
end
|
||||
|
||||
local cache = {}
|
||||
|
||||
function RepoMap.get_repo_map(file_ext)
|
||||
file_ext = file_ext or vim.fn.expand("%:e")
|
||||
local project_root = Utils.root.get()
|
||||
local cache_key = project_root .. "." .. file_ext
|
||||
local cached = cache[cache_key]
|
||||
if cached then return cached end
|
||||
|
||||
local PPath = require("plenary.path")
|
||||
local Path = require("avante.path")
|
||||
local repo_map
|
||||
|
||||
local function build_and_save()
|
||||
repo_map = RepoMap._build_repo_map(project_root, file_ext)
|
||||
cache[cache_key] = repo_map
|
||||
Path.repo_map.save(project_root, file_ext, repo_map)
|
||||
end
|
||||
|
||||
repo_map = Path.repo_map.load(project_root, file_ext)
|
||||
|
||||
if not repo_map or next(repo_map) == nil then
|
||||
build_and_save()
|
||||
if not repo_map then return end
|
||||
else
|
||||
local timer = vim.loop.new_timer()
|
||||
|
||||
if timer then
|
||||
timer:start(
|
||||
0,
|
||||
0,
|
||||
vim.schedule_wrap(function()
|
||||
build_and_save()
|
||||
timer:close()
|
||||
end)
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
local update_repo_map = vim.schedule_wrap(function(rel_filepath)
|
||||
if rel_filepath and Utils.is_same_file_ext(file_ext, rel_filepath) then
|
||||
local abs_filepath = PPath:new(project_root):joinpath(rel_filepath):absolute()
|
||||
local definitions = repo_map_lib.stringify_definitions(
|
||||
RepoMap.get_ts_lang(abs_filepath),
|
||||
Utils.file.read_content(abs_filepath) or ""
|
||||
)
|
||||
if definitions == "" then return end
|
||||
local found = false
|
||||
for _, m in ipairs(repo_map) do
|
||||
if m.path == rel_filepath then
|
||||
m.defs = definitions
|
||||
found = true
|
||||
break
|
||||
end
|
||||
end
|
||||
if not found then
|
||||
table.insert(repo_map, {
|
||||
path = Utils.relative_path(abs_filepath),
|
||||
lang = RepoMap.get_filetype(abs_filepath),
|
||||
defs = definitions,
|
||||
})
|
||||
end
|
||||
cache[cache_key] = repo_map
|
||||
Path.repo_map.save(project_root, file_ext, repo_map)
|
||||
end
|
||||
end)
|
||||
|
||||
local handle = vim.loop.new_fs_event()
|
||||
|
||||
if handle then
|
||||
handle:start(project_root, { recursive = true }, function(err, rel_filepath)
|
||||
if err then
|
||||
print("Error watching directory " .. project_root .. ":", err)
|
||||
return
|
||||
end
|
||||
|
||||
if rel_filepath then update_repo_map(rel_filepath) end
|
||||
end)
|
||||
end
|
||||
|
||||
vim.api.nvim_create_autocmd({ "BufReadPost", "BufNewFile" }, {
|
||||
callback = function(ev)
|
||||
vim.defer_fn(function()
|
||||
local filepath = vim.api.nvim_buf_get_name(ev.buf)
|
||||
if not vim.startswith(filepath, project_root) then return end
|
||||
local rel_filepath = Utils.relative_path(filepath)
|
||||
update_repo_map(rel_filepath)
|
||||
end, 0)
|
||||
end,
|
||||
})
|
||||
|
||||
return repo_map
|
||||
end
|
||||
|
||||
return RepoMap
|
@ -2,6 +2,7 @@ local Utils = require("avante.utils")
|
||||
local Config = require("avante.config")
|
||||
local Llm = require("avante.llm")
|
||||
local Provider = require("avante.providers")
|
||||
local RepoMap = require("avante.repo_map")
|
||||
|
||||
local api = vim.api
|
||||
local fn = vim.fn
|
||||
@ -394,7 +395,7 @@ function Selection:create_editing_input()
|
||||
|
||||
local mentions = Utils.extract_mentions(input)
|
||||
input = mentions.new_content
|
||||
local project_context = mentions.enable_project_context and Utils.repo_map.get_repo_map(file_ext) or nil
|
||||
local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil
|
||||
|
||||
Llm.stream({
|
||||
bufnr = code_bufnr,
|
||||
|
@ -11,6 +11,7 @@ local Diff = require("avante.diff")
|
||||
local Llm = require("avante.llm")
|
||||
local Utils = require("avante.utils")
|
||||
local Highlights = require("avante.highlights")
|
||||
local RepoMap = require("avante.repo_map")
|
||||
|
||||
local RESULT_BUF_NAME = "AVANTE_RESULT"
|
||||
local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated"
|
||||
@ -1295,7 +1296,7 @@ function Sidebar:create_input(opts)
|
||||
|
||||
local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$")
|
||||
|
||||
local project_context = mentions.enable_project_context and Utils.repo_map.get_repo_map(file_ext) or nil
|
||||
local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil
|
||||
|
||||
Llm.stream({
|
||||
bufnr = self.code.bufnr,
|
||||
|
@ -6,6 +6,7 @@ local lsp = vim.lsp
|
||||
---@field tokens avante.utils.tokens
|
||||
---@field root avante.utils.root
|
||||
---@field repo_map avante.utils.repo_map
|
||||
---@field file avante.utils.file
|
||||
local M = {}
|
||||
|
||||
setmetatable(M, {
|
||||
|
@ -1,730 +0,0 @@
|
||||
local parsers = require("nvim-treesitter.parsers")
|
||||
local Config = require("avante.config")
|
||||
|
||||
local get_node_text = vim.treesitter.get_node_text
|
||||
|
||||
---@class avante.utils.repo_map
|
||||
local RepoMap = {}
|
||||
|
||||
local dependencies_queries = {
|
||||
lua = [[
|
||||
(function_call
|
||||
name: (identifier) @function_name
|
||||
arguments: (arguments
|
||||
(string) @required_file))
|
||||
]],
|
||||
|
||||
python = [[
|
||||
(import_from_statement
|
||||
module_name: (dotted_name) @import_module)
|
||||
(import_statement
|
||||
(dotted_name) @import_module)
|
||||
]],
|
||||
|
||||
javascript = [[
|
||||
(import_statement
|
||||
source: (string) @import_module)
|
||||
(call_expression
|
||||
function: (identifier) @function_name
|
||||
arguments: (arguments
|
||||
(string) @required_file))
|
||||
]],
|
||||
|
||||
typescript = [[
|
||||
(import_statement
|
||||
source: (string) @import_module)
|
||||
(call_expression
|
||||
function: (identifier) @function_name
|
||||
arguments: (arguments
|
||||
(string) @required_file))
|
||||
]],
|
||||
|
||||
go = [[
|
||||
(import_spec
|
||||
path: (interpreted_string_literal) @import_module)
|
||||
]],
|
||||
|
||||
rust = [[
|
||||
(use_declaration
|
||||
(scoped_identifier) @import_module)
|
||||
(use_declaration
|
||||
(identifier) @import_module)
|
||||
]],
|
||||
|
||||
c = [[
|
||||
(preproc_include
|
||||
(string_literal) @import_module)
|
||||
(preproc_include
|
||||
(system_lib_string) @import_module)
|
||||
]],
|
||||
|
||||
cpp = [[
|
||||
(preproc_include
|
||||
(string_literal) @import_module)
|
||||
(preproc_include
|
||||
(system_lib_string) @import_module)
|
||||
]],
|
||||
}
|
||||
|
||||
local definitions_queries = {
|
||||
python = [[
|
||||
;; Capture top-level functions, class, and method definitions
|
||||
(module
|
||||
(expression_statement
|
||||
(assignment) @assignment
|
||||
)
|
||||
)
|
||||
(module
|
||||
(function_definition) @function
|
||||
)
|
||||
(module
|
||||
(class_definition
|
||||
body: (block
|
||||
(expression_statement
|
||||
(assignment) @class_assignment
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
(module
|
||||
(class_definition
|
||||
body: (block
|
||||
(function_definition) @method
|
||||
)
|
||||
)
|
||||
)
|
||||
]],
|
||||
javascript = [[
|
||||
;; Capture exported functions, arrow functions, variables, classes, and method definitions
|
||||
(export_statement
|
||||
declaration: (lexical_declaration
|
||||
(variable_declarator) @variable
|
||||
)
|
||||
)
|
||||
(export_statement
|
||||
declaration: (function_declaration) @function
|
||||
)
|
||||
(export_statement
|
||||
declaration: (class_declaration
|
||||
body: (class_body
|
||||
(field_definition) @class_variable
|
||||
)
|
||||
)
|
||||
)
|
||||
(export_statement
|
||||
declaration: (class_declaration
|
||||
body: (class_body
|
||||
(method_definition) @method
|
||||
)
|
||||
)
|
||||
)
|
||||
]],
|
||||
typescript = [[
|
||||
;; Capture exported functions, arrow functions, variables, classes, and method definitions
|
||||
(export_statement
|
||||
declaration: (lexical_declaration
|
||||
(variable_declarator) @variable
|
||||
)
|
||||
)
|
||||
(export_statement
|
||||
declaration: (function_declaration) @function
|
||||
)
|
||||
(export_statement
|
||||
declaration: (class_declaration
|
||||
body: (class_body
|
||||
(public_field_definition) @class_variable
|
||||
)
|
||||
)
|
||||
)
|
||||
(interface_declaration
|
||||
body: (interface_body
|
||||
(property_signature) @class_variable
|
||||
)
|
||||
)
|
||||
(type_alias_declaration
|
||||
value: (object_type
|
||||
(property_signature) @class_variable
|
||||
)
|
||||
)
|
||||
(export_statement
|
||||
declaration: (class_declaration
|
||||
body: (class_body
|
||||
(method_definition) @method
|
||||
)
|
||||
)
|
||||
)
|
||||
]],
|
||||
rust = [[
|
||||
;; Capture public functions, structs, methods, and variable definitions
|
||||
(function_item) @function
|
||||
(impl_item
|
||||
body: (declaration_list
|
||||
(function_item) @method
|
||||
)
|
||||
)
|
||||
(struct_item
|
||||
body: (field_declaration_list
|
||||
(field_declaration) @class_variable
|
||||
)
|
||||
)
|
||||
(enum_item
|
||||
body: (enum_variant_list
|
||||
(enum_variant) @enum_item
|
||||
)
|
||||
)
|
||||
(const_item) @variable
|
||||
]],
|
||||
go = [[
|
||||
;; Capture top-level functions and struct definitions
|
||||
(var_declaration
|
||||
(var_spec) @variable
|
||||
)
|
||||
(const_declaration
|
||||
(const_spec) @variable
|
||||
)
|
||||
(function_declaration) @function
|
||||
(type_declaration
|
||||
(type_spec (struct_type)) @class
|
||||
)
|
||||
(type_declaration
|
||||
(type_spec
|
||||
(struct_type
|
||||
(field_declaration_list
|
||||
(field_declaration) @class_variable)))
|
||||
)
|
||||
(method_declaration) @method
|
||||
]],
|
||||
c = [[
|
||||
;; Capture extern functions, variables, public classes, and methods
|
||||
(function_definition
|
||||
(storage_class_specifier) @extern
|
||||
) @function
|
||||
(class_specifier
|
||||
(public) @class
|
||||
(function_definition) @method
|
||||
) @class
|
||||
(declaration
|
||||
(storage_class_specifier) @extern
|
||||
) @variable
|
||||
]],
|
||||
cpp = [[
|
||||
;; Capture extern functions, variables, public classes, and methods
|
||||
(function_definition
|
||||
(storage_class_specifier) @extern
|
||||
) @function
|
||||
(class_specifier
|
||||
(public) @class
|
||||
(function_definition) @method
|
||||
) @class
|
||||
(declaration
|
||||
(storage_class_specifier) @extern
|
||||
) @variable
|
||||
]],
|
||||
lua = [[
|
||||
;; Capture function and method definitions
|
||||
(variable_list) @variable
|
||||
(function_declaration) @function
|
||||
]],
|
||||
ruby = [[
|
||||
;; Capture top-level methods, class definitions, and methods within classes
|
||||
(method) @function
|
||||
(assignment) @assignment
|
||||
(class
|
||||
body: (body_statement
|
||||
(assignment) @class_assignment
|
||||
(method) @method
|
||||
)
|
||||
)
|
||||
]],
|
||||
}
|
||||
|
||||
local queries_filetype_map = {
|
||||
["javascriptreact"] = "javascript",
|
||||
["typescriptreact"] = "typescript",
|
||||
}
|
||||
|
||||
local function get_query(queries, filetype)
|
||||
filetype = queries_filetype_map[filetype] or filetype
|
||||
return queries[filetype]
|
||||
end
|
||||
|
||||
local function get_ts_lang(bufnr)
|
||||
local lang = parsers.get_buf_lang(bufnr)
|
||||
return lang
|
||||
end
|
||||
|
||||
function RepoMap.get_parser(bufnr)
|
||||
local lang = get_ts_lang(bufnr)
|
||||
if not lang then return end
|
||||
local parser = parsers.get_parser(bufnr, lang)
|
||||
return parser, lang
|
||||
end
|
||||
|
||||
function RepoMap.extract_dependencies(bufnr)
|
||||
local parser, lang = RepoMap.get_parser(bufnr)
|
||||
if not lang or not parser or not dependencies_queries[lang] then
|
||||
print("No parser or query available for this buffer's language: " .. (lang or "unknown"))
|
||||
return {}
|
||||
end
|
||||
|
||||
local dependencies = {}
|
||||
local tree = parser:parse()[1]
|
||||
local root = tree:root()
|
||||
local filetype = vim.api.nvim_get_option_value("filetype", { buf = bufnr })
|
||||
|
||||
local query = get_query(dependencies_queries, filetype)
|
||||
if not query then return dependencies end
|
||||
|
||||
local query_obj = vim.treesitter.query.parse(lang, query)
|
||||
|
||||
for _, node, _ in query_obj:iter_captures(root, bufnr, 0, -1) do
|
||||
-- local name = query.captures[id]
|
||||
local required_file = vim.treesitter.get_node_text(node, bufnr):gsub('"', ""):gsub("'", "")
|
||||
table.insert(dependencies, required_file)
|
||||
end
|
||||
|
||||
return dependencies
|
||||
end
|
||||
|
||||
function RepoMap.get_filetype_by_filepath(filepath) return vim.filetype.match({ filename = filepath }) end
|
||||
|
||||
function RepoMap.parse_file(filepath)
|
||||
local File = require("avante.utils.file")
|
||||
local source = File.read_content(filepath)
|
||||
|
||||
local filetype = RepoMap.get_filetype_by_filepath(filepath)
|
||||
local lang = parsers.ft_to_lang(filetype)
|
||||
if lang then
|
||||
local ok, parser = pcall(vim.treesitter.get_string_parser, source, lang)
|
||||
if ok then
|
||||
local tree = parser:parse()[1]
|
||||
local node = tree:root()
|
||||
return { node = node, source = source }
|
||||
else
|
||||
print("parser error", parser)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function get_closest_parent_name(node, source)
|
||||
local parent = node:parent()
|
||||
while parent do
|
||||
local name = parent:field("name")[1]
|
||||
if name then return get_node_text(name, source) end
|
||||
parent = parent:parent()
|
||||
end
|
||||
return ""
|
||||
end
|
||||
|
||||
local function find_parent_by_type(node, type)
|
||||
local parent = node:parent()
|
||||
while parent do
|
||||
if parent:type() == type then return parent end
|
||||
parent = parent:parent()
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
local function find_child_by_type(node, type)
|
||||
for child in node:iter_children() do
|
||||
if child:type() == type then return child end
|
||||
local res = find_child_by_type(child, type)
|
||||
if res then return res end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
local function get_node_type(node, source)
|
||||
local node_type
|
||||
local predefined_type_node = find_child_by_type(node, "predefined_type")
|
||||
if predefined_type_node then
|
||||
node_type = get_node_text(predefined_type_node, source)
|
||||
else
|
||||
local value_type_node = node:field("type")[1]
|
||||
node_type = value_type_node and get_node_text(value_type_node, source) or ""
|
||||
end
|
||||
return node_type
|
||||
end
|
||||
|
||||
-- Function to extract definitions from the file
|
||||
function RepoMap.extract_definitions(filepath)
|
||||
local Utils = require("avante.utils")
|
||||
|
||||
local filetype = RepoMap.get_filetype_by_filepath(filepath)
|
||||
|
||||
if not filetype then return {} end
|
||||
|
||||
-- Get the corresponding query for the detected language
|
||||
local query = get_query(definitions_queries, filetype)
|
||||
if not query then return {} end
|
||||
|
||||
local parsed = RepoMap.parse_file(filepath)
|
||||
if not parsed then return {} end
|
||||
|
||||
-- Get the current buffer's syntax tree
|
||||
local root = parsed.node
|
||||
|
||||
local lang = parsers.ft_to_lang(filetype)
|
||||
|
||||
-- Parse the query
|
||||
local query_obj = vim.treesitter.query.parse(lang, query)
|
||||
|
||||
-- Store captured results
|
||||
local definitions = {}
|
||||
|
||||
local class_def_map = {}
|
||||
local enum_def_map = {}
|
||||
|
||||
local function get_class_def(name)
|
||||
local def = class_def_map[name]
|
||||
if def == nil then
|
||||
def = {
|
||||
type = "class",
|
||||
name = name,
|
||||
methods = {},
|
||||
properties = {},
|
||||
}
|
||||
class_def_map[name] = def
|
||||
end
|
||||
return def
|
||||
end
|
||||
|
||||
local function get_enum_def(name)
|
||||
local def = enum_def_map[name]
|
||||
if def == nil then
|
||||
def = {
|
||||
type = "enum",
|
||||
name = name,
|
||||
items = {},
|
||||
}
|
||||
enum_def_map[name] = def
|
||||
end
|
||||
return def
|
||||
end
|
||||
|
||||
for _, captures, _ in query_obj:iter_matches(root, parsed.source) do
|
||||
for id, node in pairs(captures) do
|
||||
local type = query_obj.captures[id]
|
||||
local name_node = node:field("name")[1]
|
||||
local name = name_node and get_node_text(name_node, parsed.source) or ""
|
||||
|
||||
if type == "class" then
|
||||
if name ~= "" then get_class_def(name) end
|
||||
elseif type == "enum_item" then
|
||||
local enum_name = get_closest_parent_name(node, parsed.source)
|
||||
if enum_name and filetype == "go" and not Utils.is_first_letter_uppercase(enum_name) then goto continue end
|
||||
local enum_def = get_enum_def(enum_name)
|
||||
local enum_type_node = find_child_by_type(node, "type_identifier")
|
||||
local enum_type = enum_type_node and get_node_text(enum_type_node, parsed.source) or ""
|
||||
table.insert(enum_def.items, {
|
||||
name = name,
|
||||
type = enum_type,
|
||||
})
|
||||
elseif type == "method" then
|
||||
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
|
||||
local params_node = node:field("parameters")[1]
|
||||
local params = params_node and get_node_text(params_node, parsed.source) or "()"
|
||||
local return_type_node = node:field("return_type")[1] or node:field("result")[1]
|
||||
local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void"
|
||||
|
||||
local class_name
|
||||
local impl_item_node = find_parent_by_type(node, "impl_item")
|
||||
local receiver_node = node:field("receiver")[1]
|
||||
if impl_item_node then
|
||||
local impl_type_node = impl_item_node:field("type")[1]
|
||||
class_name = impl_type_node and get_node_text(impl_type_node, parsed.source) or ""
|
||||
elseif receiver_node then
|
||||
local type_identifier_node = find_child_by_type(receiver_node, "type_identifier")
|
||||
class_name = type_identifier_node and get_node_text(type_identifier_node, parsed.source) or ""
|
||||
else
|
||||
class_name = get_closest_parent_name(node, parsed.source)
|
||||
end
|
||||
local class_def = get_class_def(class_name)
|
||||
|
||||
local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier")
|
||||
local accessibility_modifier = accessibility_modifier_node
|
||||
and get_node_text(accessibility_modifier_node, parsed.source)
|
||||
or ""
|
||||
|
||||
table.insert(class_def.methods, {
|
||||
type = "function",
|
||||
name = name,
|
||||
params = params,
|
||||
return_type = return_type,
|
||||
accessibility_modifier = accessibility_modifier,
|
||||
})
|
||||
elseif type == "class_assignment" then
|
||||
local left_node = node:field("left")[1]
|
||||
local left = left_node and get_node_text(left_node, parsed.source) or ""
|
||||
|
||||
local value_type = get_node_type(node, parsed.source)
|
||||
|
||||
local class_name = get_closest_parent_name(node, parsed.source)
|
||||
if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end
|
||||
|
||||
local class_def = get_class_def(class_name)
|
||||
|
||||
table.insert(class_def.properties, {
|
||||
type = "variable",
|
||||
name = left,
|
||||
value_type = value_type,
|
||||
})
|
||||
elseif type == "class_variable" then
|
||||
local value_type = get_node_type(node, parsed.source)
|
||||
|
||||
local class_name = get_closest_parent_name(node, parsed.source)
|
||||
if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end
|
||||
|
||||
local class_def = get_class_def(class_name)
|
||||
|
||||
table.insert(class_def.properties, {
|
||||
type = "variable",
|
||||
name = name,
|
||||
value_type = value_type,
|
||||
})
|
||||
elseif type == "function" or type == "arrow_function" then
|
||||
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
|
||||
local impl_item_node = find_parent_by_type(node, "impl_item")
|
||||
if impl_item_node then goto continue end
|
||||
local function_node = find_parent_by_type(node, "function_declaration")
|
||||
or find_parent_by_type(node, "function_definition")
|
||||
if function_node then goto continue end
|
||||
-- Extract function parameters and return type
|
||||
local params_node = node:field("parameters")[1]
|
||||
local params = params_node and get_node_text(params_node, parsed.source) or "()"
|
||||
local return_type_node = node:field("return_type")[1] or node:field("result")[1]
|
||||
local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void"
|
||||
|
||||
local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier")
|
||||
local accessibility_modifier = accessibility_modifier_node
|
||||
and get_node_text(accessibility_modifier_node, parsed.source)
|
||||
or ""
|
||||
|
||||
local def = {
|
||||
type = "function",
|
||||
name = name,
|
||||
params = params,
|
||||
return_type = return_type,
|
||||
accessibility_modifier = accessibility_modifier,
|
||||
}
|
||||
table.insert(definitions, def)
|
||||
elseif type == "assignment" then
|
||||
local impl_item_node = find_parent_by_type(node, "impl_item")
|
||||
or find_parent_by_type(node, "class_declaration")
|
||||
or find_parent_by_type(node, "class_definition")
|
||||
if impl_item_node then goto continue end
|
||||
local function_node = find_parent_by_type(node, "function_declaration")
|
||||
or find_parent_by_type(node, "function_definition")
|
||||
if function_node then goto continue end
|
||||
|
||||
local left_node = node:field("left")[1]
|
||||
local left = left_node and get_node_text(left_node, parsed.source) or ""
|
||||
|
||||
if left and filetype == "go" and not Utils.is_first_letter_uppercase(left) then goto continue end
|
||||
|
||||
local value_type = get_node_type(node, parsed.source)
|
||||
|
||||
local def = {
|
||||
type = "variable",
|
||||
name = left,
|
||||
value_type = value_type,
|
||||
}
|
||||
table.insert(definitions, def)
|
||||
elseif type == "variable" then
|
||||
local impl_item_node = find_parent_by_type(node, "impl_item")
|
||||
or find_parent_by_type(node, "class_declaration")
|
||||
or find_parent_by_type(node, "class_definition")
|
||||
if impl_item_node then goto continue end
|
||||
local function_node = find_parent_by_type(node, "function_declaration")
|
||||
or find_parent_by_type(node, "function_definition")
|
||||
if function_node then goto continue end
|
||||
|
||||
local value_type = get_node_type(node, parsed.source)
|
||||
|
||||
if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end
|
||||
|
||||
local def = { type = "variable", name = name, value_type = value_type }
|
||||
table.insert(definitions, def)
|
||||
end
|
||||
::continue::
|
||||
end
|
||||
end
|
||||
|
||||
for _, def in pairs(class_def_map) do
|
||||
table.insert(definitions, def)
|
||||
end
|
||||
|
||||
for _, def in pairs(enum_def_map) do
|
||||
table.insert(definitions, def)
|
||||
end
|
||||
|
||||
return definitions
|
||||
end
|
||||
|
||||
local function stringify_function(def)
|
||||
local res = "func " .. def.name .. def.params .. ":" .. def.return_type .. ";"
|
||||
if def.accessibility_modifier and def.accessibility_modifier ~= "" then
|
||||
res = def.accessibility_modifier .. " " .. res
|
||||
end
|
||||
return res
|
||||
end
|
||||
|
||||
local function stringify_variable(def)
|
||||
local res = "var " .. def.name
|
||||
if def.value_type and def.value_type ~= "" then res = res .. ":" .. def.value_type end
|
||||
return res .. ";"
|
||||
end
|
||||
|
||||
local function stringify_enum_item(def)
|
||||
local res = def.name
|
||||
if def.value_type and def.value_type ~= "" then res = res .. ":" .. def.value_type end
|
||||
return res .. ";"
|
||||
end
|
||||
|
||||
-- Function to load file content into a temporary buffer, process it, and then delete the buffer
|
||||
function RepoMap.stringify_definitions(filepath)
|
||||
if vim.endswith(filepath, "~") then return "" end
|
||||
|
||||
-- Extract definitions
|
||||
local definitions = RepoMap.extract_definitions(filepath)
|
||||
|
||||
local output = ""
|
||||
-- Print or process the definitions
|
||||
for _, def in ipairs(definitions) do
|
||||
if def.type == "class" then
|
||||
output = output .. def.type .. " " .. def.name .. "{"
|
||||
for _, property in ipairs(def.properties) do
|
||||
output = output .. stringify_variable(property)
|
||||
end
|
||||
for _, method in ipairs(def.methods) do
|
||||
output = output .. stringify_function(method)
|
||||
end
|
||||
output = output .. "}"
|
||||
elseif def.type == "enum" then
|
||||
output = output .. def.type .. " " .. def.name .. "{"
|
||||
for _, item in ipairs(def.items) do
|
||||
output = output .. stringify_enum_item(item) .. ""
|
||||
end
|
||||
output = output .. "}"
|
||||
elseif def.type == "function" then
|
||||
output = output .. stringify_function(def)
|
||||
elseif def.type == "variable" then
|
||||
output = output .. stringify_variable(def)
|
||||
end
|
||||
end
|
||||
|
||||
return output
|
||||
end
|
||||
|
||||
function RepoMap._build_repo_map(project_root, file_ext)
|
||||
local Utils = require("avante.utils")
|
||||
local output = {}
|
||||
local gitignore_path = project_root .. "/.gitignore"
|
||||
local ignore_patterns, negate_patterns = Utils.parse_gitignore(gitignore_path)
|
||||
local filepaths = Utils.scan_directory(project_root, ignore_patterns, negate_patterns)
|
||||
vim.iter(filepaths):each(function(filepath)
|
||||
if not Utils.is_same_file_ext(file_ext, filepath) then return end
|
||||
local definitions = RepoMap.stringify_definitions(filepath)
|
||||
if definitions == "" then return end
|
||||
table.insert(output, {
|
||||
path = Utils.relative_path(filepath),
|
||||
lang = RepoMap.get_filetype_by_filepath(filepath),
|
||||
defs = definitions,
|
||||
})
|
||||
end)
|
||||
return output
|
||||
end
|
||||
|
||||
local cache = {}
|
||||
|
||||
function RepoMap.get_repo_map(file_ext)
|
||||
file_ext = file_ext or vim.fn.expand("%:e")
|
||||
local Utils = require("avante.utils")
|
||||
local project_root = Utils.root.get()
|
||||
local cache_key = project_root .. "." .. file_ext
|
||||
local cached = cache[cache_key]
|
||||
if cached then return cached end
|
||||
|
||||
local PPath = require("plenary.path")
|
||||
local Path = require("avante.path")
|
||||
local repo_map
|
||||
|
||||
local function build_and_save()
|
||||
repo_map = RepoMap._build_repo_map(project_root, file_ext)
|
||||
cache[cache_key] = repo_map
|
||||
Path.repo_map.save(project_root, file_ext, repo_map)
|
||||
end
|
||||
|
||||
repo_map = Path.repo_map.load(project_root, file_ext)
|
||||
|
||||
if not repo_map or next(repo_map) == nil then
|
||||
build_and_save()
|
||||
if not repo_map then return end
|
||||
else
|
||||
local timer = vim.loop.new_timer()
|
||||
|
||||
if timer then
|
||||
timer:start(
|
||||
0,
|
||||
0,
|
||||
vim.schedule_wrap(function()
|
||||
build_and_save()
|
||||
timer:close()
|
||||
end)
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
local update_repo_map = vim.schedule_wrap(function(rel_filepath)
|
||||
if rel_filepath and Utils.is_same_file_ext(file_ext, rel_filepath) then
|
||||
local abs_filepath = PPath:new(project_root):joinpath(rel_filepath):absolute()
|
||||
local definitions = RepoMap.stringify_definitions(abs_filepath)
|
||||
if definitions == "" then return end
|
||||
local found = false
|
||||
for _, m in ipairs(repo_map) do
|
||||
if m.path == rel_filepath then
|
||||
m.defs = definitions
|
||||
found = true
|
||||
break
|
||||
end
|
||||
end
|
||||
if not found then
|
||||
table.insert(repo_map, {
|
||||
path = Utils.relative_path(abs_filepath),
|
||||
lang = RepoMap.get_filetype_by_filepath(abs_filepath),
|
||||
defs = definitions,
|
||||
})
|
||||
end
|
||||
cache[cache_key] = repo_map
|
||||
Path.repo_map.save(project_root, file_ext, repo_map)
|
||||
end
|
||||
end)
|
||||
|
||||
local handle = vim.loop.new_fs_event()
|
||||
|
||||
if handle then
|
||||
handle:start(project_root, { recursive = true }, function(err, rel_filepath)
|
||||
if err then
|
||||
print("Error watching directory " .. project_root .. ":", err)
|
||||
return
|
||||
end
|
||||
|
||||
if rel_filepath then update_repo_map(rel_filepath) end
|
||||
end)
|
||||
end
|
||||
|
||||
vim.api.nvim_create_autocmd({ "BufReadPost", "BufNewFile" }, {
|
||||
callback = function(ev)
|
||||
vim.defer_fn(function()
|
||||
local filepath = vim.api.nvim_buf_get_name(ev.buf)
|
||||
if not vim.startswith(filepath, project_root) then return end
|
||||
local rel_filepath = Utils.relative_path(filepath)
|
||||
update_repo_map(rel_filepath)
|
||||
end, 0)
|
||||
end,
|
||||
})
|
||||
|
||||
return repo_map
|
||||
end
|
||||
|
||||
return RepoMap
|
Loading…
x
Reference in New Issue
Block a user