Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/node/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ crate-type = ["cdylib"]
[dependencies]
napi = "2"
napi-derive = "2"
rustc-hash = "2.1.1"
serde = { version = "1.0.163", features = ["derive"] }
tokenizers = { path = "../../tokenizers/" }

Expand Down
4 changes: 2 additions & 2 deletions bindings/node/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFr
use crate::trainers::Trainer;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use tokenizers as tk;
Expand Down Expand Up @@ -95,7 +95,7 @@ impl tk::Model for Model {
self.model.as_ref()?.read().unwrap().id_to_token(id)
}

fn get_vocab(&self) -> HashMap<String, u32> {
fn get_vocab(&self) -> FxHashMap<String, u32> {
self
.model
.as_ref()
Expand Down
4 changes: 2 additions & 2 deletions bindings/node/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::pre_tokenizers::PreTokenizer;
use crate::processors::Processor;
use crate::tasks::tokenizer::{DecodeBatchTask, DecodeTask, EncodeBatchTask, EncodeTask};
use crate::trainers::Trainer;
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use tokenizers::Model as ModelTrait;

use napi::bindgen_prelude::*;
Expand Down Expand Up @@ -433,7 +433,7 @@ impl Tokenizer {
}

#[napi]
pub fn get_vocab(&self, with_added_tokens: Option<bool>) -> HashMap<String, u32> {
pub fn get_vocab(&self, with_added_tokens: Option<bool>) -> FxHashMap<String, u32> {
let with_added_tokens = with_added_tokens.unwrap_or(true);
self.tokenizer.read().unwrap().get_vocab(with_added_tokens)
}
Expand Down
1 change: 1 addition & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pyo3 = { version = "0.23", features = ["abi3", "abi3-py39", "py-clone"] }
numpy = "0.23"
ndarray = "0.16"
itertools = "0.12"
rustc-hash = "2.1.1"

[dependencies.tokenizers]
path = "../../tokenizers"
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};

Expand Down Expand Up @@ -70,7 +70,7 @@ impl Model for PyModel {
self.model.read().unwrap().id_to_token(id)
}

fn get_vocab(&self) -> HashMap<String, u32> {
fn get_vocab(&self) -> FxHashMap<String, u32> {
self.model.read().unwrap().get_vocab()
}

Expand Down
6 changes: 3 additions & 3 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use rustc_hash::{FxHashMap, FxHasher};
use serde::Serialize;
use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher};

use numpy::{npyffi, PyArray1, PyArrayMethods};
Expand Down Expand Up @@ -255,7 +255,7 @@ impl PyAddedToken {
}

fn __hash__(&self) -> u64 {
let mut hasher = DefaultHasher::new();
let mut hasher = FxHasher::default();
self.get_token().hash(&mut hasher);
hasher.finish()
}
Expand Down Expand Up @@ -675,7 +675,7 @@ impl PyTokenizer {
/// :obj:`Dict[str, int]`: The vocabulary
#[pyo3(signature = (with_added_tokens = true))]
#[pyo3(text_signature = "(self, with_added_tokens=True)")]
fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
fn get_vocab(&self, with_added_tokens: bool) -> FxHashMap<String, u32> {
self.tokenizer.get_vocab(with_added_tokens)
}

Expand Down
1 change: 1 addition & 0 deletions tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ fancy-regex = { version = "0.14", optional = true}
getrandom = { version = "0.2.10" }
esaxx-rs = { version = "0.1.10", default-features = false, features=[]}
monostate = "0.1.12"
rustc-hash = "2.1.1"

[features]
default = ["progressbar", "onig", "esaxx_fast"]
Expand Down
29 changes: 18 additions & 11 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word};
use crate::tokenizer::{Model, Result, Token};
use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH};
use crate::utils::iter::ResultShunt;
use rustc_hash::FxHashMap;
use serde_json::Value;
use std::borrow::Cow;
use std::collections::HashMap;
use std::hash::BuildHasher;
use std::iter::FromIterator;
use std::{
collections::HashMap,
fs::File,
io::prelude::*,
io::{BufRead, BufReader},
path::{Path, PathBuf},
};

pub type Vocab = HashMap<String, u32>;
type VocabR = HashMap<u32, String>;
pub type MergeMap = HashMap<Pair, (u32, u32)>;
pub type Vocab = FxHashMap<String, u32>;
type VocabR = FxHashMap<u32, String>;
pub type MergeMap = FxHashMap<Pair, (u32, u32)>;
pub type Merges = Vec<(String, String)>;

struct Config {
Expand All @@ -41,7 +44,7 @@ impl Default for BpeBuilder {
Self {
config: Config {
files: None,
vocab: HashMap::new(),
vocab: FxHashMap::default(),
merges: vec![],
cache_capacity: DEFAULT_CACHE_CAPACITY,
dropout: None,
Expand Down Expand Up @@ -71,8 +74,12 @@ impl BpeBuilder {

/// Set the vocab (token -> ID) and merges mappings.
#[must_use]
pub fn vocab_and_merges(mut self, vocab: Vocab, merges: Merges) -> Self {
self.config.vocab = vocab;
pub fn vocab_and_merges<S: BuildHasher>(
mut self,
vocab: HashMap<String, u32, S>,
merges: Merges,
) -> Self {
self.config.vocab = FxHashMap::from_iter(vocab);
self.config.merges = merges;
self
}
Expand Down Expand Up @@ -324,7 +331,7 @@ impl BPE {
let mut buffer = String::new();
vocab_file.read_to_string(&mut buffer)?;
let json: Value = serde_json::from_str(&buffer)?;
let mut vocab = HashMap::new();
let mut vocab = FxHashMap::default();
match json {
Value::Object(m) => {
for (token, id) in m {
Expand Down Expand Up @@ -493,7 +500,7 @@ impl BPE {
impl Model for BPE {
type Trainer = BpeTrainer;

fn get_vocab(&self) -> HashMap<String, u32> {
fn get_vocab(&self) -> FxHashMap<String, u32> {
self.vocab.clone()
}

Expand Down Expand Up @@ -533,7 +540,7 @@ impl Model for BPE {
.iter()
.collect();
let mut vocab_file = File::create(&vocab_path)?;
let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r);
let order_vocab_iter = OrderedVocabIter::new(self.vocab_r.clone());
let serialized = serde_json::to_string(&order_vocab_iter)?;
vocab_file.write_all(serialized.as_bytes())?;

Expand Down Expand Up @@ -587,7 +594,7 @@ mod tests {
.iter()
.cloned()
.collect();
let order_vocab_iter = OrderedVocabIter::new(&vocab_r);
let order_vocab_iter = OrderedVocabIter::new(vocab_r.clone());
let serialized = serde_json::to_string(&order_vocab_iter).unwrap();
assert_eq!(serialized, "{\"a\":0,\"b\":1,\"c\":2,\"ab\":3}");
}
Expand Down
6 changes: 3 additions & 3 deletions tokenizers/src/models/bpe/serialization.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE};
use rustc_hash::FxHashMap;
use serde::{
de::{Error, MapAccess, Visitor},
ser::SerializeStruct,
Deserialize, Deserializer, Serialize, Serializer,
};
use std::collections::HashMap;

impl Serialize for BPE {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
Expand Down Expand Up @@ -34,7 +34,7 @@ impl Serialize for BPE {
.into_iter()
.map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone()))
.collect::<Vec<_>>();
let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);
let ordered_vocab = OrderedVocabIter::new(self.vocab_r.clone());

model.serialize_field("vocab", &ordered_vocab)?;
model.serialize_field("merges", &merges)?;
Expand Down Expand Up @@ -80,7 +80,7 @@ impl<'de> Visitor<'de> for BPEVisitor {
V: MapAccess<'de>,
{
let mut builder = BpeBuilder::new();
let mut vocab: Option<HashMap<String, u32>> = None;
let mut vocab: Option<FxHashMap<String, u32>> = None;

#[derive(Debug, Deserialize)]
#[serde(untagged)]
Expand Down
Loading