feat(repo-map): C++ improvements (#734)
This commit is contained in:
parent
bbfc315eed
commit
9de95f9e02
@ -1,11 +1,31 @@
|
||||
;; Capture extern functions, variables, public classes, and methods
|
||||
(function_definition
|
||||
(storage_class_specifier) @extern
|
||||
) @function
|
||||
;; Capture functions, variables, nammespaces, classes, methods, and enums
|
||||
(namespace_definition) @namespace
|
||||
(function_definition) @function
|
||||
(class_specifier) @class
|
||||
(class_specifier
|
||||
(public) @class
|
||||
(function_definition) @method
|
||||
) @class
|
||||
(declaration
|
||||
(storage_class_specifier) @extern
|
||||
) @variable
|
||||
body: (field_declaration_list
|
||||
(declaration
|
||||
declarator: (function_declarator))? @method
|
||||
(field_declaration
|
||||
declarator: (function_declarator))? @method
|
||||
(function_definition)? @method
|
||||
(function_declarator)? @method
|
||||
(field_declaration
|
||||
declarator: (field_identifier))? @class_variable
|
||||
)
|
||||
)
|
||||
(struct_specifier) @struct
|
||||
(struct_specifier
|
||||
body: (field_declaration_list
|
||||
(declaration
|
||||
declarator: (function_declarator))? @method
|
||||
(field_declaration
|
||||
declarator: (function_declarator))? @method
|
||||
(function_definition)? @method
|
||||
(function_declarator)? @method
|
||||
(field_declaration
|
||||
declarator: (field_identifier))? @class_variable
|
||||
)
|
||||
)
|
||||
((declaration type: (_))) @variable
|
||||
(enumerator_list ((enumerator) @enum_item))
|
||||
|
@ -1,6 +1,6 @@
|
||||
use mlua::prelude::*;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::BTreeMap;
|
||||
use tree_sitter::{Node, Parser, Query, QueryCursor};
|
||||
use tree_sitter_language::LanguageFn;
|
||||
|
||||
@ -45,6 +45,7 @@ pub enum Definition {
|
||||
Enum(Enum),
|
||||
Variable(Variable),
|
||||
Union(Union),
|
||||
// TODO: Namespace support
|
||||
}
|
||||
|
||||
fn get_ts_language(language: &str) -> Option<LanguageFn> {
|
||||
@ -227,11 +228,11 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
let mut query_cursor = QueryCursor::new();
|
||||
let captures = query_cursor.captures(&query, root_node, source.as_bytes());
|
||||
|
||||
let mut class_def_map: HashMap<String, RefCell<Class>> = HashMap::new();
|
||||
let mut enum_def_map: HashMap<String, RefCell<Enum>> = HashMap::new();
|
||||
let mut union_def_map: HashMap<String, RefCell<Union>> = HashMap::new();
|
||||
let mut class_def_map: BTreeMap<String, RefCell<Class>> = BTreeMap::new();
|
||||
let mut enum_def_map: BTreeMap<String, RefCell<Enum>> = BTreeMap::new();
|
||||
let mut union_def_map: BTreeMap<String, RefCell<Union>> = BTreeMap::new();
|
||||
|
||||
let ensure_class_def = |name: &str, class_def_map: &mut HashMap<String, RefCell<Class>>| {
|
||||
let ensure_class_def = |name: &str, class_def_map: &mut BTreeMap<String, RefCell<Class>>| {
|
||||
class_def_map.entry(name.to_string()).or_insert_with(|| {
|
||||
RefCell::new(Class {
|
||||
name: name.to_string(),
|
||||
@ -242,7 +243,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
});
|
||||
};
|
||||
|
||||
let ensure_enum_def = |name: &str, enum_def_map: &mut HashMap<String, RefCell<Enum>>| {
|
||||
let ensure_enum_def = |name: &str, enum_def_map: &mut BTreeMap<String, RefCell<Enum>>| {
|
||||
enum_def_map.entry(name.to_string()).or_insert_with(|| {
|
||||
RefCell::new(Enum {
|
||||
name: name.to_string(),
|
||||
@ -251,7 +252,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
});
|
||||
};
|
||||
|
||||
let ensure_union_def = |name: &str, union_def_map: &mut HashMap<String, RefCell<Union>>| {
|
||||
let ensure_union_def = |name: &str, union_def_map: &mut BTreeMap<String, RefCell<Union>>| {
|
||||
union_def_map.entry(name.to_string()).or_insert_with(|| {
|
||||
RefCell::new(Union {
|
||||
name: name.to_string(),
|
||||
@ -260,30 +261,80 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
});
|
||||
};
|
||||
|
||||
// Sometimes, multiple queries capture the same node with the same capture name.
|
||||
// We need to ensure that we only add the node to the definition map once.
|
||||
let mut captured_nodes: BTreeMap<String, Vec<usize>> = BTreeMap::new();
|
||||
|
||||
for (m, _) in captures {
|
||||
for capture in m.captures {
|
||||
let capture_name = &query.capture_names()[capture.index as usize];
|
||||
let node = capture.node;
|
||||
let node_text = node.utf8_text(source.as_bytes()).unwrap();
|
||||
|
||||
let name_node = node.child_by_field_name("name");
|
||||
let name = name_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or(node_text);
|
||||
let node_id = node.id();
|
||||
if captured_nodes
|
||||
.get(*capture_name)
|
||||
.map_or(false, |v| v.contains(&node_id))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
captured_nodes
|
||||
.entry(String::from(*capture_name))
|
||||
.or_default()
|
||||
.push(node_id);
|
||||
|
||||
let name = match language {
|
||||
"cpp" => {
|
||||
if *capture_name == "class" {
|
||||
node.child_by_field_name("name")
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or(node_text)
|
||||
.to_string()
|
||||
} else {
|
||||
let ident = find_descendant_by_type(&node, "field_identifier")
|
||||
.or_else(|| find_descendant_by_type(&node, "operator_name"))
|
||||
.or_else(|| find_descendant_by_type(&node, "identifier"))
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap());
|
||||
if let Some(ident) = ident {
|
||||
let scope = node
|
||||
.child_by_field_name("declarator")
|
||||
.and_then(|n| n.child_by_field_name("declarator"))
|
||||
.and_then(|n| n.child_by_field_name("scope"));
|
||||
|
||||
if let Some(scope_node) = scope {
|
||||
format!(
|
||||
"{}::{}",
|
||||
scope_node.utf8_text(source.as_bytes()).unwrap(),
|
||||
ident
|
||||
)
|
||||
} else {
|
||||
ident.to_string()
|
||||
}
|
||||
} else {
|
||||
node_text.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => node
|
||||
.child_by_field_name("name")
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or(node_text)
|
||||
.to_string(),
|
||||
};
|
||||
|
||||
match *capture_name {
|
||||
"class" => {
|
||||
if !name.is_empty() {
|
||||
if language == "go" && !is_first_letter_uppercase(name) {
|
||||
if language == "go" && !is_first_letter_uppercase(&name) {
|
||||
continue;
|
||||
}
|
||||
ensure_class_def(name, &mut class_def_map);
|
||||
ensure_class_def(&name, &mut class_def_map);
|
||||
let visibility_modifier_node =
|
||||
find_child_by_type(&node, "visibility_modifier");
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("");
|
||||
let class_def = class_def_map.get_mut(name).unwrap();
|
||||
let class_def = class_def_map.get_mut(&name).unwrap();
|
||||
class_def.borrow_mut().visibility_modifier =
|
||||
if visibility_modifier.is_empty() {
|
||||
None
|
||||
@ -353,6 +404,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
union_def.borrow_mut().items.push(variable);
|
||||
}
|
||||
"method" => {
|
||||
// TODO: C++: Skip private/protected class/struct methods
|
||||
let visibility_modifier_node =
|
||||
find_descendant_by_type(&node, "visibility_modifier");
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
@ -367,11 +419,18 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
|
||||
if language == "cpp"
|
||||
&& find_descendant_by_type(&node, "destructor_name").is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
let mut params_node = node.child_by_field_name("parameters");
|
||||
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) {
|
||||
continue;
|
||||
}
|
||||
let mut params_node = node
|
||||
.child_by_field_name("parameters")
|
||||
.or_else(|| find_descendant_by_type(&node, "parameter_list"));
|
||||
|
||||
let function_node = find_ancestor_by_type(&node, "function_declaration");
|
||||
if language == "zig" {
|
||||
@ -383,7 +442,23 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
let params = params_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("()");
|
||||
let mut return_type_node = node.child_by_field_name("return_type");
|
||||
let mut return_type_node = match language {
|
||||
"cpp" => node.child_by_field_name("type"),
|
||||
_ => node.child_by_field_name("return_type"),
|
||||
};
|
||||
if language == "cpp" {
|
||||
let class_specifier_node = find_ancestor_by_type(&node, "class_specifier");
|
||||
let type_identifier_node =
|
||||
class_specifier_node.and_then(|n| n.child_by_field_name("name"));
|
||||
|
||||
if let Some(type_identifier_node) = type_identifier_node {
|
||||
let type_identifier_text =
|
||||
type_identifier_node.utf8_text(source.as_bytes()).unwrap();
|
||||
if name == type_identifier_text {
|
||||
return_type_node = Some(type_identifier_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
if return_type_node.is_none() {
|
||||
return_type_node = node.child_by_field_name("result");
|
||||
}
|
||||
@ -404,6 +479,13 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
let class_name = if language == "zig" {
|
||||
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
|
||||
.unwrap_or_default()
|
||||
} else if language == "cpp" {
|
||||
find_ancestor_by_type(&node, "class_specifier")
|
||||
.or_else(|| find_ancestor_by_type(&node, "struct_specifier"))
|
||||
.and_then(|n| n.child_by_field_name("name"))
|
||||
.and_then(|n| n.utf8_text(source.as_bytes()).ok())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
} else if let Some(impl_item) = impl_item_node {
|
||||
let impl_type_node = impl_item.child_by_field_name("type");
|
||||
impl_type_node
|
||||
@ -479,6 +561,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
class_def.borrow_mut().properties.push(variable);
|
||||
}
|
||||
"class_variable" => {
|
||||
// TODO: C++: Skip private/protected class/struct variables
|
||||
let visibility_modifier_node =
|
||||
find_descendant_by_type(&node, "visibility_modifier");
|
||||
let visibility_modifier = visibility_modifier_node
|
||||
@ -498,6 +581,14 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
}
|
||||
|
||||
let mut class_name = get_closest_ancestor_name(&node, source);
|
||||
if language == "cpp" {
|
||||
class_name = find_ancestor_by_type(&node, "class_specifier")
|
||||
.or_else(|| find_ancestor_by_type(&node, "struct_specifier"))
|
||||
.and_then(|n| n.child_by_field_name("name"))
|
||||
.and_then(|n| n.utf8_text(source.as_bytes()).ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
}
|
||||
if language == "zig" {
|
||||
class_name =
|
||||
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
|
||||
@ -512,7 +603,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
if class_name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) {
|
||||
continue;
|
||||
}
|
||||
ensure_class_def(&class_name, &mut class_def_map);
|
||||
@ -542,27 +633,40 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
}
|
||||
}
|
||||
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) {
|
||||
continue;
|
||||
}
|
||||
let impl_item_node = find_ancestor_by_type(&node, "impl_item");
|
||||
if impl_item_node.is_some() {
|
||||
continue;
|
||||
}
|
||||
let class_specifier_node = find_ancestor_by_type(&node, "class_specifier");
|
||||
if class_specifier_node.is_some() {
|
||||
continue;
|
||||
}
|
||||
let struct_specifier_node = find_ancestor_by_type(&node, "struct_specifier");
|
||||
if struct_specifier_node.is_some() {
|
||||
continue;
|
||||
}
|
||||
let function_node = find_ancestor_by_type(&node, "function_declaration")
|
||||
.or_else(|| find_ancestor_by_type(&node, "function_definition"));
|
||||
if function_node.is_some() {
|
||||
continue;
|
||||
}
|
||||
let params_node = node.child_by_field_name("parameters");
|
||||
let params_node = node
|
||||
.child_by_field_name("parameters")
|
||||
.or_else(|| find_descendant_by_type(&node, "parameter_list"));
|
||||
let params = params_node
|
||||
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
|
||||
.unwrap_or("()");
|
||||
let mut return_type_node = node.child_by_field_name("return_type");
|
||||
if return_type_node.is_none() {
|
||||
return_type_node = node.child_by_field_name("result");
|
||||
}
|
||||
|
||||
let mut return_type = "void".to_string();
|
||||
let return_type_node = match language {
|
||||
"cpp" => node.child_by_field_name("type"),
|
||||
_ => node
|
||||
.child_by_field_name("return_type")
|
||||
.or_else(|| node.child_by_field_name("result")),
|
||||
};
|
||||
if return_type_node.is_some() {
|
||||
return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes());
|
||||
if return_type.is_empty() {
|
||||
@ -689,7 +793,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
|
||||
continue;
|
||||
};
|
||||
}
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
|
||||
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) {
|
||||
continue;
|
||||
}
|
||||
let variable = Variable {
|
||||
@ -836,6 +940,7 @@ fn avante_repo_map(lua: &Lua) -> LuaResult<LuaTable> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_rust() {
|
||||
let source = r#"
|
||||
// This is a test comment
|
||||
@ -1201,6 +1306,72 @@ mod tests {
|
||||
assert_eq!(stringified, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cpp() {
|
||||
let source = r#"
|
||||
// This is a test comment
|
||||
#include <iostream>
|
||||
|
||||
namespace {
|
||||
constexpr int TEST_CONSTEXPR = 1;
|
||||
const int TEST_CONST = 1;
|
||||
}; // namespace
|
||||
|
||||
int test_var = 2;
|
||||
|
||||
int TestFunc(bool b) { return b ? 42 : -1; }
|
||||
|
||||
template <typename T> class TestClass {
|
||||
public:
|
||||
TestClass();
|
||||
TestClass(T a, T b);
|
||||
~TestClass();
|
||||
bool operator==(const TestClass &other);
|
||||
T testMethod(T x, T y) { return x + y; }
|
||||
T c;
|
||||
|
||||
private:
|
||||
void privateMethod();
|
||||
T a = 0;
|
||||
T b;
|
||||
};
|
||||
|
||||
struct TestStruct {
|
||||
public:
|
||||
TestStruct(int a, int b);
|
||||
~TestStruct();
|
||||
bool operator==(const TestStruct &other);
|
||||
int testMethod(int x, int y) { return x + y; }
|
||||
static int c;
|
||||
|
||||
private:
|
||||
int a = 0;
|
||||
int b;
|
||||
};
|
||||
|
||||
bool TestStruct::operator==(const TestStruct &other) { return true; }
|
||||
|
||||
int TestStruct::c = 0;
|
||||
|
||||
int testFunction(int a, int b) { return a + b; }
|
||||
|
||||
namespace TestNamespace {
|
||||
class InnerClass {
|
||||
public:
|
||||
bool innerMethod(int a) const;
|
||||
};
|
||||
bool InnerClass::innerMethod(int a) const { return doSomething(a * 2); }
|
||||
} // namespace TestNamespace
|
||||
|
||||
enum TestEnum { ENUM_VALUE_1, ENUM_VALUE_2 };
|
||||
"#;
|
||||
let definitions = extract_definitions("cpp", source).unwrap();
|
||||
let stringified = stringify_definitions(&definitions);
|
||||
println!("{}", stringified);
|
||||
let expected = "var TEST_CONSTEXPR:int;var TEST_CONST:int;var test_var:int;func TestFunc(bool b) -> int;func TestStruct::operator==(const TestStruct &other) -> bool;var TestStruct::c:int;func testFunction(int a, int b) -> int;func InnerClass::innerMethod(int a) -> bool;class InnerClass{func innerMethod(int a) -> bool;};class TestClass{func TestClass() -> TestClass;func operator==(const TestClass &other) -> bool;func testMethod(T x, T y) -> T;func privateMethod() -> void;func TestClass(T a, T b) -> TestClass;var c:T;var a:T;var b:T;};class TestStruct{func TestStruct(int a, int b) -> void;func operator==(const TestStruct &other) -> bool;func testMethod(int x, int y) -> int;var c:int;var a:int;var b:int;};enum TestEnum{ENUM_VALUE_1;ENUM_VALUE_2;};";
|
||||
assert_eq!(stringified, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsupported_language() {
|
||||
let source = "print('Hello, world!')";
|
||||
|
Loading…
x
Reference in New Issue
Block a user