feat(repo-map): zig support (#663)

* feature: zig support for repo map

* Update crates/avante-repo-map/Cargo.toml

Co-authored-by: yetone <yetoneful@gmail.com>

* fix: update lint error

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>

---------

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Aaron Pham <Aaronpham0103@gmail.com>
Co-authored-by: yetone <yetoneful@gmail.com>
Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Boy Maas 2024-09-29 19:27:10 +02:00 committed by GitHub
parent d28fece472
commit bac46cee83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 288 additions and 7 deletions

11
Cargo.lock generated
View File

@ -48,6 +48,7 @@ dependencies = [
"tree-sitter-ruby", "tree-sitter-ruby",
"tree-sitter-rust", "tree-sitter-rust",
"tree-sitter-typescript", "tree-sitter-typescript",
"tree-sitter-zig",
] ]
[[package]] [[package]]
@ -1404,6 +1405,16 @@ dependencies = [
"tree-sitter-language", "tree-sitter-language",
] ]
[[package]]
name = "tree-sitter-zig"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2994e37b8ef1f715b931a5ff084a1b1713b1bc56e7aaebd148cc3efe0bf29ad9"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]] [[package]]
name = "typeid" name = "typeid"
version = "1.0.2" version = "1.0.2"

View File

@ -26,6 +26,7 @@ tree-sitter-c = "0.23"
tree-sitter-cpp = "0.23" tree-sitter-cpp = "0.23"
tree-sitter-lua = "0.2" tree-sitter-lua = "0.2"
tree-sitter-ruby = "0.23" tree-sitter-ruby = "0.23"
tree-sitter-zig = "1.0.2"
[lints] [lints]
workspace = true workspace = true

View File

@ -0,0 +1,23 @@
;; Capture functions, structs, methods, variable definitions, and unions in Zig
(variable_declaration (identifier)
(struct_declaration
(container_field) @class_variable))
(variable_declaration (identifier)
(struct_declaration
(function_declaration
name: (identifier) @method)))
(variable_declaration (identifier)
(enum_declaration
(container_field
type: (identifier) @enum_item)))
(variable_declaration (identifier)
(union_declaration
(container_field
name: (identifier) @union_item)))
(source_file (function_declaration) @function)
(source_file (variable_declaration (identifier) @variable))

View File

@ -26,6 +26,12 @@ pub struct Enum {
pub items: Vec<Variable>, pub items: Vec<Variable>,
} }
#[derive(Debug, Clone)]
pub struct Union {
pub name: String,
pub items: Vec<Variable>,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Variable { pub struct Variable {
pub name: String, pub name: String,
@ -38,6 +44,7 @@ pub enum Definition {
Class(Class), Class(Class),
Enum(Enum), Enum(Enum),
Variable(Variable), Variable(Variable),
Union(Union),
} }
fn get_ts_language(language: &str) -> Option<LanguageFn> { fn get_ts_language(language: &str) -> Option<LanguageFn> {
@ -51,6 +58,7 @@ fn get_ts_language(language: &str) -> Option<LanguageFn> {
"cpp" => Some(tree_sitter_cpp::LANGUAGE), "cpp" => Some(tree_sitter_cpp::LANGUAGE),
"lua" => Some(tree_sitter_lua::LANGUAGE), "lua" => Some(tree_sitter_lua::LANGUAGE),
"ruby" => Some(tree_sitter_ruby::LANGUAGE), "ruby" => Some(tree_sitter_ruby::LANGUAGE),
"zig" => Some(tree_sitter_zig::LANGUAGE),
_ => None, _ => None,
} }
} }
@ -62,6 +70,7 @@ const JAVASCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-javascript-d
const LUA_QUERY: &str = include_str!("../queries/tree-sitter-lua-defs.scm"); const LUA_QUERY: &str = include_str!("../queries/tree-sitter-lua-defs.scm");
const PYTHON_QUERY: &str = include_str!("../queries/tree-sitter-python-defs.scm"); const PYTHON_QUERY: &str = include_str!("../queries/tree-sitter-python-defs.scm");
const RUST_QUERY: &str = include_str!("../queries/tree-sitter-rust-defs.scm"); const RUST_QUERY: &str = include_str!("../queries/tree-sitter-rust-defs.scm");
const ZIG_QUERY: &str = include_str!("../queries/tree-sitter-zig-defs.scm");
const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-defs.scm"); const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-defs.scm");
const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm"); const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm");
@ -79,6 +88,7 @@ fn get_definitions_query(language: &str) -> Result<Query, String> {
"lua" => LUA_QUERY, "lua" => LUA_QUERY,
"python" => PYTHON_QUERY, "python" => PYTHON_QUERY,
"rust" => RUST_QUERY, "rust" => RUST_QUERY,
"zig" => ZIG_QUERY,
"typescript" => TYPESCRIPT_QUERY, "typescript" => TYPESCRIPT_QUERY,
"ruby" => RUBY_QUERY, "ruby" => RUBY_QUERY,
_ => return Err(format!("Unsupported language: {language}")), _ => return Err(format!("Unsupported language: {language}")),
@ -128,6 +138,49 @@ fn find_child_by_type<'a>(node: &'a Node, child_type: &str) -> Option<Node<'a>>
.find(|child| child.kind() == child_type) .find(|child| child.kind() == child_type)
} }
// Zig-specific function to find the parent variable declaration
fn zig_find_parent_variable_declaration_name<'a>(
node: &'a Node,
source: &'a [u8],
) -> Option<String> {
let vardec = find_ancestor_by_type(node, "variable_declaration");
if let Some(vardec) = vardec {
// Find the identifier child node, which represents the class name
let identifier_node = find_child_by_type(&vardec, "identifier");
if let Some(identifier_node) = identifier_node {
return Some(get_node_text(&identifier_node, source));
}
}
None
}
fn zig_is_declaration_public<'a>(node: &'a Node, declaration_type: &str, source: &'a [u8]) -> bool {
let declaration = find_ancestor_by_type(node, declaration_type);
if let Some(declaration) = declaration {
let declaration_text = get_node_text(&declaration, source);
return declaration_text.starts_with("pub");
}
false
}
fn zig_is_variable_declaration_public<'a>(node: &'a Node, source: &'a [u8]) -> bool {
zig_is_declaration_public(node, "variable_declaration", source)
}
fn zig_is_function_declaration_public<'a>(node: &'a Node, source: &'a [u8]) -> bool {
zig_is_declaration_public(node, "function_declaration", source)
}
fn zig_find_type_in_parent<'a>(node: &'a Node, source: &'a [u8]) -> Option<String> {
// First go to the parent and then get the child_by_field_name "type"
if let Some(parent) = node.parent() {
if let Some(type_node) = parent.child_by_field_name("type") {
return Some(get_node_text(&type_node, source));
}
}
None
}
fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String { fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String {
node.utf8_text(source).unwrap_or_default().to_string() node.utf8_text(source).unwrap_or_default().to_string()
} }
@ -176,6 +229,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let mut class_def_map: HashMap<String, RefCell<Class>> = HashMap::new(); let mut class_def_map: HashMap<String, RefCell<Class>> = HashMap::new();
let mut enum_def_map: HashMap<String, RefCell<Enum>> = HashMap::new(); let mut enum_def_map: HashMap<String, RefCell<Enum>> = HashMap::new();
let mut union_def_map: HashMap<String, RefCell<Union>> = HashMap::new();
let ensure_class_def = |name: &str, class_def_map: &mut HashMap<String, RefCell<Class>>| { let ensure_class_def = |name: &str, class_def_map: &mut HashMap<String, RefCell<Class>>| {
class_def_map.entry(name.to_string()).or_insert_with(|| { class_def_map.entry(name.to_string()).or_insert_with(|| {
@ -197,14 +251,26 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
}); });
}; };
let ensure_union_def = |name: &str, union_def_map: &mut HashMap<String, RefCell<Union>>| {
union_def_map.entry(name.to_string()).or_insert_with(|| {
RefCell::new(Union {
name: name.to_string(),
items: vec![],
})
});
};
for (m, _) in captures { for (m, _) in captures {
for capture in m.captures { for capture in m.captures {
let capture_name = &query.capture_names()[capture.index as usize]; let capture_name = &query.capture_names()[capture.index as usize];
let node = capture.node; let node = capture.node;
let node_text = node.utf8_text(source.as_bytes()).unwrap();
let name_node = node.child_by_field_name("name"); let name_node = node.child_by_field_name("name");
let name = name_node let name = name_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap()) .map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or(""); .unwrap_or(node_text);
match *capture_name { match *capture_name {
"class" => { "class" => {
if !name.is_empty() { if !name.is_empty() {
@ -235,7 +301,17 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if language == "rust" && !visibility_modifier.contains("pub") { if language == "rust" && !visibility_modifier.contains("pub") {
continue; continue;
} }
let enum_name = get_closest_ancestor_name(&node, source); if language == "zig"
&& !zig_is_variable_declaration_public(&node, source.as_bytes())
{
continue;
}
let mut enum_name = get_closest_ancestor_name(&node, source);
if language == "zig" {
enum_name =
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
.unwrap_or_default();
}
if !enum_name.is_empty() if !enum_name.is_empty()
&& language == "go" && language == "go"
&& !is_first_letter_uppercase(&enum_name) && !is_first_letter_uppercase(&enum_name)
@ -254,6 +330,28 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
}; };
enum_def.borrow_mut().items.push(variable); enum_def.borrow_mut().items.push(variable);
} }
"union_item" => {
if language != "zig" {
continue;
}
if !zig_is_variable_declaration_public(&node, source.as_bytes()) {
continue;
}
let union_name =
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
.unwrap_or_default();
ensure_union_def(&union_name, &mut union_def_map);
let union_def = union_def_map.get_mut(&union_name).unwrap();
let union_type_node = find_descendant_by_type(&node, "type_identifier");
let union_type = union_type_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
let variable = Variable {
name: name.to_string(),
value_type: union_type.to_string(),
};
union_def.borrow_mut().items.push(variable);
}
"method" => { "method" => {
let visibility_modifier_node = let visibility_modifier_node =
find_descendant_by_type(&node, "visibility_modifier"); find_descendant_by_type(&node, "visibility_modifier");
@ -263,10 +361,25 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if language == "rust" && !visibility_modifier.contains("pub") { if language == "rust" && !visibility_modifier.contains("pub") {
continue; continue;
} }
if language == "zig"
&& !(zig_is_function_declaration_public(&node, source.as_bytes())
&& zig_is_variable_declaration_public(&node, source.as_bytes()))
{
continue;
}
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
continue; continue;
} }
let params_node = node.child_by_field_name("parameters"); let mut params_node = node.child_by_field_name("parameters");
let function_node = find_ancestor_by_type(&node, "function_declaration");
if language == "zig" {
params_node = function_node
.as_ref()
.and_then(|n| find_child_by_type(n, "parameters"));
}
let params = params_node let params = params_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap()) .map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("()"); .unwrap_or("()");
@ -288,7 +401,10 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let impl_item_node = find_ancestor_by_type(&node, "impl_item"); let impl_item_node = find_ancestor_by_type(&node, "impl_item");
let receiver_node = node.child_by_field_name("receiver"); let receiver_node = node.child_by_field_name("receiver");
let class_name = if let Some(impl_item) = impl_item_node { let class_name = if language == "zig" {
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
.unwrap_or_default()
} else if let Some(impl_item) = impl_item_node {
let impl_type_node = impl_item.child_by_field_name("type"); let impl_type_node = impl_item.child_by_field_name("type");
impl_type_node impl_type_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap()) .map(|n| n.utf8_text(source.as_bytes()).unwrap())
@ -371,8 +487,22 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if language == "rust" && !visibility_modifier.contains("pub") { if language == "rust" && !visibility_modifier.contains("pub") {
continue; continue;
} }
let value_type = get_node_type(&node, source.as_bytes()); let value_type = get_node_type(&node, source.as_bytes());
let class_name = get_closest_ancestor_name(&node, source);
if language == "zig" {
// when top level class is not public, skip
if !zig_is_variable_declaration_public(&node, source.as_bytes()) {
continue;
}
}
let mut class_name = get_closest_ancestor_name(&node, source);
if language == "zig" {
class_name =
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
.unwrap_or_default();
}
if !class_name.is_empty() if !class_name.is_empty()
&& language == "go" && language == "go"
&& !is_first_letter_uppercase(&class_name) && !is_first_letter_uppercase(&class_name)
@ -399,9 +529,19 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let visibility_modifier = visibility_modifier_node let visibility_modifier = visibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap()) .map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or(""); .unwrap_or("");
if language == "rust" && !visibility_modifier.contains("pub") { if language == "rust" && !visibility_modifier.contains("pub") {
continue; continue;
} }
if language == "zig" {
let variable_declaration_text =
node.utf8_text(source.as_bytes()).unwrap_or("");
if !variable_declaration_text.contains("pub") {
continue;
}
}
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
continue; continue;
} }
@ -493,9 +633,17 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let visibility_modifier = visibility_modifier_node let visibility_modifier = visibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap()) .map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or(""); .unwrap_or("");
if language == "rust" && !visibility_modifier.contains("pub") { if language == "rust" && !visibility_modifier.contains("pub") {
continue; continue;
} }
if language == "zig"
&& !zig_is_variable_declaration_public(&node, source.as_bytes())
{
continue;
}
let impl_item_node = find_ancestor_by_type(&node, "impl_item") let impl_item_node = find_ancestor_by_type(&node, "impl_item")
.or_else(|| find_ancestor_by_type(&node, "class_declaration")) .or_else(|| find_ancestor_by_type(&node, "class_declaration"))
.or_else(|| find_ancestor_by_type(&node, "class_definition")); .or_else(|| find_ancestor_by_type(&node, "class_definition"));
@ -532,7 +680,15 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
continue; continue;
} }
} }
let value_type = get_node_type(&node, source.as_bytes());
let mut value_type = get_node_type(&node, source.as_bytes());
if language == "zig" {
if let Some(zig_type) = zig_find_type_in_parent(&node, source.as_bytes()) {
value_type = zig_type;
} else {
continue;
};
}
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
continue; continue;
} }
@ -563,6 +719,9 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
for (_, def) in enum_def_map { for (_, def) in enum_def_map {
definitions.push(Definition::Enum(def.into_inner())); definitions.push(Definition::Enum(def.into_inner()));
} }
for (_, def) in union_def_map {
definitions.push(Definition::Union(def.into_inner()));
}
Ok(definitions) Ok(definitions)
} }
@ -599,6 +758,14 @@ fn stringify_enum_item(item: &Variable) -> String {
format!("{res};") format!("{res};")
} }
fn stringify_union_item(item: &Variable) -> String {
let mut res = item.name.clone();
if !item.value_type.is_empty() {
res = format!("{res}:{}", item.value_type);
}
format!("{res};")
}
fn stringify_class(class: &Class) -> String { fn stringify_class(class: &Class) -> String {
let mut res = format!("class {}{{", class.name); let mut res = format!("class {}{{", class.name);
for method in &class.methods { for method in &class.methods {
@ -620,6 +787,14 @@ fn stringify_enum(enum_def: &Enum) -> String {
} }
format!("{res}}};") format!("{res}}};")
} }
fn stringify_union(union_def: &Union) -> String {
let mut res = format!("union {}{{", union_def.name);
for item in &union_def.items {
let item_str = stringify_union_item(item);
res = format!("{res}{item_str}");
}
format!("{res}}};")
}
fn stringify_definitions(definitions: &Vec<Definition>) -> String { fn stringify_definitions(definitions: &Vec<Definition>) -> String {
let mut res = String::new(); let mut res = String::new();
@ -627,6 +802,7 @@ fn stringify_definitions(definitions: &Vec<Definition>) -> String {
match definition { match definition {
Definition::Class(class) => res = format!("{res}{}", stringify_class(class)), Definition::Class(class) => res = format!("{res}{}", stringify_class(class)),
Definition::Enum(enum_def) => res = format!("{res}{}", stringify_enum(enum_def)), Definition::Enum(enum_def) => res = format!("{res}{}", stringify_enum(enum_def)),
Definition::Union(union_def) => res = format!("{res}{}", stringify_union(union_def)),
Definition::Func(func) => res = format!("{res}{}", stringify_function(func)), Definition::Func(func) => res = format!("{res}{}", stringify_function(func)),
Definition::Variable(variable) => { Definition::Variable(variable) => {
let variable_str = stringify_variable(variable); let variable_str = stringify_variable(variable);
@ -660,7 +836,6 @@ fn avante_repo_map(lua: &Lua) -> LuaResult<LuaTable> {
mod tests { mod tests {
use super::*; use super::*;
#[test]
fn test_rust() { fn test_rust() {
let source = r#" let source = r#"
// This is a test comment // This is a test comment
@ -718,6 +893,77 @@ mod tests {
assert_eq!(stringified, expected); assert_eq!(stringified, expected);
} }
#[test]
fn test_zig() {
let source = r#"
// This is a test comment
pub const TEST_CONST: u32 = 1;
pub var TEST_VAR: u32 = 2;
const INNER_TEST_CONST: u32 = 3;
var INNER_TEST_VAR: u32 = 4;
pub const TestStruct = struct {
test_field: []const u8,
test_field2: u64,
pub fn test_method(_: *TestStruct, a: u32, b: u32) u32 {
return a + b;
}
fn inner_test_method(_: *TestStruct, a: u32, b: u32) u32 {
return a + b;
}
};
const InnerTestStruct = struct {
test_field: []const u8,
test_field2: u64,
pub fn test_method(_: *InnerTestStruct, a: u32, b: u32) u32 {
return a + b;
}
fn inner_test_method(_: *InnerTestStruct, a: u32, b: u32) u32 {
return a + b;
}
};
pub const TestEnum = enum {
TestEnumField1,
TestEnumField2,
};
const InnerTestEnum = enum {
InnerTestEnumField1,
InnerTestEnumField2,
};
pub const TestUnion = union {
TestUnionField1: u32,
TestUnionField2: u64,
};
const InnerTestUnion = union {
InnerTestUnionField1: u32,
InnerTestUnionField2: u64,
};
pub fn test_fn(a: u32, b: u32) u32 {
const inner_var_in_func = 1;
const InnerStructInFunc = struct {
c: u32,
};
_ = InnerStructInFunc;
return a + b + inner_var_in_func;
}
fn inner_test_fn(a: u32, b: u32) u32 {
return a + b;
}
"#;
let definitions = extract_definitions("zig", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
let expected = "var TEST_CONST:u32;var TEST_VAR:u32;func test_fn() -> void;class TestStruct{func test_method(_: *TestStruct, a: u32, b: u32) -> void;var test_field:[]const u8;var test_field2:u64;};enum TestEnum{TestEnumField1;TestEnumField2;};union TestUnion{TestUnionField1;TestUnionField2;};";
assert_eq!(stringified, expected);
}
#[test] #[test]
fn test_go() { fn test_go() {
let source = r#" let source = r#"