diff --git a/Cargo.lock b/Cargo.lock index c748bff..4e9ab8c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -40,6 +40,7 @@ dependencies = [ "tree-sitter", "tree-sitter-c", "tree-sitter-cpp", + "tree-sitter-elixir", "tree-sitter-go", "tree-sitter-javascript", "tree-sitter-language", @@ -1340,6 +1341,16 @@ dependencies = [ "tree-sitter-language", ] +[[package]] +name = "tree-sitter-elixir" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97bf0efa4be41120018f23305b105ad4dfd3be1b7f302dc4071d0e6c2dec3a32" +dependencies = [ + "cc", + "tree-sitter-language", +] + [[package]] name = "tree-sitter-go" version = "0.23.1" diff --git a/crates/avante-repo-map/Cargo.toml b/crates/avante-repo-map/Cargo.toml index 49a36f1..001c073 100644 --- a/crates/avante-repo-map/Cargo.toml +++ b/crates/avante-repo-map/Cargo.toml @@ -28,6 +28,7 @@ tree-sitter-lua = "0.2" tree-sitter-ruby = "0.23" tree-sitter-zig = "1.0.2" tree-sitter-scala = "0.23" +tree-sitter-elixir = "0.3.1" [lints] workspace = true diff --git a/crates/avante-repo-map/queries/tree-sitter-elixir-defs.scm b/crates/avante-repo-map/queries/tree-sitter-elixir-defs.scm new file mode 100644 index 0000000..822efc9 --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-elixir-defs.scm @@ -0,0 +1,21 @@ +; * modules and protocols +(call + target: (identifier) @ignore + (arguments (alias) @class) + (#match? @ignore "^(defmodule|defprotocol)$")) + +; * functions +(call + target: (identifier) @ignore + (arguments + [ + ; zero-arity functions with no parentheses + (identifier) @method + ; regular function clause + (call target: (identifier) @method) + ; function clause with a guard clause + (binary_operator + left: (call target: (identifier) @method) + operator: "when") + ]) + (#match? @ignore "^(def|defdelegate|defguard|defn)$")) diff --git a/crates/avante-repo-map/src/lib.rs b/crates/avante-repo-map/src/lib.rs index 69e5cc4..98e5b7f 100644 --- a/crates/avante-repo-map/src/lib.rs +++ b/crates/avante-repo-map/src/lib.rs @@ -14,6 +14,7 @@ pub struct Func { #[derive(Debug, Clone)] pub struct Class { + pub type_name: String, pub name: String, pub methods: Vec, pub properties: Vec, @@ -61,6 +62,7 @@ fn get_ts_language(language: &str) -> Option { "ruby" => Some(tree_sitter_ruby::LANGUAGE), "zig" => Some(tree_sitter_zig::LANGUAGE), "scala" => Some(tree_sitter_scala::LANGUAGE), + "elixir" => Some(tree_sitter_elixir::LANGUAGE), _ => None, } } @@ -76,6 +78,7 @@ const ZIG_QUERY: &str = include_str!("../queries/tree-sitter-zig-defs.scm"); const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-defs.scm"); const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm"); const SCALA_QUERY: &str = include_str!("../queries/tree-sitter-scala-defs.scm"); +const ELIXIR_QUERY: &str = include_str!("../queries/tree-sitter-elixir-defs.scm"); fn get_definitions_query(language: &str) -> Result { let ts_language = get_ts_language(language); @@ -95,6 +98,7 @@ fn get_definitions_query(language: &str) -> Result { "typescript" => TYPESCRIPT_QUERY, "ruby" => RUBY_QUERY, "scala" => SCALA_QUERY, + "elixir" => ELIXIR_QUERY, _ => return Err(format!("Unsupported language: {language}")), }; let query = Query::new(&ts_language.into(), contents) @@ -185,6 +189,23 @@ fn zig_find_type_in_parent<'a>(node: &'a Node, source: &'a [u8]) -> Option(node: &'a Node, source: &'a [u8],) -> Option { + let mut parent = node.parent(); + while let Some(parent_node) = parent { + if parent_node.kind() == "call" { + let text = get_node_text(&parent_node, source); + if text.starts_with("defmodule ") { + let arguments_node = find_child_by_type(&parent_node, "arguments"); + if let Some(arguments_node) = arguments_node { + return Some(get_node_text(&arguments_node, source)); + } + } + } + parent = parent_node.parent(); + } + None +} + fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String { node.utf8_text(source).unwrap_or_default().to_string() } @@ -235,9 +256,14 @@ fn extract_definitions(language: &str, source: &str) -> Result, let mut enum_def_map: BTreeMap> = BTreeMap::new(); let mut union_def_map: BTreeMap> = BTreeMap::new(); - let ensure_class_def = |name: &str, class_def_map: &mut BTreeMap>| { + let ensure_class_def = |language: &str, name: &str, class_def_map: &mut BTreeMap>| { + let mut type_name = "class"; + if language == "elixir" { + type_name = "module" + } class_def_map.entry(name.to_string()).or_insert_with(|| { RefCell::new(Class { + type_name: type_name.to_string(), name: name.to_string(), methods: vec![], properties: vec![], @@ -337,7 +363,7 @@ fn extract_definitions(language: &str, source: &str) -> Result, if language == "go" && !is_first_letter_uppercase(&name) { continue; } - ensure_class_def(&name, &mut class_def_map); + ensure_class_def(&language, &name, &mut class_def_map); let visibility_modifier_node = find_child_by_type(&node, "visibility_modifier"); let visibility_modifier = visibility_modifier_node @@ -449,12 +475,18 @@ fn extract_definitions(language: &str, source: &str) -> Result, .child_by_field_name("parameters") .or_else(|| find_descendant_by_type(&node, "parameter_list")); - let function_node = find_ancestor_by_type(&node, "function_declaration"); + let zig_function_node = find_ancestor_by_type(&node, "function_declaration"); if language == "zig" { - params_node = function_node + params_node = zig_function_node .as_ref() .and_then(|n| find_child_by_type(n, "parameters")); } + let ex_function_node = find_ancestor_by_type(&node, "call"); + if language == "elixir" { + params_node = ex_function_node + .as_ref() + .and_then(|n| find_child_by_type(n, "arguments")); + } let params = params_node .map(|n| n.utf8_text(source.as_bytes()).unwrap()) @@ -480,6 +512,9 @@ fn extract_definitions(language: &str, source: &str) -> Result, return_type_node = node.child_by_field_name("result"); } let mut return_type = "void".to_string(); + if language == "elixir" { + return_type = "".to_string(); + } if return_type_node.is_some() { return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes()); if return_type.is_empty() { @@ -496,6 +531,9 @@ fn extract_definitions(language: &str, source: &str) -> Result, let class_name = if language == "zig" { zig_find_parent_variable_declaration_name(&node, source.as_bytes()) .unwrap_or_default() + } else if language == "elixir" { + ex_find_parent_module_declaration_name(&node, source.as_bytes()) + .unwrap_or_default() } else if language == "cpp" { find_ancestor_by_type(&node, "class_specifier") .or_else(|| find_ancestor_by_type(&node, "struct_specifier")) @@ -524,7 +562,7 @@ fn extract_definitions(language: &str, source: &str) -> Result, continue; } - ensure_class_def(&class_name, &mut class_def_map); + ensure_class_def(&language, &class_name, &mut class_def_map); let class_def = class_def_map.get_mut(&class_name).unwrap(); let accessibility_modifier_node = @@ -569,7 +607,7 @@ fn extract_definitions(language: &str, source: &str) -> Result, if class_name.is_empty() { continue; } - ensure_class_def(&class_name, &mut class_def_map); + ensure_class_def(&language, &class_name, &mut class_def_map); let class_def = class_def_map.get_mut(&class_name).unwrap(); let variable = Variable { name: left.to_string(), @@ -623,7 +661,7 @@ fn extract_definitions(language: &str, source: &str) -> Result, if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) { continue; } - ensure_class_def(&class_name, &mut class_def_map); + ensure_class_def(&language, &class_name, &mut class_def_map); let class_def = class_def_map.get_mut(&class_name).unwrap(); let variable = Variable { name: name.to_string(), @@ -888,7 +926,7 @@ fn stringify_union_item(item: &Variable) -> String { } fn stringify_class(class: &Class) -> String { - let mut res = format!("class {}{{", class.name); + let mut res = format!("{} {}{{", class.type_name, class.name); for method in &class.methods { let method_str = stringify_function(method); res = format!("{res}{method_str}"); @@ -1428,6 +1466,45 @@ mod tests { assert_eq!(stringified, expected); } + #[test] + fn test_elixir() { + let source = r#" + defmodule TestModule do + @moduledoc """ + This is a test module + """ + + @test_const "test" + @other_const 123 + + def test_func(a, b) do + a + b + end + + defp private_func(x) do + x * 2 + end + + defmacro test_macro(expr) do + quote do + unquote(expr) + end + end + end + + defmodule AnotherModule do + def another_func() do + :ok + end + end + "#; + let definitions = extract_definitions("elixir", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{stringified}"); + let expected = "module AnotherModule{func another_func();};module TestModule{func test_func(a, b);};"; + assert_eq!(stringified, expected); + } + #[test] fn test_unsupported_language() { let source = "print('Hello, world!')"; diff --git a/lua/avante/repo_map.lua b/lua/avante/repo_map.lua index bd8a029..b7bd477 100644 --- a/lua/avante/repo_map.lua +++ b/lua/avante/repo_map.lua @@ -2,7 +2,6 @@ local Popup = require("nui.popup") local Utils = require("avante.utils") local event = require("nui.utils.autocmd").event local Config = require("avante.config") -local fn = vim.fn local filetype_map = { ["javascriptreact"] = "javascript", @@ -34,14 +33,13 @@ function RepoMap.get_ts_lang(filepath) end function RepoMap.get_filetype(filepath) - local filetype = vim.filetype.match({ filename = filepath }) - -- TypeScript files are sometimes not detected correctly + -- Some files are sometimes not detected correctly when buffer is not included -- https://github.com/neovim/neovim/issues/27265 - if not filetype then - local ext = fn.fnamemodify(filepath, ":e") - if ext == "tsx" then filetype = "typescriptreact" end - if ext == "ts" then filetype = "typescript" end - end + + local buf = vim.api.nvim_create_buf(false, true) + local filetype = vim.filetype.match({ filename = filepath, buf = buf }) + vim.api.nvim_buf_delete(buf, { force = true }) + return filetype end