feat(repo-map): C++ improvements (#734)

This commit is contained in:
Maddison Hellstrom 2024-10-21 02:20:18 -07:00 committed by GitHub
parent bbfc315eed
commit 9de95f9e02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 227 additions and 36 deletions

View File

@ -1,11 +1,31 @@
;; Capture extern functions, variables, public classes, and methods
(function_definition
(storage_class_specifier) @extern
) @function
;; Capture functions, variables, nammespaces, classes, methods, and enums
(namespace_definition) @namespace
(function_definition) @function
(class_specifier) @class
(class_specifier
(public) @class
(function_definition) @method
) @class
(declaration
(storage_class_specifier) @extern
) @variable
body: (field_declaration_list
(declaration
declarator: (function_declarator))? @method
(field_declaration
declarator: (function_declarator))? @method
(function_definition)? @method
(function_declarator)? @method
(field_declaration
declarator: (field_identifier))? @class_variable
)
)
(struct_specifier) @struct
(struct_specifier
body: (field_declaration_list
(declaration
declarator: (function_declarator))? @method
(field_declaration
declarator: (function_declarator))? @method
(function_definition)? @method
(function_declarator)? @method
(field_declaration
declarator: (field_identifier))? @class_variable
)
)
((declaration type: (_))) @variable
(enumerator_list ((enumerator) @enum_item))

View File

@ -1,6 +1,6 @@
use mlua::prelude::*;
use std::cell::RefCell;
use std::collections::HashMap;
use std::collections::BTreeMap;
use tree_sitter::{Node, Parser, Query, QueryCursor};
use tree_sitter_language::LanguageFn;
@ -45,6 +45,7 @@ pub enum Definition {
Enum(Enum),
Variable(Variable),
Union(Union),
// TODO: Namespace support
}
fn get_ts_language(language: &str) -> Option<LanguageFn> {
@ -227,11 +228,11 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let mut query_cursor = QueryCursor::new();
let captures = query_cursor.captures(&query, root_node, source.as_bytes());
let mut class_def_map: HashMap<String, RefCell<Class>> = HashMap::new();
let mut enum_def_map: HashMap<String, RefCell<Enum>> = HashMap::new();
let mut union_def_map: HashMap<String, RefCell<Union>> = HashMap::new();
let mut class_def_map: BTreeMap<String, RefCell<Class>> = BTreeMap::new();
let mut enum_def_map: BTreeMap<String, RefCell<Enum>> = BTreeMap::new();
let mut union_def_map: BTreeMap<String, RefCell<Union>> = BTreeMap::new();
let ensure_class_def = |name: &str, class_def_map: &mut HashMap<String, RefCell<Class>>| {
let ensure_class_def = |name: &str, class_def_map: &mut BTreeMap<String, RefCell<Class>>| {
class_def_map.entry(name.to_string()).or_insert_with(|| {
RefCell::new(Class {
name: name.to_string(),
@ -242,7 +243,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
});
};
let ensure_enum_def = |name: &str, enum_def_map: &mut HashMap<String, RefCell<Enum>>| {
let ensure_enum_def = |name: &str, enum_def_map: &mut BTreeMap<String, RefCell<Enum>>| {
enum_def_map.entry(name.to_string()).or_insert_with(|| {
RefCell::new(Enum {
name: name.to_string(),
@ -251,7 +252,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
});
};
let ensure_union_def = |name: &str, union_def_map: &mut HashMap<String, RefCell<Union>>| {
let ensure_union_def = |name: &str, union_def_map: &mut BTreeMap<String, RefCell<Union>>| {
union_def_map.entry(name.to_string()).or_insert_with(|| {
RefCell::new(Union {
name: name.to_string(),
@ -260,30 +261,80 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
});
};
// Sometimes, multiple queries capture the same node with the same capture name.
// We need to ensure that we only add the node to the definition map once.
let mut captured_nodes: BTreeMap<String, Vec<usize>> = BTreeMap::new();
for (m, _) in captures {
for capture in m.captures {
let capture_name = &query.capture_names()[capture.index as usize];
let node = capture.node;
let node_text = node.utf8_text(source.as_bytes()).unwrap();
let name_node = node.child_by_field_name("name");
let name = name_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or(node_text);
let node_id = node.id();
if captured_nodes
.get(*capture_name)
.map_or(false, |v| v.contains(&node_id))
{
continue;
}
captured_nodes
.entry(String::from(*capture_name))
.or_default()
.push(node_id);
let name = match language {
"cpp" => {
if *capture_name == "class" {
node.child_by_field_name("name")
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or(node_text)
.to_string()
} else {
let ident = find_descendant_by_type(&node, "field_identifier")
.or_else(|| find_descendant_by_type(&node, "operator_name"))
.or_else(|| find_descendant_by_type(&node, "identifier"))
.map(|n| n.utf8_text(source.as_bytes()).unwrap());
if let Some(ident) = ident {
let scope = node
.child_by_field_name("declarator")
.and_then(|n| n.child_by_field_name("declarator"))
.and_then(|n| n.child_by_field_name("scope"));
if let Some(scope_node) = scope {
format!(
"{}::{}",
scope_node.utf8_text(source.as_bytes()).unwrap(),
ident
)
} else {
ident.to_string()
}
} else {
node_text.to_string()
}
}
}
_ => node
.child_by_field_name("name")
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or(node_text)
.to_string(),
};
match *capture_name {
"class" => {
if !name.is_empty() {
if language == "go" && !is_first_letter_uppercase(name) {
if language == "go" && !is_first_letter_uppercase(&name) {
continue;
}
ensure_class_def(name, &mut class_def_map);
ensure_class_def(&name, &mut class_def_map);
let visibility_modifier_node =
find_child_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
let class_def = class_def_map.get_mut(name).unwrap();
let class_def = class_def_map.get_mut(&name).unwrap();
class_def.borrow_mut().visibility_modifier =
if visibility_modifier.is_empty() {
None
@ -353,6 +404,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
union_def.borrow_mut().items.push(variable);
}
"method" => {
// TODO: C++: Skip private/protected class/struct methods
let visibility_modifier_node =
find_descendant_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
@ -367,11 +419,18 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
{
continue;
}
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
if language == "cpp"
&& find_descendant_by_type(&node, "destructor_name").is_some()
{
continue;
}
let mut params_node = node.child_by_field_name("parameters");
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) {
continue;
}
let mut params_node = node
.child_by_field_name("parameters")
.or_else(|| find_descendant_by_type(&node, "parameter_list"));
let function_node = find_ancestor_by_type(&node, "function_declaration");
if language == "zig" {
@ -383,7 +442,23 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let params = params_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("()");
let mut return_type_node = node.child_by_field_name("return_type");
let mut return_type_node = match language {
"cpp" => node.child_by_field_name("type"),
_ => node.child_by_field_name("return_type"),
};
if language == "cpp" {
let class_specifier_node = find_ancestor_by_type(&node, "class_specifier");
let type_identifier_node =
class_specifier_node.and_then(|n| n.child_by_field_name("name"));
if let Some(type_identifier_node) = type_identifier_node {
let type_identifier_text =
type_identifier_node.utf8_text(source.as_bytes()).unwrap();
if name == type_identifier_text {
return_type_node = Some(type_identifier_node);
}
}
}
if return_type_node.is_none() {
return_type_node = node.child_by_field_name("result");
}
@ -404,6 +479,13 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let class_name = if language == "zig" {
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
.unwrap_or_default()
} else if language == "cpp" {
find_ancestor_by_type(&node, "class_specifier")
.or_else(|| find_ancestor_by_type(&node, "struct_specifier"))
.and_then(|n| n.child_by_field_name("name"))
.and_then(|n| n.utf8_text(source.as_bytes()).ok())
.unwrap_or("")
.to_string()
} else if let Some(impl_item) = impl_item_node {
let impl_type_node = impl_item.child_by_field_name("type");
impl_type_node
@ -479,6 +561,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
class_def.borrow_mut().properties.push(variable);
}
"class_variable" => {
// TODO: C++: Skip private/protected class/struct variables
let visibility_modifier_node =
find_descendant_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
@ -498,6 +581,14 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
}
let mut class_name = get_closest_ancestor_name(&node, source);
if language == "cpp" {
class_name = find_ancestor_by_type(&node, "class_specifier")
.or_else(|| find_ancestor_by_type(&node, "struct_specifier"))
.and_then(|n| n.child_by_field_name("name"))
.and_then(|n| n.utf8_text(source.as_bytes()).ok())
.unwrap_or("")
.to_string();
}
if language == "zig" {
class_name =
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
@ -512,7 +603,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if class_name.is_empty() {
continue;
}
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) {
continue;
}
ensure_class_def(&class_name, &mut class_def_map);
@ -542,27 +633,40 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
}
}
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) {
continue;
}
let impl_item_node = find_ancestor_by_type(&node, "impl_item");
if impl_item_node.is_some() {
continue;
}
let class_specifier_node = find_ancestor_by_type(&node, "class_specifier");
if class_specifier_node.is_some() {
continue;
}
let struct_specifier_node = find_ancestor_by_type(&node, "struct_specifier");
if struct_specifier_node.is_some() {
continue;
}
let function_node = find_ancestor_by_type(&node, "function_declaration")
.or_else(|| find_ancestor_by_type(&node, "function_definition"));
if function_node.is_some() {
continue;
}
let params_node = node.child_by_field_name("parameters");
let params_node = node
.child_by_field_name("parameters")
.or_else(|| find_descendant_by_type(&node, "parameter_list"));
let params = params_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("()");
let mut return_type_node = node.child_by_field_name("return_type");
if return_type_node.is_none() {
return_type_node = node.child_by_field_name("result");
}
let mut return_type = "void".to_string();
let return_type_node = match language {
"cpp" => node.child_by_field_name("type"),
_ => node
.child_by_field_name("return_type")
.or_else(|| node.child_by_field_name("result")),
};
if return_type_node.is_some() {
return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes());
if return_type.is_empty() {
@ -689,7 +793,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
continue;
};
}
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) {
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) {
continue;
}
let variable = Variable {
@ -836,6 +940,7 @@ fn avante_repo_map(lua: &Lua) -> LuaResult<LuaTable> {
mod tests {
use super::*;
#[test]
fn test_rust() {
let source = r#"
// This is a test comment
@ -1201,6 +1306,72 @@ mod tests {
assert_eq!(stringified, expected);
}
#[test]
fn test_cpp() {
let source = r#"
// This is a test comment
#include <iostream>
namespace {
constexpr int TEST_CONSTEXPR = 1;
const int TEST_CONST = 1;
}; // namespace
int test_var = 2;
int TestFunc(bool b) { return b ? 42 : -1; }
template <typename T> class TestClass {
public:
TestClass();
TestClass(T a, T b);
~TestClass();
bool operator==(const TestClass &other);
T testMethod(T x, T y) { return x + y; }
T c;
private:
void privateMethod();
T a = 0;
T b;
};
struct TestStruct {
public:
TestStruct(int a, int b);
~TestStruct();
bool operator==(const TestStruct &other);
int testMethod(int x, int y) { return x + y; }
static int c;
private:
int a = 0;
int b;
};
bool TestStruct::operator==(const TestStruct &other) { return true; }
int TestStruct::c = 0;
int testFunction(int a, int b) { return a + b; }
namespace TestNamespace {
class InnerClass {
public:
bool innerMethod(int a) const;
};
bool InnerClass::innerMethod(int a) const { return doSomething(a * 2); }
} // namespace TestNamespace
enum TestEnum { ENUM_VALUE_1, ENUM_VALUE_2 };
"#;
let definitions = extract_definitions("cpp", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{}", stringified);
let expected = "var TEST_CONSTEXPR:int;var TEST_CONST:int;var test_var:int;func TestFunc(bool b) -> int;func TestStruct::operator==(const TestStruct &other) -> bool;var TestStruct::c:int;func testFunction(int a, int b) -> int;func InnerClass::innerMethod(int a) -> bool;class InnerClass{func innerMethod(int a) -> bool;};class TestClass{func TestClass() -> TestClass;func operator==(const TestClass &other) -> bool;func testMethod(T x, T y) -> T;func privateMethod() -> void;func TestClass(T a, T b) -> TestClass;var c:T;var a:T;var b:T;};class TestStruct{func TestStruct(int a, int b) -> void;func operator==(const TestStruct &other) -> bool;func testMethod(int x, int y) -> int;var c:int;var a:int;var b:int;};enum TestEnum{ENUM_VALUE_1;ENUM_VALUE_2;};";
assert_eq!(stringified, expected);
}
#[test]
fn test_unsupported_language() {
let source = "print('Hello, world!')";