feat(repo-map): zig support (#663)
* feature: zig support for repo map * Update crates/avante-repo-map/Cargo.toml Co-authored-by: yetone <yetoneful@gmail.com> * fix: update lint error Signed-off-by: Aaron Pham <contact@aarnphm.xyz> --------- Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Aaron Pham <Aaronpham0103@gmail.com> Co-authored-by: yetone <yetoneful@gmail.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
d28fece472
commit
bac46cee83
11
Cargo.lock
generated
11
Cargo.lock
generated
@ -48,6 +48,7 @@ dependencies = [
|
||||
"tree-sitter-ruby",
|
||||
"tree-sitter-rust",
|
||||
"tree-sitter-typescript",
|
||||
"tree-sitter-zig",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1404,6 +1405,16 @@ dependencies = [
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-zig"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2994e37b8ef1f715b931a5ff084a1b1713b1bc56e7aaebd148cc3efe0bf29ad9"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typeid"
|
||||
version = "1.0.2"
|
||||
|
@ -26,6 +26,7 @@ tree-sitter-c = "0.23"
|
||||
tree-sitter-cpp = "0.23"
|
||||
tree-sitter-lua = "0.2"
|
||||
tree-sitter-ruby = "0.23"
|
||||
tree-sitter-zig = "1.0.2"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
23
crates/avante-repo-map/queries/tree-sitter-zig-defs.scm
Normal file
23
crates/avante-repo-map/queries/tree-sitter-zig-defs.scm
Normal file
@ -0,0 +1,23 @@
|
||||
;; Capture functions, structs, methods, variable definitions, and unions in Zig
|
||||
(variable_declaration (identifier)
|
||||
(struct_declaration
|
||||
(container_field) @class_variable))
|
||||
|
||||
(variable_declaration (identifier)
|
||||
(struct_declaration
|
||||
(function_declaration
|
||||
name: (identifier) @method)))
|
||||
|
||||
(variable_declaration (identifier)
|
||||
(enum_declaration
|
||||
(container_field
|
||||
type: (identifier) @enum_item)))
|
||||
|
||||
(variable_declaration (identifier)
|
||||
(union_declaration
|
||||
(container_field
|
||||
name: (identifier) @union_item)))
|
||||
|
||||
(source_file (function_declaration) @function)
|
||||
|
||||
(source_file (variable_declaration (identifier) @variable))
|
@ -26,6 +26,12 @@ pub struct Enum {
|
||||
pub items: Vec<Variable>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Union {
|
||||
pub name: String,
|
||||
pub items: Vec<Variable>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Variable {
|
||||
pub name: String,
|
||||
@ -38,6 +44,7 @@ pub enum Definition {
|
||||
Class(Class),
|
||||
Enum(Enum),
|
||||
Variable(Variable),
|
||||
Union(Union),
|
||||
}
|
||||
|
||||
fn get_ts_language(language: &str) -> Option<LanguageFn> {
|
||||
@ -51,6 +58,7 @@ fn get_ts_language(language: &str) -> Option<LanguageFn> {
|
||||
"cpp" => Some(tree_sitter_cpp::LANGUAGE),
|
||||
"lua" => Some(tree_sitter_lua::LANGUAGE),
|
||||
"ruby" => Some(tree_sitter_ruby::LANGUAGE),
|
||||
"zig" => Some(tree_sitter_zig::LANGUAGE),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@ -62,6 +70,7 @@ const JAVASCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-javascript-d
|
||||
const LUA_QUERY: &str = include_str!("../queries/tree-sitter-lua-defs.scm");
|
||||
const PYTHON_QUERY: &str = include_str!("../queries/tree-sitter-python-defs.scm");
|
||||
const RUST_QUERY: &str = include_str!("../queries/tree-sitter-rust-defs.scm");
|
||||
const ZIG_QUERY: &str = include_str!("../queries/tree-sitter-zig-defs.scm");
|
||||
const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-defs.scm");
|
||||
const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm");
|
||||
|
||||
@ -79,6 +88,7 @@ fn get_definitions_query(language: &str) -> Result<Query, String> {
|
||||
"lua" => LUA_QUERY,
|
||||
"python" => PYTHON_QUERY,
|
||||
"rust" => RUST_QUERY,
|
||||
"zig" => ZIG_QUERY,
|
||||
"typescript" => TYPESCRIPT_QUERY,
|
||||
"ruby" => RUBY_QUERY,
|
||||
_ => return Err(format!("Unsupported language: {language}")),
|
||||
@ -128,6 +138,49 @@ fn find_child_by_type<'a>(node: &'a Node, child_type: &str) -> Option<Node<'a>>
|
||||
.find(|child| child.kind() == child_type)
|
||||
}
|
||||
|
||||
// Zig-specific function to find the parent variable declaration
|
||||
fn zig_find_parent_variable_declaration_name<'a>(
|
||||
node: &'a Node,
|
||||
source: &'a [u8],
|
||||
) -> Option<String> {
|
||||
let vardec = find_ancestor_by_type(node, "variable_declaration");
|
||||
if let Some(vardec) = vardec {
|
||||
// Find the identifier child node, which represents the class name
|
||||
let identifier_node = find_child_by_type(&vardec, "identifier");
|
||||
if let Some(identifier_node) = identifier_node {
|
||||
return Some(get_node_text(&identifier_node, source));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn zig_is_declaration_public<'a>(node: &'a Node, declaration_type: &str, source: &'a [u8]) -> bool {
|
||||
let declaration = find_ancestor_by_type(node, declaration_type);
|
||||
if let Some(declaration) = declaration {
|
||||
let declaration_text = get_node_text(&declaration, source);
|
||||
return declaration_text.starts_with("pub");
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn zig_is_variable_declaration_public<'a>(node: &'a Node, source: &'a [u8]) -> bool {
|
||||
zig_is_declaration_public(node, "variable_declaration", source)
|
||||
}
|
||||
|
||||
fn zig_is_function_declaration_public<'a>(node: &'a Node, source: &'a [u8]) -> bool {
|
||||
zig_is_declaration_public(node, "function_declaration", source)
|
||||
}
|
||||
|
||||
fn zig_find_type_in_parent<'a>(node: &'a Node, source: &'a [u8]) -> Option<String> {
|
||||
// First go to the parent and then get the child_by_field_name "type"
|
||||
if let Some(parent) = node.parent() {
|
||||
if let Some(type_node) = parent.child_by_field_name("type") {
|
||||
return Some(get_node_text(&type_node, source));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String {
|
||||
node.utf8_text(source).unwrap_or_default().to_string()
|
||||
}
|
||||
@ -176,6 +229,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
|
||||
let mut class_def_map: HashMap<String, RefCell<Class>> = HashMap::new();
|
||||
let mut enum_def_map: HashMap<String, RefCell<Enum>> = HashMap::new();
|
||||
let mut union_def_map: HashMap<String, RefCell<Union>> = HashMap::new();
|
||||
|
||||
let ensure_class_def = |name: &str, class_def_map: &mut HashMap<String, RefCell<Class>>| {
|
||||
class_def_map.entry(name.to_string()).or_insert_with(|| {
|
||||
@ -197,14 +251,26 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
});
|
||||
};
|
||||
|
||||
let ensure_union_def = |name: &str, union_def_map: &mut HashMap<String, RefCell<Union>>| {
|
||||
union_def_map.entry(name.to_string()).or_insert_with(|| {
|
||||
RefCell::new(Union {
|
||||
name: name.to_string(),
|
||||
items: vec![],
|
||||
})
|
||||
});
|
||||
};
|
||||
|
||||
for (m, _) in captures {
|
||||
for capture in m.captures {
|
||||
let capture_name = &query.capture_names()[capture.index as usize];
|
||||
let node = capture.node;
|
||||
let node_text = node.utf8_text(source.as_bytes()).unwrap();
|
||||
|
||||
let name_node = node.child_by_field_name("name");
|
||||
let name = name_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
.unwrap_or(node_text);
|
||||
|
||||
match *capture_name {
|
||||
"class" => {
|
||||
if !name.is_empty() {
|
||||
@ -235,7 +301,17 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
if language == "rust" && !visibility_modifier.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
let enum_name = get_closest_ancestor_name(&node, source);
|
||||
if language == "zig"
|
||||
&& !zig_is_variable_declaration_public(&node, source.as_bytes())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
let mut enum_name = get_closest_ancestor_name(&node, source);
|
||||
if language == "zig" {
|
||||
enum_name =
|
||||
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
|
||||
.unwrap_or_default();
|
||||
}
|
||||
if !enum_name.is_empty()
|
||||
&& language == "go"
|
||||
&& !is_first_letter_uppercase(&enum_name)
|
||||
@ -254,6 +330,28 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
};
|
||||
enum_def.borrow_mut().items.push(variable);
|
||||
}
|
||||
"union_item" => {
|
||||
if language != "zig" {
|
||||
continue;
|
||||
}
|
||||
if !zig_is_variable_declaration_public(&node, source.as_bytes()) {
|
||||
continue;
|
||||
}
|
||||
let union_name =
|
||||
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
|
||||
.unwrap_or_default();
|
||||
ensure_union_def(&union_name, &mut union_def_map);
|
||||
let union_def = union_def_map.get_mut(&union_name).unwrap();
|
||||
let union_type_node = find_descendant_by_type(&node, "type_identifier");
|
||||
let union_type = union_type_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
let variable = Variable {
|
||||
name: name.to_string(),
|
||||
value_type: union_type.to_string(),
|
||||
};
|
||||
union_def.borrow_mut().items.push(variable);
|
||||
}
|
||||
"method" => {
|
||||
let visibility_modifier_node =
|
||||
find_descendant_by_type(&node, "visibility_modifier");
|
||||
@ -263,10 +361,25 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
if language == "rust" && !visibility_modifier.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
if language == "zig"
|
||||
&& !(zig_is_function_declaration_public(&node, source.as_bytes())
|
||||
&& zig_is_variable_declaration_public(&node, source.as_bytes()))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
|
||||
continue;
|
||||
}
|
||||
let params_node = node.child_by_field_name("parameters");
|
||||
let mut params_node = node.child_by_field_name("parameters");
|
||||
|
||||
let function_node = find_ancestor_by_type(&node, "function_declaration");
|
||||
if language == "zig" {
|
||||
params_node = function_node
|
||||
.as_ref()
|
||||
.and_then(|n| find_child_by_type(n, "parameters"));
|
||||
}
|
||||
|
||||
let params = params_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("()");
|
||||
@ -288,7 +401,10 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
|
||||
let impl_item_node = find_ancestor_by_type(&node, "impl_item");
|
||||
let receiver_node = node.child_by_field_name("receiver");
|
||||
let class_name = if let Some(impl_item) = impl_item_node {
|
||||
let class_name = if language == "zig" {
|
||||
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
|
||||
.unwrap_or_default()
|
||||
} else if let Some(impl_item) = impl_item_node {
|
||||
let impl_type_node = impl_item.child_by_field_name("type");
|
||||
impl_type_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
@ -371,8 +487,22 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
if language == "rust" && !visibility_modifier.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let value_type = get_node_type(&node, source.as_bytes());
|
||||
let class_name = get_closest_ancestor_name(&node, source);
|
||||
|
||||
if language == "zig" {
|
||||
// when top level class is not public, skip
|
||||
if !zig_is_variable_declaration_public(&node, source.as_bytes()) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let mut class_name = get_closest_ancestor_name(&node, source);
|
||||
if language == "zig" {
|
||||
class_name =
|
||||
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
|
||||
.unwrap_or_default();
|
||||
}
|
||||
if !class_name.is_empty()
|
||||
&& language == "go"
|
||||
&& !is_first_letter_uppercase(&class_name)
|
||||
@ -399,9 +529,19 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
|
||||
if language == "rust" && !visibility_modifier.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if language == "zig" {
|
||||
let variable_declaration_text =
|
||||
node.utf8_text(source.as_bytes()).unwrap_or("");
|
||||
if !variable_declaration_text.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
|
||||
continue;
|
||||
}
|
||||
@ -493,9 +633,17 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
|
||||
if language == "rust" && !visibility_modifier.contains("pub") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if language == "zig"
|
||||
&& !zig_is_variable_declaration_public(&node, source.as_bytes())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let impl_item_node = find_ancestor_by_type(&node, "impl_item")
|
||||
.or_else(|| find_ancestor_by_type(&node, "class_declaration"))
|
||||
.or_else(|| find_ancestor_by_type(&node, "class_definition"));
|
||||
@ -532,7 +680,15 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let value_type = get_node_type(&node, source.as_bytes());
|
||||
|
||||
let mut value_type = get_node_type(&node, source.as_bytes());
|
||||
if language == "zig" {
|
||||
if let Some(zig_type) = zig_find_type_in_parent(&node, source.as_bytes()) {
|
||||
value_type = zig_type;
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
}
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
|
||||
continue;
|
||||
}
|
||||
@ -563,6 +719,9 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
for (_, def) in enum_def_map {
|
||||
definitions.push(Definition::Enum(def.into_inner()));
|
||||
}
|
||||
for (_, def) in union_def_map {
|
||||
definitions.push(Definition::Union(def.into_inner()));
|
||||
}
|
||||
|
||||
Ok(definitions)
|
||||
}
|
||||
@ -599,6 +758,14 @@ fn stringify_enum_item(item: &Variable) -> String {
|
||||
format!("{res};")
|
||||
}
|
||||
|
||||
fn stringify_union_item(item: &Variable) -> String {
|
||||
let mut res = item.name.clone();
|
||||
if !item.value_type.is_empty() {
|
||||
res = format!("{res}:{}", item.value_type);
|
||||
}
|
||||
format!("{res};")
|
||||
}
|
||||
|
||||
fn stringify_class(class: &Class) -> String {
|
||||
let mut res = format!("class {}{{", class.name);
|
||||
for method in &class.methods {
|
||||
@ -620,6 +787,14 @@ fn stringify_enum(enum_def: &Enum) -> String {
|
||||
}
|
||||
format!("{res}}};")
|
||||
}
|
||||
fn stringify_union(union_def: &Union) -> String {
|
||||
let mut res = format!("union {}{{", union_def.name);
|
||||
for item in &union_def.items {
|
||||
let item_str = stringify_union_item(item);
|
||||
res = format!("{res}{item_str}");
|
||||
}
|
||||
format!("{res}}};")
|
||||
}
|
||||
|
||||
fn stringify_definitions(definitions: &Vec<Definition>) -> String {
|
||||
let mut res = String::new();
|
||||
@ -627,6 +802,7 @@ fn stringify_definitions(definitions: &Vec<Definition>) -> String {
|
||||
match definition {
|
||||
Definition::Class(class) => res = format!("{res}{}", stringify_class(class)),
|
||||
Definition::Enum(enum_def) => res = format!("{res}{}", stringify_enum(enum_def)),
|
||||
Definition::Union(union_def) => res = format!("{res}{}", stringify_union(union_def)),
|
||||
Definition::Func(func) => res = format!("{res}{}", stringify_function(func)),
|
||||
Definition::Variable(variable) => {
|
||||
let variable_str = stringify_variable(variable);
|
||||
@ -660,7 +836,6 @@ fn avante_repo_map(lua: &Lua) -> LuaResult<LuaTable> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_rust() {
|
||||
let source = r#"
|
||||
// This is a test comment
|
||||
@ -718,6 +893,77 @@ mod tests {
|
||||
assert_eq!(stringified, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zig() {
|
||||
let source = r#"
|
||||
// This is a test comment
|
||||
pub const TEST_CONST: u32 = 1;
|
||||
pub var TEST_VAR: u32 = 2;
|
||||
const INNER_TEST_CONST: u32 = 3;
|
||||
var INNER_TEST_VAR: u32 = 4;
|
||||
pub const TestStruct = struct {
|
||||
test_field: []const u8,
|
||||
test_field2: u64,
|
||||
|
||||
pub fn test_method(_: *TestStruct, a: u32, b: u32) u32 {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
fn inner_test_method(_: *TestStruct, a: u32, b: u32) u32 {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
const InnerTestStruct = struct {
|
||||
test_field: []const u8,
|
||||
test_field2: u64,
|
||||
|
||||
pub fn test_method(_: *InnerTestStruct, a: u32, b: u32) u32 {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
fn inner_test_method(_: *InnerTestStruct, a: u32, b: u32) u32 {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
pub const TestEnum = enum {
|
||||
TestEnumField1,
|
||||
TestEnumField2,
|
||||
};
|
||||
const InnerTestEnum = enum {
|
||||
InnerTestEnumField1,
|
||||
InnerTestEnumField2,
|
||||
};
|
||||
|
||||
pub const TestUnion = union {
|
||||
TestUnionField1: u32,
|
||||
TestUnionField2: u64,
|
||||
};
|
||||
|
||||
const InnerTestUnion = union {
|
||||
InnerTestUnionField1: u32,
|
||||
InnerTestUnionField2: u64,
|
||||
};
|
||||
|
||||
pub fn test_fn(a: u32, b: u32) u32 {
|
||||
const inner_var_in_func = 1;
|
||||
const InnerStructInFunc = struct {
|
||||
c: u32,
|
||||
};
|
||||
_ = InnerStructInFunc;
|
||||
return a + b + inner_var_in_func;
|
||||
}
|
||||
fn inner_test_fn(a: u32, b: u32) u32 {
|
||||
return a + b;
|
||||
}
|
||||
"#;
|
||||
|
||||
let definitions = extract_definitions("zig", source).unwrap();
|
||||
let stringified = stringify_definitions(&definitions);
|
||||
println!("{stringified}");
|
||||
let expected = "var TEST_CONST:u32;var TEST_VAR:u32;func test_fn() -> void;class TestStruct{func test_method(_: *TestStruct, a: u32, b: u32) -> void;var test_field:[]const u8;var test_field2:u64;};enum TestEnum{TestEnumField1;TestEnumField2;};union TestUnion{TestUnionField1;TestUnionField2;};";
|
||||
assert_eq!(stringified, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_go() {
|
||||
let source = r#"
|
||||
|
Loading…
x
Reference in New Issue
Block a user