feat(repo_map): add elixir support (#894)

This commit is contained in:
Radosław Woźniak 2024-11-24 10:29:30 +01:00 committed by GitHub
parent e60ccd2db4
commit 890fd92594
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 124 additions and 16 deletions

11
Cargo.lock generated
View File

@ -40,6 +40,7 @@ dependencies = [
"tree-sitter",
"tree-sitter-c",
"tree-sitter-cpp",
"tree-sitter-elixir",
"tree-sitter-go",
"tree-sitter-javascript",
"tree-sitter-language",
@ -1340,6 +1341,16 @@ dependencies = [
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-elixir"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97bf0efa4be41120018f23305b105ad4dfd3be1b7f302dc4071d0e6c2dec3a32"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-go"
version = "0.23.1"

View File

@ -28,6 +28,7 @@ tree-sitter-lua = "0.2"
tree-sitter-ruby = "0.23"
tree-sitter-zig = "1.0.2"
tree-sitter-scala = "0.23"
tree-sitter-elixir = "0.3.1"
[lints]
workspace = true

View File

@ -0,0 +1,21 @@
; * modules and protocols
(call
target: (identifier) @ignore
(arguments (alias) @class)
(#match? @ignore "^(defmodule|defprotocol)$"))
; * functions
(call
target: (identifier) @ignore
(arguments
[
; zero-arity functions with no parentheses
(identifier) @method
; regular function clause
(call target: (identifier) @method)
; function clause with a guard clause
(binary_operator
left: (call target: (identifier) @method)
operator: "when")
])
(#match? @ignore "^(def|defdelegate|defguard|defn)$"))

View File

@ -14,6 +14,7 @@ pub struct Func {
#[derive(Debug, Clone)]
pub struct Class {
pub type_name: String,
pub name: String,
pub methods: Vec<Func>,
pub properties: Vec<Variable>,
@ -61,6 +62,7 @@ fn get_ts_language(language: &str) -> Option<LanguageFn> {
"ruby" => Some(tree_sitter_ruby::LANGUAGE),
"zig" => Some(tree_sitter_zig::LANGUAGE),
"scala" => Some(tree_sitter_scala::LANGUAGE),
"elixir" => Some(tree_sitter_elixir::LANGUAGE),
_ => None,
}
}
@ -76,6 +78,7 @@ const ZIG_QUERY: &str = include_str!("../queries/tree-sitter-zig-defs.scm");
const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-defs.scm");
const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm");
const SCALA_QUERY: &str = include_str!("../queries/tree-sitter-scala-defs.scm");
const ELIXIR_QUERY: &str = include_str!("../queries/tree-sitter-elixir-defs.scm");
fn get_definitions_query(language: &str) -> Result<Query, String> {
let ts_language = get_ts_language(language);
@ -95,6 +98,7 @@ fn get_definitions_query(language: &str) -> Result<Query, String> {
"typescript" => TYPESCRIPT_QUERY,
"ruby" => RUBY_QUERY,
"scala" => SCALA_QUERY,
"elixir" => ELIXIR_QUERY,
_ => return Err(format!("Unsupported language: {language}")),
};
let query = Query::new(&ts_language.into(), contents)
@ -185,6 +189,23 @@ fn zig_find_type_in_parent<'a>(node: &'a Node, source: &'a [u8]) -> Option<Strin
None
}
fn ex_find_parent_module_declaration_name<'a>(node: &'a Node, source: &'a [u8],) -> Option<String> {
let mut parent = node.parent();
while let Some(parent_node) = parent {
if parent_node.kind() == "call" {
let text = get_node_text(&parent_node, source);
if text.starts_with("defmodule ") {
let arguments_node = find_child_by_type(&parent_node, "arguments");
if let Some(arguments_node) = arguments_node {
return Some(get_node_text(&arguments_node, source));
}
}
}
parent = parent_node.parent();
}
None
}
fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String {
node.utf8_text(source).unwrap_or_default().to_string()
}
@ -235,9 +256,14 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let mut enum_def_map: BTreeMap<String, RefCell<Enum>> = BTreeMap::new();
let mut union_def_map: BTreeMap<String, RefCell<Union>> = BTreeMap::new();
let ensure_class_def = |name: &str, class_def_map: &mut BTreeMap<String, RefCell<Class>>| {
let ensure_class_def = |language: &str, name: &str, class_def_map: &mut BTreeMap<String, RefCell<Class>>| {
let mut type_name = "class";
if language == "elixir" {
type_name = "module"
}
class_def_map.entry(name.to_string()).or_insert_with(|| {
RefCell::new(Class {
type_name: type_name.to_string(),
name: name.to_string(),
methods: vec![],
properties: vec![],
@ -337,7 +363,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if language == "go" && !is_first_letter_uppercase(&name) {
continue;
}
ensure_class_def(&name, &mut class_def_map);
ensure_class_def(&language, &name, &mut class_def_map);
let visibility_modifier_node =
find_child_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
@ -449,12 +475,18 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
.child_by_field_name("parameters")
.or_else(|| find_descendant_by_type(&node, "parameter_list"));
let function_node = find_ancestor_by_type(&node, "function_declaration");
let zig_function_node = find_ancestor_by_type(&node, "function_declaration");
if language == "zig" {
params_node = function_node
params_node = zig_function_node
.as_ref()
.and_then(|n| find_child_by_type(n, "parameters"));
}
let ex_function_node = find_ancestor_by_type(&node, "call");
if language == "elixir" {
params_node = ex_function_node
.as_ref()
.and_then(|n| find_child_by_type(n, "arguments"));
}
let params = params_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
@ -480,6 +512,9 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
return_type_node = node.child_by_field_name("result");
}
let mut return_type = "void".to_string();
if language == "elixir" {
return_type = "".to_string();
}
if return_type_node.is_some() {
return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes());
if return_type.is_empty() {
@ -496,6 +531,9 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let class_name = if language == "zig" {
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
.unwrap_or_default()
} else if language == "elixir" {
ex_find_parent_module_declaration_name(&node, source.as_bytes())
.unwrap_or_default()
} else if language == "cpp" {
find_ancestor_by_type(&node, "class_specifier")
.or_else(|| find_ancestor_by_type(&node, "struct_specifier"))
@ -524,7 +562,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
continue;
}
ensure_class_def(&class_name, &mut class_def_map);
ensure_class_def(&language, &class_name, &mut class_def_map);
let class_def = class_def_map.get_mut(&class_name).unwrap();
let accessibility_modifier_node =
@ -569,7 +607,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if class_name.is_empty() {
continue;
}
ensure_class_def(&class_name, &mut class_def_map);
ensure_class_def(&language, &class_name, &mut class_def_map);
let class_def = class_def_map.get_mut(&class_name).unwrap();
let variable = Variable {
name: left.to_string(),
@ -623,7 +661,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) {
continue;
}
ensure_class_def(&class_name, &mut class_def_map);
ensure_class_def(&language, &class_name, &mut class_def_map);
let class_def = class_def_map.get_mut(&class_name).unwrap();
let variable = Variable {
name: name.to_string(),
@ -888,7 +926,7 @@ fn stringify_union_item(item: &Variable) -> String {
}
fn stringify_class(class: &Class) -> String {
let mut res = format!("class {}{{", class.name);
let mut res = format!("{} {}{{", class.type_name, class.name);
for method in &class.methods {
let method_str = stringify_function(method);
res = format!("{res}{method_str}");
@ -1428,6 +1466,45 @@ mod tests {
assert_eq!(stringified, expected);
}
#[test]
fn test_elixir() {
let source = r#"
defmodule TestModule do
@moduledoc """
This is a test module
"""
@test_const "test"
@other_const 123
def test_func(a, b) do
a + b
end
defp private_func(x) do
x * 2
end
defmacro test_macro(expr) do
quote do
unquote(expr)
end
end
end
defmodule AnotherModule do
def another_func() do
:ok
end
end
"#;
let definitions = extract_definitions("elixir", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
let expected = "module AnotherModule{func another_func();};module TestModule{func test_func(a, b);};";
assert_eq!(stringified, expected);
}
#[test]
fn test_unsupported_language() {
let source = "print('Hello, world!')";

View File

@ -2,7 +2,6 @@ local Popup = require("nui.popup")
local Utils = require("avante.utils")
local event = require("nui.utils.autocmd").event
local Config = require("avante.config")
local fn = vim.fn
local filetype_map = {
["javascriptreact"] = "javascript",
@ -34,14 +33,13 @@ function RepoMap.get_ts_lang(filepath)
end
function RepoMap.get_filetype(filepath)
local filetype = vim.filetype.match({ filename = filepath })
-- TypeScript files are sometimes not detected correctly
-- Some files are sometimes not detected correctly when buffer is not included
-- https://github.com/neovim/neovim/issues/27265
if not filetype then
local ext = fn.fnamemodify(filepath, ":e")
if ext == "tsx" then filetype = "typescriptreact" end
if ext == "ts" then filetype = "typescript" end
end
local buf = vim.api.nvim_create_buf(false, true)
local filetype = vim.filetype.match({ filename = filepath, buf = buf })
vim.api.nvim_buf_delete(buf, { force = true })
return filetype
end