feat: tokenizers (#407)
* feat: autobuild tiktoken lib and schenanigans Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * chore: revert readme changes Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * fix(build): windows Signed-off-by: Hanchin Hsieh <me@yuchanns.xyz> * chore(plugin): early load commands and base setup Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * fix(build): make sync Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * feat: rust go vroom vroom Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * feat: scuffed afaf implementation binding go brrrr Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * chore: remove dups Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * fix(tokens): calculate whether we should do prompt_caching (fixes #416) Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * chore: ignore lockfiles Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * Update README.md * Update crates/avante-tokenizers/README.md * chore: remove unused Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * chore: remove auto build Signed-off-by: Aaron Pham <contact@aarnphm.xyz> --------- Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: Hanchin Hsieh <me@yuchanns.xyz> Co-authored-by: yuchanns <me@yuchanns.xyz>
This commit is contained in:
parent
81b44e4533
commit
d2095ba267
5
.cargo/config.toml
Normal file
5
.cargo/config.toml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
[target.x86_64-apple-darwin]
|
||||||
|
rustflags = ["-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup"]
|
||||||
|
|
||||||
|
[target.aarch64-apple-darwin]
|
||||||
|
rustflags = ["-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup"]
|
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
* text=auto eol=lf
|
||||||
|
**/*.lock linguist-generated=true
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -43,5 +43,4 @@ temp/
|
|||||||
# If you have any personal configuration files, you should ignore them too
|
# If you have any personal configuration files, you should ignore them too
|
||||||
config.personal.lua
|
config.personal.lua
|
||||||
|
|
||||||
# Avante chat history
|
target
|
||||||
.avante_chat_history
|
|
||||||
|
29
Build.ps1
Normal file
29
Build.ps1
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
param (
|
||||||
|
[string]$Version = "luajit"
|
||||||
|
)
|
||||||
|
|
||||||
|
$BuildDir = "build"
|
||||||
|
$BuildFromSource = $true
|
||||||
|
|
||||||
|
function Build-FromSource($feature) {
|
||||||
|
if (-not (Test-Path $BuildDir)) {
|
||||||
|
New-Item -ItemType Directory -Path $BuildDir | Out-Null
|
||||||
|
}
|
||||||
|
|
||||||
|
cargo build --release --features=$feature
|
||||||
|
|
||||||
|
$targetFile = "avante_tokenizers.dll"
|
||||||
|
Copy-Item (Join-Path "target\release\libavante_tokenizers.dll") (Join-Path $BuildDir $targetFile)
|
||||||
|
|
||||||
|
Remove-Item -Recurse -Force "target"
|
||||||
|
}
|
||||||
|
|
||||||
|
function Main {
|
||||||
|
Set-Location $PSScriptRoot
|
||||||
|
Write-Host "Building for $Version..."
|
||||||
|
Build-FromSource $Version
|
||||||
|
Write-Host "Completed!"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run the main function
|
||||||
|
Main
|
1580
Cargo.lock
generated
Normal file
1580
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
44
Cargo.toml
Normal file
44
Cargo.toml
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
[workspace]
|
||||||
|
members = ["crates/*"]
|
||||||
|
resolver = "2"
|
||||||
|
|
||||||
|
[workspace.package]
|
||||||
|
edition = "2021"
|
||||||
|
rust-version = "1.80"
|
||||||
|
license = "Apache-2.0"
|
||||||
|
version = "0.1.0"
|
||||||
|
|
||||||
|
[workspace.dependencies]
|
||||||
|
avante-tokenizers = { path = "crates/avante-tokenizers" }
|
||||||
|
|
||||||
|
[workspace.lints.rust]
|
||||||
|
unsafe_code = "warn"
|
||||||
|
unreachable_pub = "warn"
|
||||||
|
|
||||||
|
[workspace.lints.clippy]
|
||||||
|
pedantic = { level = "warn", priority = -2 }
|
||||||
|
# Allowed pedantic lints
|
||||||
|
char_lit_as_u8 = "allow"
|
||||||
|
collapsible_else_if = "allow"
|
||||||
|
collapsible_if = "allow"
|
||||||
|
implicit_hasher = "allow"
|
||||||
|
map_unwrap_or = "allow"
|
||||||
|
match_same_arms = "allow"
|
||||||
|
missing_errors_doc = "allow"
|
||||||
|
missing_panics_doc = "allow"
|
||||||
|
module_name_repetitions = "allow"
|
||||||
|
must_use_candidate = "allow"
|
||||||
|
similar_names = "allow"
|
||||||
|
too_many_lines = "allow"
|
||||||
|
too_many_arguments = "allow"
|
||||||
|
# Disallowed restriction lints
|
||||||
|
print_stdout = "warn"
|
||||||
|
print_stderr = "warn"
|
||||||
|
dbg_macro = "warn"
|
||||||
|
empty_drop = "warn"
|
||||||
|
empty_structs_with_brackets = "warn"
|
||||||
|
exit = "warn"
|
||||||
|
get_unwrap = "warn"
|
||||||
|
rc_buffer = "warn"
|
||||||
|
rc_mutex = "warn"
|
||||||
|
rest_pat_in_fully_bound_structs = "warn"
|
47
Makefile
47
Makefile
@ -1,5 +1,48 @@
|
|||||||
hello:
|
UNAME := $(shell uname)
|
||||||
@echo Hello avante.nvim!
|
ARCH := $(shell uname -m)
|
||||||
|
|
||||||
|
ifeq ($(UNAME), Linux)
|
||||||
|
OS := linux
|
||||||
|
EXT := so
|
||||||
|
else ifeq ($(UNAME), Darwin)
|
||||||
|
OS := macOS
|
||||||
|
EXT := dylib
|
||||||
|
else
|
||||||
|
$(error Unsupported operating system: $(UNAME))
|
||||||
|
endif
|
||||||
|
|
||||||
|
LUA_VERSIONS := luajit lua51
|
||||||
|
BUILD_DIR := build
|
||||||
|
|
||||||
|
all: luajit
|
||||||
|
|
||||||
|
luajit: $(BUILD_DIR)/libavante_tokenizers.$(EXT)
|
||||||
|
lua51: $(BUILD_DIR)/libavante_tokenizers-lua51.$(EXT)
|
||||||
|
lua52: $(BUILD_DIR)/libavante_tokenizers-lua52.$(EXT)
|
||||||
|
lua53: $(BUILD_DIR)/libavante_tokenizers-lua53.$(EXT)
|
||||||
|
lua54: $(BUILD_DIR)/libavante_tokenizers-lua54.$(EXT)
|
||||||
|
|
||||||
|
define build_from_source
|
||||||
|
cargo build --release --features=$1
|
||||||
|
cp target/release/libavante_tokenizers.$(EXT) $(BUILD_DIR)/avante_tokenizers.$(EXT)
|
||||||
|
endef
|
||||||
|
|
||||||
|
$(BUILD_DIR)/libavante_tokenizers.$(EXT): $(BUILD_DIR)
|
||||||
|
$(call build_from_source,luajit)
|
||||||
|
$(BUILD_DIR)/libavante_tokenizers-lua51.$(EXT): $(BUILD_DIR)
|
||||||
|
$(call build_from_source,lua51)
|
||||||
|
$(BUILD_DIR)/libavante_tokenizers-lua52.$(EXT): $(BUILD_DIR)
|
||||||
|
$(call build_from_source,lua52)
|
||||||
|
$(BUILD_DIR)/libavante_tokenizers-lua53.$(EXT): $(BUILD_DIR)
|
||||||
|
$(call build_from_source,lua53)
|
||||||
|
$(BUILD_DIR)/libavante_tokenizers-lua54.$(EXT): $(BUILD_DIR)
|
||||||
|
$(call build_from_source,lua54)
|
||||||
|
|
||||||
|
$(BUILD_DIR):
|
||||||
|
mkdir -p $(BUILD_DIR)
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -rf $(BUILD_DIR)
|
||||||
|
|
||||||
luacheck:
|
luacheck:
|
||||||
luacheck `find -name "*.lua"` --codes
|
luacheck `find -name "*.lua"` --codes
|
||||||
|
25
README.md
25
README.md
@ -30,11 +30,7 @@ https://github.com/user-attachments/assets/86140bfd-08b4-483d-a887-1b701d9e37dd
|
|||||||
opts = {
|
opts = {
|
||||||
-- add any opts here
|
-- add any opts here
|
||||||
},
|
},
|
||||||
keys = {
|
build = ":AvanteBuild", -- This is optional, recommended tho. Also note that this will block the startup for a bit since we are compiling bindings in Rust.
|
||||||
{ "<leader>aa", function() require("avante.api").ask() end, desc = "avante: ask", mode = { "n", "v" } },
|
|
||||||
{ "<leader>ar", function() require("avante.api").refresh() end, desc = "avante: refresh" },
|
|
||||||
{ "<leader>ae", function() require("avante.api").edit() end, desc = "avante: edit", mode = "v" },
|
|
||||||
},
|
|
||||||
dependencies = {
|
dependencies = {
|
||||||
"stevearc/dressing.nvim",
|
"stevearc/dressing.nvim",
|
||||||
"nvim-lua/plenary.nvim",
|
"nvim-lua/plenary.nvim",
|
||||||
@ -90,9 +86,14 @@ Plug 'HakonHarnes/img-clip.nvim'
|
|||||||
Plug 'zbirenbaum/copilot.lua'
|
Plug 'zbirenbaum/copilot.lua'
|
||||||
|
|
||||||
" Yay
|
" Yay
|
||||||
Plug 'yetone/avante.nvim'
|
Plug 'yetone/avante.nvim', { 'branch': 'main', 'do': ':AvanteBuild', 'on': 'AvanteAsk' }
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!important]
|
||||||
|
>
|
||||||
|
> For `avante.tokenizers` to work, make sure to call `require('avante_lib').load()` somewhere when entering the editor.
|
||||||
|
> We will leave the users to decide where it fits to do this, as this varies among configurations. (But we do recommend running this after where you set your colorscheme)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@ -100,31 +101,31 @@ Plug 'yetone/avante.nvim'
|
|||||||
<summary><a href="https://github.com/echasnovski/mini.deps">mini.deps</a></summary>
|
<summary><a href="https://github.com/echasnovski/mini.deps">mini.deps</a></summary>
|
||||||
|
|
||||||
```lua
|
```lua
|
||||||
local add, later = MiniDeps.add, MiniDeps.later
|
local add, later, now = MiniDeps.add, MiniDeps.later, MiniDeps.now
|
||||||
|
|
||||||
add({
|
add({
|
||||||
source = 'yetone/avante.nvim',
|
source = 'yetone/avante.nvim',
|
||||||
|
monitor = 'main',
|
||||||
depends = {
|
depends = {
|
||||||
'stevearc/dressing.nvim',
|
'stevearc/dressing.nvim',
|
||||||
'nvim-lua/plenary.nvim',
|
'nvim-lua/plenary.nvim',
|
||||||
'MunifTanjim/nui.nvim',
|
'MunifTanjim/nui.nvim',
|
||||||
'echasnovski/mini.icons'
|
'echasnovski/mini.icons'
|
||||||
},
|
},
|
||||||
|
hooks = { post_checkout = function() vim.cmd('AvanteBuild') end }
|
||||||
})
|
})
|
||||||
--- optional
|
--- optional
|
||||||
add({ source = 'zbirenbaum/copilot.lua' })
|
add({ source = 'zbirenbaum/copilot.lua' })
|
||||||
add({ source = 'HakonHarnes/img-clip.nvim' })
|
add({ source = 'HakonHarnes/img-clip.nvim' })
|
||||||
add({ source = 'MeanderingProgrammer/render-markdown.nvim' })
|
add({ source = 'MeanderingProgrammer/render-markdown.nvim' })
|
||||||
|
|
||||||
|
now(function() require('avante_lib').load() end)
|
||||||
later(function() require('render-markdown').setup({...}) end)
|
later(function() require('render-markdown').setup({...}) end)
|
||||||
later(function()
|
later(function()
|
||||||
require('img-clip').setup({...}) -- config img-clip
|
require('img-clip').setup({...}) -- config img-clip
|
||||||
require("copilot").setup({...}) -- setup copilot to your liking
|
require("copilot").setup({...}) -- setup copilot to your liking
|
||||||
require("avante").setup({...}) -- config for avante.nvim
|
require("avante").setup({...}) -- config for avante.nvim
|
||||||
end)
|
end)
|
||||||
|
|
||||||
```
|
|
||||||
```
|
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@ -171,6 +172,10 @@ require('avante').setup ({
|
|||||||
> }
|
> }
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
>
|
||||||
|
> Any rendering plugins that support markdown should work Avante as long as you add the supported filetype `Avante`. See https://github.com/yetone/avante.nvim/issues/175 and [this comment](https://github.com/yetone/avante.nvim/issues/175#issuecomment-2313749363) for more information.
|
||||||
|
|
||||||
### Default setup configuration
|
### Default setup configuration
|
||||||
|
|
||||||
_See [config.lua#L9](./lua/avante/config.lua) for the full config_
|
_See [config.lua#L9](./lua/avante/config.lua) for the full config_
|
||||||
|
1487
crates/avante-tokenizers/Cargo.lock
generated
Normal file
1487
crates/avante-tokenizers/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
32
crates/avante-tokenizers/Cargo.toml
Normal file
32
crates/avante-tokenizers/Cargo.toml
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
[lib]
|
||||||
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
|
[package]
|
||||||
|
name = "avante-tokenizers"
|
||||||
|
edition = { workspace = true }
|
||||||
|
version = { workspace = true }
|
||||||
|
rust-version = { workspace = true }
|
||||||
|
license = { workspace = true }
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
mlua = { version = "0.10.0-beta.1", features = [
|
||||||
|
"module",
|
||||||
|
"serialize",
|
||||||
|
], git = "https://github.com/mlua-rs/mlua.git", branch = "main" }
|
||||||
|
tiktoken-rs = "0.5.9"
|
||||||
|
tokenizers = { version = "0.20.0", features = [
|
||||||
|
"esaxx_fast",
|
||||||
|
"http",
|
||||||
|
"unstable_wasm",
|
||||||
|
"onig",
|
||||||
|
], default-features = false }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
lua51 = ["mlua/lua51"]
|
||||||
|
lua52 = ["mlua/lua52"]
|
||||||
|
lua53 = ["mlua/lua53"]
|
||||||
|
lua54 = ["mlua/lua54"]
|
||||||
|
luajit = ["mlua/luajit"]
|
1
crates/avante-tokenizers/README.md
Normal file
1
crates/avante-tokenizers/README.md
Normal file
@ -0,0 +1 @@
|
|||||||
|
A simple crate to unify hf/tokenizers and tiktoken-rs
|
96
crates/avante-tokenizers/src/lib.rs
Normal file
96
crates/avante-tokenizers/src/lib.rs
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
use mlua::prelude::*;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use tiktoken_rs::{get_bpe_from_model, CoreBPE};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
struct Tiktoken {
|
||||||
|
bpe: CoreBPE,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tiktoken {
|
||||||
|
fn new(model: String) -> Self {
|
||||||
|
let bpe = get_bpe_from_model(&model).unwrap();
|
||||||
|
Tiktoken { bpe }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode(&self, text: String) -> (Vec<usize>, usize, usize) {
|
||||||
|
let tokens = self.bpe.encode_with_special_tokens(&text);
|
||||||
|
let num_tokens = tokens.len();
|
||||||
|
let num_chars = text.chars().count();
|
||||||
|
(tokens, num_tokens, num_chars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct HuggingFaceTokenizer {
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HuggingFaceTokenizer {
|
||||||
|
fn new(model: String) -> Self {
|
||||||
|
let tokenizer = Tokenizer::from_pretrained(model, None).unwrap();
|
||||||
|
HuggingFaceTokenizer { tokenizer }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode(&self, text: String) -> (Vec<usize>, usize, usize) {
|
||||||
|
let encoding = self.tokenizer.encode(text, false).unwrap();
|
||||||
|
let tokens: Vec<usize> = encoding.get_ids().iter().map(|x| *x as usize).collect();
|
||||||
|
let num_tokens = tokens.len();
|
||||||
|
let num_chars = encoding.get_offsets().last().unwrap().1;
|
||||||
|
(tokens, num_tokens, num_chars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum TokenizerType {
|
||||||
|
Tiktoken(Tiktoken),
|
||||||
|
HuggingFace(HuggingFaceTokenizer),
|
||||||
|
}
|
||||||
|
|
||||||
|
struct State {
|
||||||
|
tokenizer: Mutex<Option<TokenizerType>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl State {
|
||||||
|
fn new() -> Self {
|
||||||
|
State {
|
||||||
|
tokenizer: Mutex::new(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode(state: &State, text: String) -> LuaResult<(Vec<usize>, usize, usize)> {
|
||||||
|
let tokenizer = state.tokenizer.lock().unwrap();
|
||||||
|
match tokenizer.as_ref() {
|
||||||
|
Some(TokenizerType::Tiktoken(tokenizer)) => Ok(tokenizer.encode(text)),
|
||||||
|
Some(TokenizerType::HuggingFace(tokenizer)) => Ok(tokenizer.encode(text)),
|
||||||
|
None => Err(LuaError::RuntimeError(
|
||||||
|
"Tokenizer not initialized".to_string(),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_pretrained(state: &State, model: String) -> LuaResult<()> {
|
||||||
|
let mut tokenizer_mutex = state.tokenizer.lock().unwrap();
|
||||||
|
*tokenizer_mutex = Some(match model.as_str() {
|
||||||
|
"gpt-4o" => TokenizerType::Tiktoken(Tiktoken::new(model)),
|
||||||
|
_ => TokenizerType::HuggingFace(HuggingFaceTokenizer::new(model)),
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[mlua::lua_module]
|
||||||
|
fn avante_tokenizers(lua: &Lua) -> LuaResult<LuaTable> {
|
||||||
|
let core = State::new();
|
||||||
|
let state = Arc::new(core);
|
||||||
|
let state_clone = Arc::clone(&state);
|
||||||
|
|
||||||
|
let exports = lua.create_table()?;
|
||||||
|
exports.set(
|
||||||
|
"from_pretrained",
|
||||||
|
lua.create_function(move |_, model: String| from_pretrained(&state, model))?,
|
||||||
|
)?;
|
||||||
|
exports.set(
|
||||||
|
"encode",
|
||||||
|
lua.create_function(move |_, text: String| encode(&state_clone, text))?,
|
||||||
|
)?;
|
||||||
|
Ok(exports)
|
||||||
|
}
|
@ -10,6 +10,7 @@ local Utils = require("avante.utils")
|
|||||||
---@field ask fun(): boolean
|
---@field ask fun(): boolean
|
||||||
---@field edit fun(): nil
|
---@field edit fun(): nil
|
||||||
---@field refresh fun(): nil
|
---@field refresh fun(): nil
|
||||||
|
---@field build fun(): boolean
|
||||||
---@field toggle avante.ApiToggle
|
---@field toggle avante.ApiToggle
|
||||||
|
|
||||||
return setmetatable({}, {
|
return setmetatable({}, {
|
||||||
|
@ -34,6 +34,9 @@ H.commands = function()
|
|||||||
cmd("Refresh", function()
|
cmd("Refresh", function()
|
||||||
M.refresh()
|
M.refresh()
|
||||||
end, { desc = "avante: refresh windows" })
|
end, { desc = "avante: refresh windows" })
|
||||||
|
cmd("Build", function()
|
||||||
|
M.build()
|
||||||
|
end, { desc = "avante: build dependencies" })
|
||||||
end
|
end
|
||||||
|
|
||||||
H.keymaps = function()
|
H.keymaps = function()
|
||||||
@ -91,6 +94,34 @@ end
|
|||||||
H.augroup = api.nvim_create_augroup("avante_autocmds", { clear = true })
|
H.augroup = api.nvim_create_augroup("avante_autocmds", { clear = true })
|
||||||
|
|
||||||
H.autocmds = function()
|
H.autocmds = function()
|
||||||
|
local ok, LazyConfig = pcall(require, "lazy.core.config")
|
||||||
|
|
||||||
|
if ok then
|
||||||
|
local name = "avante.nvim"
|
||||||
|
local load_path = function()
|
||||||
|
require("avante_lib").load()
|
||||||
|
end
|
||||||
|
|
||||||
|
if LazyConfig.plugins[name] and LazyConfig.plugins[name]._.loaded then
|
||||||
|
vim.schedule(load_path)
|
||||||
|
else
|
||||||
|
api.nvim_create_autocmd("User", {
|
||||||
|
pattern = "LazyLoad",
|
||||||
|
callback = function(event)
|
||||||
|
if event.data == name then
|
||||||
|
load_path()
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
end,
|
||||||
|
})
|
||||||
|
end
|
||||||
|
|
||||||
|
api.nvim_create_autocmd("User", {
|
||||||
|
pattern = "VeryLazy",
|
||||||
|
callback = load_path,
|
||||||
|
})
|
||||||
|
end
|
||||||
|
|
||||||
api.nvim_create_autocmd("TabEnter", {
|
api.nvim_create_autocmd("TabEnter", {
|
||||||
group = H.augroup,
|
group = H.augroup,
|
||||||
pattern = "*",
|
pattern = "*",
|
||||||
@ -221,6 +252,53 @@ setmetatable(M.toggle, {
|
|||||||
end,
|
end,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
local function to_windows_path(path)
|
||||||
|
local winpath = path:gsub("/", "\\")
|
||||||
|
|
||||||
|
if winpath:match("^%a:") then
|
||||||
|
winpath = winpath:sub(1, 2):upper() .. winpath:sub(3)
|
||||||
|
end
|
||||||
|
|
||||||
|
winpath = winpath:gsub("\\$", "")
|
||||||
|
|
||||||
|
return winpath
|
||||||
|
end
|
||||||
|
|
||||||
|
M.build = H.api(function()
|
||||||
|
local dirname = Utils.trim(string.sub(debug.getinfo(1).source, 2, #"/init.lua" * -1), { suffix = "/" })
|
||||||
|
local git_root = vim.fs.find(".git", { path = dirname, upward = true })[1]
|
||||||
|
local build_directory = git_root and vim.fn.fnamemodify(git_root, ":h") or (dirname .. "/../../")
|
||||||
|
|
||||||
|
if not vim.fn.executable("cargo") then
|
||||||
|
error("Building avante.nvim requires cargo to be installed.", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
---@type string[]
|
||||||
|
local cmd
|
||||||
|
local os_name = Utils.get_os_name()
|
||||||
|
|
||||||
|
if vim.tbl_contains({ "linux", "darwin" }, os_name) then
|
||||||
|
cmd = { "sh", "-c", ("make -C %s"):format(build_directory) }
|
||||||
|
elseif os_name == "windows" then
|
||||||
|
build_directory = to_windows_path(build_directory)
|
||||||
|
cmd = {
|
||||||
|
"powershell",
|
||||||
|
"-ExecutionPolicy",
|
||||||
|
"Bypass",
|
||||||
|
"-File",
|
||||||
|
("%s\\Build.ps1"):format(build_directory),
|
||||||
|
"-WorkingDirectory",
|
||||||
|
build_directory,
|
||||||
|
}
|
||||||
|
else
|
||||||
|
error("Unsupported operating system: " .. os_name, 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
local job = vim.system(cmd, { text = true }):wait()
|
||||||
|
|
||||||
|
return vim.tbl_contains({ 0 }, job.code) and true or false
|
||||||
|
end)
|
||||||
|
|
||||||
M.ask = H.api(function()
|
M.ask = H.api(function()
|
||||||
M.toggle()
|
M.toggle()
|
||||||
end)
|
end)
|
||||||
@ -283,7 +361,7 @@ function M.setup(opts)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
require("avante.history").setup()
|
require("avante.path").setup()
|
||||||
require("avante.highlights").setup()
|
require("avante.highlights").setup()
|
||||||
require("avante.diff").setup()
|
require("avante.diff").setup()
|
||||||
require("avante.providers").setup()
|
require("avante.providers").setup()
|
||||||
|
@ -2,10 +2,14 @@ local fn, api = vim.fn, vim.api
|
|||||||
local Path = require("plenary.path")
|
local Path = require("plenary.path")
|
||||||
local Config = require("avante.config")
|
local Config = require("avante.config")
|
||||||
|
|
||||||
|
---@class avante.Path
|
||||||
|
---@field history_path Path
|
||||||
|
---@field cache_path Path
|
||||||
|
local P = {}
|
||||||
|
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
local H = {}
|
local H = {}
|
||||||
|
|
||||||
---@param bufnr integer
|
---@param bufnr integer
|
||||||
---@return string
|
---@return string
|
||||||
H.filename = function(bufnr)
|
H.filename = function(bufnr)
|
||||||
@ -39,11 +43,20 @@ M.save = function(bufnr, history)
|
|||||||
history_file:write(vim.json.encode(history), "w")
|
history_file:write(vim.json.encode(history), "w")
|
||||||
end
|
end
|
||||||
|
|
||||||
M.setup = function()
|
P.history = M
|
||||||
local history_dir = Path:new(Config.history.storage_path)
|
|
||||||
if not history_dir:exists() then
|
P.setup = function()
|
||||||
history_dir:mkdir({ parents = true })
|
local history_path = Path:new(Config.history.storage_path)
|
||||||
|
if not history_path:exists() then
|
||||||
|
history_path:mkdir({ parents = true })
|
||||||
end
|
end
|
||||||
|
P.history_path = history_path
|
||||||
|
|
||||||
|
local cache_path = Path:new(vim.fn.stdpath("cache") .. "/avante")
|
||||||
|
if not cache_path:exists() then
|
||||||
|
cache_path:mkdir({ parents = true })
|
||||||
|
end
|
||||||
|
P.cache_path = cache_path
|
||||||
end
|
end
|
||||||
|
|
||||||
return M
|
return P
|
@ -12,6 +12,7 @@ local O = require("avante.providers").openai
|
|||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.api_key_name = "AZURE_OPENAI_API_KEY"
|
M.api_key_name = "AZURE_OPENAI_API_KEY"
|
||||||
|
M.tokenizer_id = "gpt-4o"
|
||||||
|
|
||||||
M.parse_message = O.parse_message
|
M.parse_message = O.parse_message
|
||||||
M.parse_response = O.parse_response
|
M.parse_response = O.parse_response
|
||||||
|
@ -6,6 +6,7 @@ local P = require("avante.providers")
|
|||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.api_key_name = "ANTHROPIC_API_KEY"
|
M.api_key_name = "ANTHROPIC_API_KEY"
|
||||||
|
M.tokenizer_id = "gpt-4o"
|
||||||
|
|
||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
M.parse_message = function(prompt_opts)
|
M.parse_message = function(prompt_opts)
|
||||||
@ -28,8 +29,10 @@ M.parse_message = function(prompt_opts)
|
|||||||
local user_prompt_obj = {
|
local user_prompt_obj = {
|
||||||
type = "text",
|
type = "text",
|
||||||
text = user_prompt,
|
text = user_prompt,
|
||||||
cache_control = { type = "ephemeral" },
|
|
||||||
}
|
}
|
||||||
|
if Utils.tokens.calculate_tokens(user_prompt_obj.text) > 1024 then
|
||||||
|
user_prompt_obj.cache_control = { type = "ephemeral" }
|
||||||
|
end
|
||||||
|
|
||||||
table.insert(message_content, user_prompt_obj)
|
table.insert(message_content, user_prompt_obj)
|
||||||
end
|
end
|
||||||
|
@ -29,6 +29,7 @@ local P = require("avante.providers")
|
|||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.api_key_name = "CO_API_KEY"
|
M.api_key_name = "CO_API_KEY"
|
||||||
|
M.tokenizer_id = "CohereForAI/c4ai-command-r-plus-08-2024"
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_message = function(opts)
|
||||||
local user_prompt = table.concat(opts.user_prompts, "\n\n")
|
local user_prompt = table.concat(opts.user_prompts, "\n\n")
|
||||||
|
@ -127,6 +127,7 @@ end
|
|||||||
M.state = nil
|
M.state = nil
|
||||||
|
|
||||||
M.api_key_name = P.AVANTE_INTERNAL_KEY
|
M.api_key_name = P.AVANTE_INTERNAL_KEY
|
||||||
|
M.tokenizer_id = "gpt-4o"
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_message = function(opts)
|
||||||
return {
|
return {
|
||||||
@ -166,6 +167,7 @@ M.setup = function()
|
|||||||
M.state = { github_token = nil, oauth_token = H.get_oauth_token() }
|
M.state = { github_token = nil, oauth_token = H.get_oauth_token() }
|
||||||
H.refresh_token()
|
H.refresh_token()
|
||||||
end
|
end
|
||||||
|
require("avante.tokenizers").setup(M.tokenizer_id)
|
||||||
vim.g.avante_login = true
|
vim.g.avante_login = true
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ local Clipboard = require("avante.clipboard")
|
|||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.api_key_name = "GEMINI_API_KEY"
|
M.api_key_name = "GEMINI_API_KEY"
|
||||||
|
M.tokenizer_id = "google/gemma-2b"
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_message = function(opts)
|
||||||
local message_content = {}
|
local message_content = {}
|
||||||
|
@ -69,6 +69,7 @@ local Dressing = require("avante.ui.dressing")
|
|||||||
---@field setup fun(): nil
|
---@field setup fun(): nil
|
||||||
---@field has fun(): boolean
|
---@field has fun(): boolean
|
||||||
---@field api_key_name string
|
---@field api_key_name string
|
||||||
|
---@field tokenizer_id string | "gpt-4o"
|
||||||
---@field model? string
|
---@field model? string
|
||||||
---@field parse_api_key fun(): string | nil
|
---@field parse_api_key fun(): string | nil
|
||||||
---@field parse_stream_data? AvanteStreamParser
|
---@field parse_stream_data? AvanteStreamParser
|
||||||
@ -269,6 +270,11 @@ M = setmetatable(M, {
|
|||||||
return E.parse_envvar(t[k])
|
return E.parse_envvar(t[k])
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- default to gpt-4o as tokenizer
|
||||||
|
if t[k].tokenizer_id == nil then
|
||||||
|
t[k].tokenizer_id = "gpt-4o"
|
||||||
|
end
|
||||||
|
|
||||||
if t[k].has == nil then
|
if t[k].has == nil then
|
||||||
t[k].has = function()
|
t[k].has = function()
|
||||||
return E.parse_envvar(t[k]) ~= nil
|
return E.parse_envvar(t[k]) ~= nil
|
||||||
@ -280,6 +286,7 @@ M = setmetatable(M, {
|
|||||||
if not E.is_local(k) then
|
if not E.is_local(k) then
|
||||||
t[k].parse_api_key()
|
t[k].parse_api_key()
|
||||||
end
|
end
|
||||||
|
require("avante.tokenizers").setup(t[k].tokenizer_id)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ local P = require("avante.providers")
|
|||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.api_key_name = "OPENAI_API_KEY"
|
M.api_key_name = "OPENAI_API_KEY"
|
||||||
|
M.tokenizer_id = "gpt-4o"
|
||||||
|
|
||||||
---@param opts AvantePromptOptions
|
---@param opts AvantePromptOptions
|
||||||
M.get_user_message = function(opts)
|
M.get_user_message = function(opts)
|
||||||
|
@ -4,7 +4,7 @@ local fn = vim.fn
|
|||||||
local Split = require("nui.split")
|
local Split = require("nui.split")
|
||||||
local event = require("nui.utils.autocmd").event
|
local event = require("nui.utils.autocmd").event
|
||||||
|
|
||||||
local History = require("avante.history")
|
local Path = require("avante.path")
|
||||||
local Config = require("avante.config")
|
local Config = require("avante.config")
|
||||||
local Diff = require("avante.diff")
|
local Diff = require("avante.diff")
|
||||||
local Llm = require("avante.llm")
|
local Llm = require("avante.llm")
|
||||||
@ -1170,7 +1170,7 @@ function Sidebar:get_commands()
|
|||||||
end,
|
end,
|
||||||
clear = function(args, cb)
|
clear = function(args, cb)
|
||||||
local chat_history = {}
|
local chat_history = {}
|
||||||
History.save(self.code.bufnr, chat_history)
|
Path.history.save(self.code.bufnr, chat_history)
|
||||||
self:update_content("Chat history cleared", { focus = false, scroll = false })
|
self:update_content("Chat history cleared", { focus = false, scroll = false })
|
||||||
vim.defer_fn(function()
|
vim.defer_fn(function()
|
||||||
self:close()
|
self:close()
|
||||||
@ -1239,7 +1239,7 @@ function Sidebar:create_input()
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
local chat_history = History.load(self.code.bufnr)
|
local chat_history = Path.history.load(self.code.bufnr)
|
||||||
|
|
||||||
---@param request string
|
---@param request string
|
||||||
local function handle_submit(request)
|
local function handle_submit(request)
|
||||||
@ -1356,7 +1356,7 @@ function Sidebar:create_input()
|
|||||||
request = request,
|
request = request,
|
||||||
response = full_response,
|
response = full_response,
|
||||||
})
|
})
|
||||||
History.save(self.code.bufnr, chat_history)
|
Path.history.save(self.code.bufnr, chat_history)
|
||||||
end
|
end
|
||||||
|
|
||||||
Llm.stream({
|
Llm.stream({
|
||||||
@ -1564,7 +1564,7 @@ function Sidebar:get_selected_code_size()
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Sidebar:render()
|
function Sidebar:render()
|
||||||
local chat_history = History.load(self.code.bufnr)
|
local chat_history = Path.history.load(self.code.bufnr)
|
||||||
|
|
||||||
local sidebar_height = api.nvim_win_get_height(self.code.winid)
|
local sidebar_height = api.nvim_win_get_height(self.code.winid)
|
||||||
local selected_code_size = self:get_selected_code_size()
|
local selected_code_size = self:get_selected_code_size()
|
||||||
|
66
lua/avante/tokenizers.lua
Normal file
66
lua/avante/tokenizers.lua
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
local Utils = require("avante.utils")
|
||||||
|
|
||||||
|
---@class AvanteTokenizer
|
||||||
|
---@field from_pretrained fun(model: string): nil
|
||||||
|
---@field encode fun(string): integer[]
|
||||||
|
local tokenizers = nil
|
||||||
|
|
||||||
|
local M = {}
|
||||||
|
|
||||||
|
---@param model "gpt-4o" | string
|
||||||
|
M.setup = function(model)
|
||||||
|
local ok, core = pcall(require, "avante_tokenizers")
|
||||||
|
if not ok then
|
||||||
|
return
|
||||||
|
end
|
||||||
|
---@cast core AvanteTokenizer
|
||||||
|
if tokenizers == nil then
|
||||||
|
tokenizers = core
|
||||||
|
end
|
||||||
|
|
||||||
|
local HF_TOKEN = os.getenv("HF_TOKEN")
|
||||||
|
if HF_TOKEN == nil and model ~= "gpt-4o" then
|
||||||
|
Utils.warn(
|
||||||
|
"Please set HF_TOKEN environment variable to use HuggingFace tokenizer if " .. model .. " is gated",
|
||||||
|
{ once = true }
|
||||||
|
)
|
||||||
|
end
|
||||||
|
vim.env.HF_HUB_DISABLE_PROGRESS_BARS = 1
|
||||||
|
|
||||||
|
---@cast core AvanteTokenizer
|
||||||
|
core.from_pretrained(model)
|
||||||
|
end
|
||||||
|
|
||||||
|
M.available = function()
|
||||||
|
return tokenizers ~= nil
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param prompt string
|
||||||
|
M.encode = function(prompt)
|
||||||
|
if not tokenizers then
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
if not prompt or prompt == "" then
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
if type(prompt) ~= "string" then
|
||||||
|
error("Prompt is not type string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
return tokenizers.encode(prompt)
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param prompt string
|
||||||
|
M.count = function(prompt)
|
||||||
|
if not tokenizers then
|
||||||
|
return math.ceil(#prompt * 0.5)
|
||||||
|
end
|
||||||
|
|
||||||
|
local tokens = M.encode(prompt)
|
||||||
|
if not tokens then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
return #tokens
|
||||||
|
end
|
||||||
|
|
||||||
|
return M
|
@ -1,4 +1,6 @@
|
|||||||
--Taken from https://github.com/jackMort/ChatGPT.nvim/blob/main/lua/chatgpt/flows/chat/tokens.lua
|
--Taken from https://github.com/jackMort/ChatGPT.nvim/blob/main/lua/chatgpt/flows/chat/tokens.lua
|
||||||
|
local Tokenizer = require("avante.tokenizers")
|
||||||
|
|
||||||
---@class avante.utils.tokens
|
---@class avante.utils.tokens
|
||||||
local Tokens = {}
|
local Tokens = {}
|
||||||
|
|
||||||
@ -11,6 +13,10 @@ local cost_per_token = {
|
|||||||
---@param text string The text to calculate the number of tokens for.
|
---@param text string The text to calculate the number of tokens for.
|
||||||
---@return integer The number of tokens in the given text.
|
---@return integer The number of tokens in the given text.
|
||||||
function Tokens.calculate_tokens(text)
|
function Tokens.calculate_tokens(text)
|
||||||
|
if Tokenizer.available() then
|
||||||
|
return Tokenizer.count(text)
|
||||||
|
end
|
||||||
|
|
||||||
local tokens = 0
|
local tokens = 0
|
||||||
local current_token = ""
|
local current_token = ""
|
||||||
for char in text:gmatch(".") do
|
for char in text:gmatch(".") do
|
||||||
|
22
lua/avante_lib.lua
Normal file
22
lua/avante_lib.lua
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
local M = {}
|
||||||
|
|
||||||
|
local function get_library_path()
|
||||||
|
local os_name = require("avante.utils").get_os_name()
|
||||||
|
local ext = os_name == "linux" and "so" or (os_name == "darwin" and "dylib" or "dll")
|
||||||
|
local dirname = string.sub(debug.getinfo(1).source, 2, #"/avante_lib.lua" * -1)
|
||||||
|
return dirname .. ("../build/?.%s"):format(ext)
|
||||||
|
end
|
||||||
|
|
||||||
|
---@type fun(s: string): string
|
||||||
|
local trim_semicolon = function(s)
|
||||||
|
return s:sub(-1) == ";" and s:sub(1, -2) or s
|
||||||
|
end
|
||||||
|
|
||||||
|
M.load = function()
|
||||||
|
local library_path = get_library_path()
|
||||||
|
if not string.find(package.cpath, library_path, 1, true) then
|
||||||
|
package.cpath = trim_semicolon(package.cpath) .. ";" .. library_path
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return M
|
@ -1,5 +1,7 @@
|
|||||||
--- NOTE: We will override vim.paste if img-clip.nvim is available to work with avante.nvim internal logic paste
|
--- NOTE: We will override vim.paste if img-clip.nvim is available to work with avante.nvim internal logic paste
|
||||||
|
|
||||||
|
require("avante").setup()
|
||||||
|
|
||||||
local Clipboard = require("avante.clipboard")
|
local Clipboard = require("avante.clipboard")
|
||||||
local Config = require("avante.config")
|
local Config = require("avante.config")
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user