feat(tokenizers): support parsing from public URL (#765)
This commit is contained in:
parent
a8e2b9a00c
commit
bdbbdec88c
97
Cargo.lock
generated
97
Cargo.lock
generated
@ -64,9 +64,13 @@ dependencies = [
|
|||||||
name = "avante-tokenizers"
|
name = "avante-tokenizers"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"dirs",
|
||||||
|
"hf-hub",
|
||||||
"mlua",
|
"mlua",
|
||||||
|
"regex",
|
||||||
"tiktoken-rs",
|
"tiktoken-rs",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
|
"ureq",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -343,16 +347,6 @@ dependencies = [
|
|||||||
"cc",
|
"cc",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "fancy-regex"
|
|
||||||
version = "0.12.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "7493d4c459da9f84325ad297371a6b2b8a162800873a22e3b6b6512e61d18c05"
|
|
||||||
dependencies = [
|
|
||||||
"bit-set",
|
|
||||||
"regex",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fancy-regex"
|
name = "fancy-regex"
|
||||||
version = "0.13.0"
|
version = "0.13.0"
|
||||||
@ -585,9 +579,9 @@ checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minijinja"
|
name = "minijinja"
|
||||||
version = "2.2.0"
|
version = "2.4.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6d7d3e3a3eece1fa4618237ad41e1de855ced47eab705cec1c9a920e1d1c5aad"
|
checksum = "c9ca8daf4b0b4029777f1bc6e1aedd1aec7b74c276a43bc6f620a8e1a1c0a90e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"memo-map",
|
"memo-map",
|
||||||
@ -616,10 +610,12 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mlua"
|
name = "mlua"
|
||||||
version = "0.10.0-beta.1"
|
version = "0.10.0"
|
||||||
source = "git+https://github.com/mlua-rs/mlua.git?branch=main#1634c43f0afaf7a71dc555cb6b3624250e5ff209"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0f6ddbd668297c46be4bdea6c599dcc1f001a129586272d53170b7ac0a62961e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bstr",
|
"bstr",
|
||||||
|
"either",
|
||||||
"erased-serde",
|
"erased-serde",
|
||||||
"mlua-sys",
|
"mlua-sys",
|
||||||
"mlua_derive",
|
"mlua_derive",
|
||||||
@ -632,8 +628,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mlua-sys"
|
name = "mlua-sys"
|
||||||
version = "0.6.2"
|
version = "0.6.4"
|
||||||
source = "git+https://github.com/mlua-rs/mlua.git?branch=main#1634c43f0afaf7a71dc555cb6b3624250e5ff209"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e9eebac25c35a13285456c88ee2fde93d9aee8bcfdaf03f9d6d12be3391351ec"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cc",
|
"cc",
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
@ -642,8 +639,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mlua_derive"
|
name = "mlua_derive"
|
||||||
version = "0.9.3"
|
version = "0.10.0"
|
||||||
source = "git+https://github.com/mlua-rs/mlua.git?branch=main#1634c43f0afaf7a71dc555cb6b3624250e5ff209"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2cfc5faa2e0d044b3f5f0879be2920e0a711c97744c42cf1c295cb183668933e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@ -957,9 +955,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex"
|
name = "regex"
|
||||||
version = "1.10.6"
|
version = "1.11.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
|
checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"memchr",
|
"memchr",
|
||||||
@ -969,9 +967,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex-automata"
|
name = "regex-automata"
|
||||||
version = "0.4.7"
|
version = "0.4.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
|
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"memchr",
|
"memchr",
|
||||||
@ -980,9 +978,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex-syntax"
|
name = "regex-syntax"
|
||||||
version = "0.8.4"
|
version = "0.8.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
|
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ring"
|
name = "ring"
|
||||||
@ -1160,6 +1158,17 @@ version = "1.13.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "socks"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"libc",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "spin"
|
name = "spin"
|
||||||
version = "0.9.8"
|
version = "0.9.8"
|
||||||
@ -1216,18 +1225,18 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.63"
|
version = "1.0.65"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
|
checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"thiserror-impl",
|
"thiserror-impl",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror-impl"
|
name = "thiserror-impl"
|
||||||
version = "1.0.63"
|
version = "1.0.65"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
|
checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@ -1236,16 +1245,17 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tiktoken-rs"
|
name = "tiktoken-rs"
|
||||||
version = "0.5.9"
|
version = "0.6.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c314e7ce51440f9e8f5a497394682a57b7c323d0f4d0a6b1b13c429056e0e234"
|
checksum = "44075987ee2486402f0808505dd65692163d243a337fc54363d49afac41087f6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"base64 0.21.7",
|
"base64 0.21.7",
|
||||||
"bstr",
|
"bstr",
|
||||||
"fancy-regex 0.12.0",
|
"fancy-regex",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
|
"regex",
|
||||||
"rustc-hash 1.1.0",
|
"rustc-hash 1.1.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1273,7 +1283,7 @@ dependencies = [
|
|||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
"esaxx-rs",
|
"esaxx-rs",
|
||||||
"fancy-regex 0.13.0",
|
"fancy-regex",
|
||||||
"getrandom",
|
"getrandom",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
"itertools 0.12.1",
|
"itertools 0.12.1",
|
||||||
@ -1499,6 +1509,7 @@ dependencies = [
|
|||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"socks",
|
||||||
"url",
|
"url",
|
||||||
"webpki-roots",
|
"webpki-roots",
|
||||||
]
|
]
|
||||||
@ -1602,6 +1613,28 @@ dependencies = [
|
|||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi"
|
||||||
|
version = "0.3.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-i686-pc-windows-gnu",
|
||||||
|
"winapi-x86_64-pc-windows-gnu",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-i686-pc-windows-gnu"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-x86_64-pc-windows-gnu"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-sys"
|
name = "windows-sys"
|
||||||
version = "0.48.0"
|
version = "0.48.0"
|
||||||
|
@ -12,7 +12,7 @@ version = "0.1.0"
|
|||||||
avante-tokenizers = { path = "crates/avante-tokenizers" }
|
avante-tokenizers = { path = "crates/avante-tokenizers" }
|
||||||
avante-templates = { path = "crates/avante-templates" }
|
avante-templates = { path = "crates/avante-templates" }
|
||||||
avante-repo-map = { path = "crates/avante-repo-map" }
|
avante-repo-map = { path = "crates/avante-repo-map" }
|
||||||
minijinja = { version = "2.2.0", features = [
|
minijinja = { version = "2.4.0", features = [
|
||||||
"loader",
|
"loader",
|
||||||
"json",
|
"json",
|
||||||
"fuel",
|
"fuel",
|
||||||
@ -21,11 +21,8 @@ minijinja = { version = "2.2.0", features = [
|
|||||||
"custom_syntax",
|
"custom_syntax",
|
||||||
"loop_controls",
|
"loop_controls",
|
||||||
] }
|
] }
|
||||||
mlua = { version = "0.10.0-beta.1", features = [
|
mlua = { version = "0.10.0", features = ["module", "serialize"] }
|
||||||
"module",
|
tiktoken-rs = { version = "0.6.0" }
|
||||||
"serialize",
|
|
||||||
], git = "https://github.com/mlua-rs/mlua.git", branch = "main" }
|
|
||||||
tiktoken-rs = { version = "0.5.9" }
|
|
||||||
tokenizers = { version = "0.20.0", features = [
|
tokenizers = { version = "0.20.0", features = [
|
||||||
"esaxx_fast",
|
"esaxx_fast",
|
||||||
"http",
|
"http",
|
||||||
|
@ -12,6 +12,10 @@ license = { workspace = true }
|
|||||||
workspace = true
|
workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
dirs = "5.0.1"
|
||||||
|
regex = "1.11.1"
|
||||||
|
hf-hub = { version = "0.3.2", features = ["default"] }
|
||||||
|
ureq = { version = "2.10.1", features = ["json", "socks-proxy"] }
|
||||||
mlua = { workspace = true }
|
mlua = { workspace = true }
|
||||||
tiktoken-rs = { workspace = true }
|
tiktoken-rs = { workspace = true }
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
|
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
use regex::Regex;
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use tiktoken_rs::{get_bpe_from_model, CoreBPE};
|
use tiktoken_rs::{get_bpe_from_model, CoreBPE};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
@ -10,10 +13,10 @@ struct Tiktoken {
|
|||||||
impl Tiktoken {
|
impl Tiktoken {
|
||||||
fn new(model: &str) -> Self {
|
fn new(model: &str) -> Self {
|
||||||
let bpe = get_bpe_from_model(model).unwrap();
|
let bpe = get_bpe_from_model(model).unwrap();
|
||||||
Tiktoken { bpe }
|
Self { bpe }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn encode(&self, text: &str) -> (Vec<usize>, usize, usize) {
|
fn encode(&self, text: &str) -> (Vec<u32>, usize, usize) {
|
||||||
let tokens = self.bpe.encode_with_special_tokens(text);
|
let tokens = self.bpe.encode_with_special_tokens(text);
|
||||||
let num_tokens = tokens.len();
|
let num_tokens = tokens.len();
|
||||||
let num_chars = text.chars().count();
|
let num_chars = text.chars().count();
|
||||||
@ -25,23 +28,53 @@ struct HuggingFaceTokenizer {
|
|||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HuggingFaceTokenizer {
|
fn is_valid_url(url: &str) -> bool {
|
||||||
fn new(model: &str) -> Self {
|
let url_regex = Regex::new(r"^https?://[^\s/$.?#].[^\s]*$").unwrap();
|
||||||
let tokenizer = Tokenizer::from_pretrained(model, None).unwrap();
|
url_regex.is_match(url)
|
||||||
HuggingFaceTokenizer { tokenizer }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn encode(&self, text: &str) -> (Vec<usize>, usize, usize) {
|
impl HuggingFaceTokenizer {
|
||||||
let encoding = self
|
fn new(model: &str) -> Self {
|
||||||
.tokenizer
|
let tokenizer_path = if is_valid_url(model) {
|
||||||
.encode(text, false)
|
Self::get_cached_tokenizer(model)
|
||||||
.map_err(LuaError::external)
|
} else {
|
||||||
.unwrap();
|
// Use existing HuggingFace Hub logic for model names
|
||||||
let tokens: Vec<usize> = encoding.get_ids().iter().map(|x| *x as usize).collect();
|
let identifier = model.to_string();
|
||||||
|
let api = ApiBuilder::new().with_progress(false).build().unwrap();
|
||||||
|
let repo = Repo::new(identifier, RepoType::Model);
|
||||||
|
let api = api.repo(repo);
|
||||||
|
api.get("tokenizer.json").unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
|
||||||
|
Self { tokenizer }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode(&self, text: &str) -> (Vec<u32>, usize, usize) {
|
||||||
|
let encoding = self.tokenizer.encode(text, false).unwrap();
|
||||||
|
let tokens = encoding.get_ids().to_vec();
|
||||||
let num_tokens = tokens.len();
|
let num_tokens = tokens.len();
|
||||||
let num_chars = encoding.get_offsets().last().unwrap().1;
|
let num_chars = encoding.get_offsets().last().unwrap().1;
|
||||||
(tokens, num_tokens, num_chars)
|
(tokens, num_tokens, num_chars)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_cached_tokenizer(url: &str) -> PathBuf {
|
||||||
|
let cache_dir = dirs::home_dir()
|
||||||
|
.map(|h| h.join(".cache").join("avante"))
|
||||||
|
.unwrap();
|
||||||
|
std::fs::create_dir_all(&cache_dir).unwrap();
|
||||||
|
|
||||||
|
// Extract filename from URL
|
||||||
|
let filename = url.split('/').last().unwrap();
|
||||||
|
|
||||||
|
let cached_path = cache_dir.join(filename);
|
||||||
|
|
||||||
|
if !cached_path.exists() {
|
||||||
|
let response = ureq::get(url).call().unwrap();
|
||||||
|
let _ = std::fs::write(&cached_path, response.into_string().unwrap());
|
||||||
|
}
|
||||||
|
cached_path
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
enum TokenizerType {
|
enum TokenizerType {
|
||||||
@ -61,7 +94,7 @@ impl State {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn encode(state: &State, text: &str) -> LuaResult<(Vec<usize>, usize, usize)> {
|
fn encode(state: &State, text: &str) -> LuaResult<(Vec<u32>, usize, usize)> {
|
||||||
let tokenizer = state.tokenizer.lock().unwrap();
|
let tokenizer = state.tokenizer.lock().unwrap();
|
||||||
match tokenizer.as_ref() {
|
match tokenizer.as_ref() {
|
||||||
Some(TokenizerType::Tiktoken(tokenizer)) => Ok(tokenizer.encode(text)),
|
Some(TokenizerType::Tiktoken(tokenizer)) => Ok(tokenizer.encode(text)),
|
||||||
@ -100,3 +133,62 @@ fn avante_tokenizers(lua: &Lua) -> LuaResult<LuaTable> {
|
|||||||
)?;
|
)?;
|
||||||
Ok(exports)
|
Ok(exports)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tiktoken() {
|
||||||
|
let model = "gpt-4o";
|
||||||
|
let source = "Hello, world!";
|
||||||
|
let tokenizer = Tiktoken::new(model);
|
||||||
|
let (tokens, num_tokens, num_chars) = tokenizer.encode(source);
|
||||||
|
assert_eq!(tokens, vec![13225, 11, 2375, 0]);
|
||||||
|
assert_eq!(num_tokens, 4);
|
||||||
|
assert_eq!(num_chars, source.chars().count());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_hf() {
|
||||||
|
let model = "gpt2";
|
||||||
|
let source = "Hello, world!";
|
||||||
|
let tokenizer = HuggingFaceTokenizer::new(model);
|
||||||
|
let (tokens, num_tokens, num_chars) = tokenizer.encode(source);
|
||||||
|
assert_eq!(tokens, vec![15496, 11, 995, 0]);
|
||||||
|
assert_eq!(num_tokens, 4);
|
||||||
|
assert_eq!(num_chars, source.chars().count());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_roundtrip() {
|
||||||
|
let state = State::new();
|
||||||
|
let source = "Hello, world!";
|
||||||
|
let model = "gpt2";
|
||||||
|
|
||||||
|
from_pretrained(&state, model);
|
||||||
|
let (tokens, num_tokens, num_chars) = encode(&state, "Hello, world!").unwrap();
|
||||||
|
assert_eq!(tokens, vec![15496, 11, 995, 0]);
|
||||||
|
assert_eq!(num_tokens, 4);
|
||||||
|
assert_eq!(num_chars, source.chars().count());
|
||||||
|
}
|
||||||
|
|
||||||
|
// For example: https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json
|
||||||
|
// Disable testing on GitHub Actions to avoid rate limiting and file size limits
|
||||||
|
#[test]
|
||||||
|
fn test_public_url() {
|
||||||
|
if std::env::var("GITHUB_ACTIONS").is_ok() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let state = State::new();
|
||||||
|
let source = "Hello, world!";
|
||||||
|
let model =
|
||||||
|
"https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json";
|
||||||
|
|
||||||
|
from_pretrained(&state, model);
|
||||||
|
let (tokens, num_tokens, num_chars) = encode(&state, "Hello, world!").unwrap();
|
||||||
|
assert_eq!(tokens, vec![28339, 19, 3845, 8]);
|
||||||
|
assert_eq!(num_tokens, 4);
|
||||||
|
assert_eq!(num_chars, source.chars().count());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -76,7 +76,7 @@ Respect and use existing conventions, libraries, etc that are already present in
|
|||||||
},
|
},
|
||||||
---@type AvanteSupportedProvider
|
---@type AvanteSupportedProvider
|
||||||
cohere = {
|
cohere = {
|
||||||
endpoint = "https://api.cohere.com/v1",
|
endpoint = "https://api.cohere.com/v2",
|
||||||
model = "command-r-plus-08-2024",
|
model = "command-r-plus-08-2024",
|
||||||
timeout = 30000, -- Timeout in milliseconds
|
timeout = 30000, -- Timeout in milliseconds
|
||||||
temperature = 0,
|
temperature = 0,
|
||||||
|
@ -2,57 +2,69 @@ local Utils = require("avante.utils")
|
|||||||
local P = require("avante.providers")
|
local P = require("avante.providers")
|
||||||
|
|
||||||
---@alias CohereFinishReason "COMPLETE" | "LENGTH" | "ERROR"
|
---@alias CohereFinishReason "COMPLETE" | "LENGTH" | "ERROR"
|
||||||
|
---@alias CohereStreamType "message-start" | "content-start" | "content-delta" | "content-end" | "message-end"
|
||||||
---
|
---
|
||||||
---@class CohereChatStreamResponse
|
---@class CohereChatContent
|
||||||
---@field event_type "stream-start" | "text-generation" | "stream-end"
|
---@field type? CohereStreamType
|
||||||
---@field is_finished boolean
|
|
||||||
---
|
|
||||||
---@class CohereTextGenerationResponse: CohereChatStreamResponse
|
|
||||||
---@field text string
|
---@field text string
|
||||||
---
|
---
|
||||||
---@class CohereStreamEndResponse: CohereChatStreamResponse
|
---@class CohereChatMessage
|
||||||
---@field response CohereChatResponse
|
---@field content CohereChatContent
|
||||||
---@field finish_reason CohereFinishReason
|
|
||||||
---
|
---
|
||||||
---@class CohereChatResponse
|
---@class CohereChatStreamBase
|
||||||
---@field text string
|
---@field type CohereStreamType
|
||||||
---@field generation_id string
|
---@field index integer
|
||||||
---@field chat_history CohereMessage[]
|
---
|
||||||
---@field finish_reason CohereFinishReason
|
---@class CohereChatContentDelta: CohereChatStreamBase
|
||||||
---@field meta {api_version: {version: integer}, billed_units: {input_tokens: integer, output_tokens: integer}, tokens: {input_tokens: integer, output_tokens: integer}}
|
---@field type "content-delta" | "content-start" | "content-end"
|
||||||
|
---@field delta? { message: CohereChatMessage }
|
||||||
|
---
|
||||||
|
---@class CohereChatMessageStart: CohereChatStreamBase
|
||||||
|
---@field type "message-start"
|
||||||
|
---@field delta { message: { role: "assistant" } }
|
||||||
|
---
|
||||||
|
---@class CohereChatMessageEnd: CohereChatStreamBase
|
||||||
|
---@field type "message-end"
|
||||||
|
---@field delta { finish_reason: CohereFinishReason, usage: CohereChatUsage }
|
||||||
|
---
|
||||||
|
---@class CohereChatUsage
|
||||||
|
---@field billed_units { input_tokens: integer, output_tokens: integer }
|
||||||
|
---@field tokens { input_tokens: integer, output_tokens: integer }
|
||||||
|
---
|
||||||
|
---@alias CohereChatResponse CohereChatContentDelta | CohereChatMessageStart | CohereChatMessageEnd
|
||||||
---
|
---
|
||||||
---@class CohereMessage
|
---@class CohereMessage
|
||||||
---@field role? "USER" | "SYSTEM" | "CHATBOT"
|
---@field type "text"
|
||||||
---@field message string
|
---@field text string
|
||||||
---
|
---
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
M.api_key_name = "CO_API_KEY"
|
M.api_key_name = "CO_API_KEY"
|
||||||
M.tokenizer_id = "CohereForAI/c4ai-command-r-plus-08-2024"
|
M.tokenizer_id = "https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json"
|
||||||
|
|
||||||
M.parse_message = function(opts)
|
M.parse_message = function(opts)
|
||||||
return {
|
---@type CohereMessage[]
|
||||||
preamble = opts.system_prompt,
|
local user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
|
||||||
message = table.concat(opts.user_prompts, "\n"),
|
table.insert(acc, { type = "text", text = prompt })
|
||||||
|
return acc
|
||||||
|
end)
|
||||||
|
local messages = {
|
||||||
|
{ role = "system", content = opts.system_prompt },
|
||||||
|
{ role = "user", content = user_content },
|
||||||
}
|
}
|
||||||
|
return { messages = messages }
|
||||||
end
|
end
|
||||||
|
|
||||||
M.parse_stream_data = function(data, opts)
|
M.parse_stream_data = function(data, opts)
|
||||||
---@type CohereChatStreamResponse
|
---@type CohereChatResponse
|
||||||
local json = vim.json.decode(data)
|
local json = vim.json.decode(data)
|
||||||
if json.is_finished then
|
if json.type ~= nil then
|
||||||
|
if json.type == "message-end" and json.delta.finish_reason == "COMPLETE" then
|
||||||
opts.on_complete(nil)
|
opts.on_complete(nil)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
if json.event_type ~= nil then
|
if json.type == "content-delta" then opts.on_chunk(json.delta.message.content.text) end
|
||||||
---@cast json CohereStreamEndResponse
|
|
||||||
if json.event_type == "stream-end" and json.finish_reason == "COMPLETE" then
|
|
||||||
opts.on_complete(nil)
|
|
||||||
return
|
|
||||||
end
|
|
||||||
---@cast json CohereTextGenerationResponse
|
|
||||||
if json.event_type == "text-generation" then opts.on_chunk(json.text) end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -83,4 +95,10 @@ M.parse_curl_args = function(provider, code_opts)
|
|||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
M.setup = function()
|
||||||
|
P.env.parse_envvar(M)
|
||||||
|
require("avante.tokenizers").setup(M.tokenizer_id, false)
|
||||||
|
vim.g.avante_login = true
|
||||||
|
end
|
||||||
|
|
||||||
return M
|
return M
|
||||||
|
@ -8,7 +8,9 @@ local tokenizers = nil
|
|||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
---@param model "gpt-4o" | string
|
---@param model "gpt-4o" | string
|
||||||
M.setup = function(model)
|
---@param warning? boolean
|
||||||
|
M.setup = function(model, warning)
|
||||||
|
warning = warning or true
|
||||||
vim.defer_fn(function()
|
vim.defer_fn(function()
|
||||||
local ok, core = pcall(require, "avante_tokenizers")
|
local ok, core = pcall(require, "avante_tokenizers")
|
||||||
if not ok then return end
|
if not ok then return end
|
||||||
@ -19,6 +21,7 @@ M.setup = function(model)
|
|||||||
core.from_pretrained(model)
|
core.from_pretrained(model)
|
||||||
end, 1000)
|
end, 1000)
|
||||||
|
|
||||||
|
if warning then
|
||||||
local HF_TOKEN = os.getenv("HF_TOKEN")
|
local HF_TOKEN = os.getenv("HF_TOKEN")
|
||||||
if HF_TOKEN == nil and model ~= "gpt-4o" then
|
if HF_TOKEN == nil and model ~= "gpt-4o" then
|
||||||
Utils.warn(
|
Utils.warn(
|
||||||
@ -26,7 +29,7 @@ M.setup = function(model)
|
|||||||
{ once = true }
|
{ once = true }
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
vim.env.HF_HUB_DISABLE_PROGRESS_BARS = 1
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
M.available = function() return tokenizers ~= nil end
|
M.available = function() return tokenizers ~= nil end
|
||||||
|
Loading…
x
Reference in New Issue
Block a user