chore(rust): fix current clippy lint (#504)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2024-09-03 22:18:53 -04:00 committed by GitHub
parent e57a3f27df
commit 41c78127e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 35 additions and 18 deletions

View File

@ -41,3 +41,15 @@ jobs:
components: clippy, rustfmt components: clippy, rustfmt
- name: Run rustfmt - name: Run rustfmt
run: make ruststylecheck run: make ruststylecheck
rustlint:
name: Lint Rust
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: Swatinem/rust-cache@v2
- uses: dtolnay/rust-toolchain@master
with:
toolchain: stable
components: clippy, rustfmt
- name: Run rustfmt
run: make rustlint

View File

@ -15,7 +15,7 @@ impl<'a> State<'a> {
} }
} }
#[derive(Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct TemplateContext { struct TemplateContext {
use_xml_format: bool, use_xml_format: bool,
ask: bool, ask: bool,
@ -29,16 +29,17 @@ struct TemplateContext {
// Given the file name registered after add, the context table in Lua, resulted in a formatted // Given the file name registered after add, the context table in Lua, resulted in a formatted
// Lua string // Lua string
fn render(state: &State, template: String, context: TemplateContext) -> LuaResult<String> { #[allow(clippy::needless_pass_by_value)]
fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<String> {
let environment = state.environment.lock().unwrap(); let environment = state.environment.lock().unwrap();
match environment.as_ref() { match environment.as_ref() {
Some(environment) => { Some(environment) => {
let template = environment let jinja_template = environment
.get_template(&template) .get_template(template)
.map_err(LuaError::external) .map_err(LuaError::external)
.unwrap(); .unwrap();
Ok(template Ok(jinja_template
.render(context! { .render(context! {
use_xml_format => context.use_xml_format, use_xml_format => context.use_xml_format,
ask => context.ask, ask => context.ask,
@ -84,7 +85,7 @@ fn avante_templates(lua: &Lua) -> LuaResult<LuaTable> {
"render", "render",
lua.create_function_mut(move |lua, (template, context): (String, LuaValue)| { lua.create_function_mut(move |lua, (template, context): (String, LuaValue)| {
let ctx = lua.from_value(context)?; let ctx = lua.from_value(context)?;
render(&state_clone, template, ctx) render(&state_clone, template.as_str(), ctx)
})?, })?,
)?; )?;
Ok(exports) Ok(exports)

View File

@ -8,13 +8,13 @@ struct Tiktoken {
} }
impl Tiktoken { impl Tiktoken {
fn new(model: String) -> Self { fn new(model: &str) -> Self {
let bpe = get_bpe_from_model(&model).unwrap(); let bpe = get_bpe_from_model(model).unwrap();
Tiktoken { bpe } Tiktoken { bpe }
} }
fn encode(&self, text: String) -> (Vec<usize>, usize, usize) { fn encode(&self, text: &str) -> (Vec<usize>, usize, usize) {
let tokens = self.bpe.encode_with_special_tokens(&text); let tokens = self.bpe.encode_with_special_tokens(text);
let num_tokens = tokens.len(); let num_tokens = tokens.len();
let num_chars = text.chars().count(); let num_chars = text.chars().count();
(tokens, num_tokens, num_chars) (tokens, num_tokens, num_chars)
@ -26,13 +26,17 @@ struct HuggingFaceTokenizer {
} }
impl HuggingFaceTokenizer { impl HuggingFaceTokenizer {
fn new(model: String) -> Self { fn new(model: &str) -> Self {
let tokenizer = Tokenizer::from_pretrained(model, None).unwrap(); let tokenizer = Tokenizer::from_pretrained(model, None).unwrap();
HuggingFaceTokenizer { tokenizer } HuggingFaceTokenizer { tokenizer }
} }
fn encode(&self, text: String) -> (Vec<usize>, usize, usize) { fn encode(&self, text: &str) -> (Vec<usize>, usize, usize) {
let encoding = self.tokenizer.encode(text, false).unwrap(); let encoding = self
.tokenizer
.encode(text, false)
.map_err(LuaError::external)
.unwrap();
let tokens: Vec<usize> = encoding.get_ids().iter().map(|x| *x as usize).collect(); let tokens: Vec<usize> = encoding.get_ids().iter().map(|x| *x as usize).collect();
let num_tokens = tokens.len(); let num_tokens = tokens.len();
let num_chars = encoding.get_offsets().last().unwrap().1; let num_chars = encoding.get_offsets().last().unwrap().1;
@ -57,7 +61,7 @@ impl State {
} }
} }
fn encode(state: &State, text: String) -> LuaResult<(Vec<usize>, usize, usize)> { fn encode(state: &State, text: &str) -> LuaResult<(Vec<usize>, usize, usize)> {
let tokenizer = state.tokenizer.lock().unwrap(); let tokenizer = state.tokenizer.lock().unwrap();
match tokenizer.as_ref() { match tokenizer.as_ref() {
Some(TokenizerType::Tiktoken(tokenizer)) => Ok(tokenizer.encode(text)), Some(TokenizerType::Tiktoken(tokenizer)) => Ok(tokenizer.encode(text)),
@ -68,9 +72,9 @@ fn encode(state: &State, text: String) -> LuaResult<(Vec<usize>, usize, usize)>
} }
} }
fn from_pretrained(state: &State, model: String) { fn from_pretrained(state: &State, model: &str) {
let mut tokenizer_mutex = state.tokenizer.lock().unwrap(); let mut tokenizer_mutex = state.tokenizer.lock().unwrap();
*tokenizer_mutex = Some(match model.as_str() { *tokenizer_mutex = Some(match model {
"gpt-4o" => TokenizerType::Tiktoken(Tiktoken::new(model)), "gpt-4o" => TokenizerType::Tiktoken(Tiktoken::new(model)),
_ => TokenizerType::HuggingFace(HuggingFaceTokenizer::new(model)), _ => TokenizerType::HuggingFace(HuggingFaceTokenizer::new(model)),
}); });
@ -86,13 +90,13 @@ fn avante_tokenizers(lua: &Lua) -> LuaResult<LuaTable> {
exports.set( exports.set(
"from_pretrained", "from_pretrained",
lua.create_function(move |_, model: String| { lua.create_function(move |_, model: String| {
from_pretrained(&state, model); from_pretrained(&state, model.as_str());
Ok(()) Ok(())
})?, })?,
)?; )?;
exports.set( exports.set(
"encode", "encode",
lua.create_function(move |_, text: String| encode(&state_clone, text))?, lua.create_function(move |_, text: String| encode(&state_clone, text.as_str()))?,
)?; )?;
Ok(exports) Ok(exports)
} }