Skip to content

Commit

Permalink
added tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed Mar 14, 2024
1 parent d4184da commit 16cfff5
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 28 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Developing and training transformer-based models is typically resource-intensive
- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU.
- Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models.
- True random number generators in Jax which do not need the verbose code.
- A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU etc.
- A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU, Tokenizer etc.
- Each model is contained in a single file with no external dependencies, so the source code can also be easily used.
- True random number generators in Jax which do not need the verbose code (examples shown in next sections).

Expand Down
3 changes: 2 additions & 1 deletion nanodl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from nanodl.__src.sklearn_gpu.bayes import NaiveBayesClassifier
from nanodl.__src.sklearn_gpu.dimensionality_reduction import PCA
from nanodl.__src.sklearn_gpu.clustering import KMeans, GaussianMixtureModel
from nanodl.__src.utils.tokenizer import Tokenizer
from nanodl.__src.utils.random import *

from nanodl.__src.sklearn_gpu.regression import (
Expand Down Expand Up @@ -271,6 +272,7 @@
"Dataset",
"ArrayDataset",
"DataLoader",
"Tokenizer",
"batch_cosine_similarities",
"batch_pearsonr",
"classification_scores",
Expand Down Expand Up @@ -312,7 +314,6 @@
"permutation",
"gumbel",
"choice",
"binomial",
"bits",
"exponential",
"triangular",
Expand Down
22 changes: 0 additions & 22 deletions nanodl/__src/utils/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,28 +164,6 @@ def choice(a: Union[int, jnp.ndarray],
p=p,
axis=axis)

def binomial(n: int,
p: float,
shape: Tuple[int, ...] = (),
dtype: Any = jnp.float32,
seed=None) -> jnp.ndarray:
"""Draw samples from a binomial distribution.
Args:
n (int): The number of trials.
p (float): The probability of success of an individual trial.
shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to ().
dtype (Any, optional): The data type of the output tensor. Defaults to jnp.int32.
Returns:
jnp.ndarray: A tensor of samples from a binomial distribution.
"""
return random.binomial(time_rng_key(seed),
n,
p,
shape=shape,
dtype=dtype)

def bits(shape: Tuple[int, ...],
dtype: Any = jnp.uint32,
seed=None) -> jnp.ndarray:
Expand Down
107 changes: 107 additions & 0 deletions nanodl/__src/utils/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os
from typing import List, Optional
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer

class Tokenizer:
"""
A tokenizer class that utilizes SentencePiece to encode and decode text.
This class can be initialized with either an existing SentencePiece model
or a dataset to train a new model. It provides methods to encode a string
to a list of token ids and decode a list of token ids back to a string.
Attributes:
sp_model (SentencePieceProcessor): The SentencePiece processor.
n_words (int): Number of words in the vocabulary.
bos_id (int): Token id for the beginning of a sentence.
eos_id (int): Token id for the end of a sentence.
pad_id (int): Token id for padding.
Example usage:
Training a new model and encoding/decoding a string:
```python
# Initialize tokenizer with training data and train a new model.
text_paths = ['/Users/mac1/Desktop/nanodl/nanodl/__src/utils/sample.txt']
tokenizer = Tokenizer(training_data=text_paths,
vocab_size=100,
model_type='bpe',
max_sentence_length=50)
# Encode a sentence.
encoded_sentence = tokenizer.encode('Hello, world!')
print(f'Encoded: {encoded_sentence}')
# Decode the encoded sentence.
decoded_sentence = tokenizer.decode(encoded_sentence)
print(f'Decoded: {decoded_sentence}')
```
Loading an existing model and encoding/decoding a string:
```python
# Initialize tokenizer with a pre-trained model.
tokenizer = Tokenizer(model_path='path/to/model.model')
# Encode a sentence.
encoded_sentence = tokenizer.encode('Hello, world!')
print(f'Encoded: {encoded_sentence}')
# Decode the encoded sentence.
decoded_sentence = tokenizer.decode(encoded_sentence)
print(f'Decoded: {decoded_sentence}')
```
"""
def __init__(self,
training_data: List[str],
vocab_size: int,
model_type: str = "bpe",
max_sentence_length: int = 512,
model_path: Optional[str] = None):

if model_path and os.path.isfile(model_path):
# Load an existing model
self.sp_model = SentencePieceProcessor(model_file=model_path)
elif training_data and all(os.path.isfile(f) for f in training_data):
# Train a new model using a list of data files
input_files = ','.join(training_data)
model_prefix = "trained_model"
SentencePieceTrainer.train(
input=input_files,
model_prefix=model_prefix,
vocab_size=vocab_size,
model_type=model_type,
max_sentence_length=max_sentence_length,
)

self.sp_model = SentencePieceProcessor(model_file=f"{model_prefix}.model")
else:
raise ValueError("Must provide either a model_path or a non-empty training_data list")

# Initialize token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()

assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

def encode(self,
s: str,
bos: bool = True,
eos: bool = False) -> List[int]:
"""Converts a string into a list of tokens."""
assert isinstance(s, str)
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t

def decode(self,
t: List[int]) -> str:
"""Converts a list of tokens back into a string."""
return self.sp_model.decode(t)
3 changes: 3 additions & 0 deletions tests/files/sample.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Hello, world! This is a test of the Tokenizer.
Let's see how it tokenizes this file.
Another sentence to check the tokenization process.
4 changes: 0 additions & 4 deletions tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ def test_choice(self):
result = choice(5, shape=(3,), seed=42)
self.assertEqual(result.shape, (3,))

def test_binomial(self):
result = binomial(10, 0.5, (2, 2), seed=42)
self.assertEqual(result.shape, (2, 2))

def test_bits(self):
result = bits((2, 2), seed=42)
self.assertEqual(result.shape, (2, 2))
Expand Down
17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,22 @@ def test_random_flip_image(self):
self.assertEqual(flipped_image.shape, (5, 5, 3))


class TestTokenizerEncodingDecoding(unittest.TestCase):
def setUp(self):
"""Set up the tokenizer with specific training data."""
text_paths = ['tests/files/sample.txt']
self.tokenizer = Tokenizer(training_data=text_paths,
vocab_size=100,
model_type='bpe',
max_sentence_length=50)

def test_encode_decode(self):
"""Test that encoding followed by decoding returns the original sentence."""
test_sentence = "Hello, test"
encoded_sentence = self.tokenizer.encode(test_sentence)
decoded_sentence = self.tokenizer.decode(encoded_sentence)
self.assertEqual(test_sentence, decoded_sentence)


if __name__ == '__main__':
unittest.main()
Binary file added trained_model.model
Binary file not shown.
100 changes: 100 additions & 0 deletions trained_model.vocab
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
<unk> 0
<s> 0
</s> 0
▁t -0
is -1
en -2
es -3
iz -4
ok -5
▁T -6
his -7
eniz -8
okeniz -9
He -10
Le -11
ee -12
er -13
fi -14
he -15
ho -16
it -17
ld -18
le -19
ll -20
of -21
or -22
▁a -23
▁s -24
▁w -25
Let -26
est -27
how -28
llo -29
▁He -30
▁fi -31
▁is -32
▁it -33
▁of -34
orld -35
▁Let -36
▁how -37
▁see -38
▁the -39
▁This -40
▁file -41
▁test -42
▁this -43
▁Hello -44
▁world -45
okenizer -46
okenizes -47
▁Tokenizer -48
▁tokenizes -49
Th -50
To -51
el -52
et -53
hi -54
il -55
ke -56
lo -57
ni -58
ow -59
rl -60
se -61
st -62
te -63
th -64
to -65
wo -66
ze -67
▁H -68
▁L -69
▁f -70
▁h -71
▁i -72
▁o -73
▁ -74
e -75
i -76
s -77
t -78
o -79
h -80
l -81
. -82
T -83
f -84
k -85
n -86
r -87
w -88
z -89
! -90
' -91
, -92
H -93
L -94
a -95
d -96

0 comments on commit 16cfff5

Please sign in to comment.