Skip to content

Commit 24bb087

Browse files
Add a tokenizer for the Mistral backbone (#1383)
1 parent 1cf5c39 commit 24bb087

File tree

4 files changed

+153
-0
lines changed

4 files changed

+153
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from keras_nlp.api_export import keras_nlp_export
15+
from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
16+
17+
18+
@keras_nlp_export("keras_nlp.models.MistralTokenizer")
19+
class MistralTokenizer(SentencePieceTokenizer):
20+
"""Mistral tokenizer layer based on SentencePiece.
21+
22+
This tokenizer class will tokenize raw strings into integer sequences and
23+
is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the
24+
underlying tokenizer, it will check for all special tokens needed by
25+
Mistral models and provides a `from_preset()` method to automatically
26+
download a matching vocabulary for a Mistral preset.
27+
28+
This tokenizer does not provide truncation or padding of inputs. It can be
29+
combined with a `keras_nlp.models.MistralPreprocessor` layer for input
30+
packing.
31+
32+
If input is a batch of strings (rank > 0), the layer will output a
33+
`tf.RaggedTensor` where the last dimension of the output is ragged.
34+
35+
If input is a scalar string (rank == 0), the layer will output a dense
36+
`tf.Tensor` with static shape `[None]`.
37+
38+
Args:
39+
proto: Either a `string` path to a SentencePiece proto file, or a
40+
`bytes` object with a serialized SentencePiece proto. See the
41+
[SentencePiece repository](https://github.com/google/sentencepiece)
42+
for more details on the format.
43+
44+
Examples:
45+
```python
46+
# Unbatched input.
47+
tokenizer = keras_nlp.models.MistralTokenizer.from_preset(
48+
"mistral_base_en",
49+
)
50+
tokenizer("The quick brown fox jumped.")
51+
52+
# Batched input.
53+
tokenizer(["The quick brown fox jumped.", "The fox slept."])
54+
55+
# Detokenization.
56+
tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
57+
```
58+
"""
59+
60+
def __init__(self, proto, **kwargs):
61+
super().__init__(proto=proto, **kwargs)
62+
63+
# Check for necessary special tokens.
64+
start_token = "<s>"
65+
end_token = "</s>"
66+
for token in [start_token, end_token]:
67+
if token not in self.get_vocabulary():
68+
raise ValueError(
69+
f"Cannot find token `'{token}'` in the provided "
70+
f"`vocabulary`. Please provide `'{token}'` in your "
71+
"`vocabulary` or use a pretrained `vocabulary` name."
72+
)
73+
74+
self.start_token_id = self.token_to_id(start_token)
75+
self.end_token_id = self.token_to_id(end_token)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer
18+
from keras_nlp.tests.test_case import TestCase
19+
20+
21+
class MistralTokenizerTest(TestCase):
22+
def setUp(self):
23+
self.init_kwargs = {
24+
# Generated using create_mistral_test_proto.py
25+
"proto": os.path.join(
26+
self.get_test_data_dir(), "mistral_test_vocab.spm"
27+
)
28+
}
29+
self.input_data = ["the quick brown fox", "the earth is round"]
30+
31+
def test_tokenizer_basics(self):
32+
self.run_preprocessing_layer_test(
33+
cls=MistralTokenizer,
34+
init_kwargs=self.init_kwargs,
35+
input_data=self.input_data,
36+
expected_output=[[3, 8, 4, 6], [3, 5, 7, 9]],
37+
)
38+
39+
def test_errors_missing_special_tokens(self):
40+
with self.assertRaises(ValueError):
41+
MistralTokenizer(
42+
# Generated using create_no_special_token_proto.py
43+
proto=os.path.join(
44+
self.get_test_data_dir(), "no_special_token_vocab.spm"
45+
)
46+
)
232 KB
Binary file not shown.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from tools.sentencepiece_testing.utils import train_sentencepiece
16+
17+
18+
def main():
19+
train_sentencepiece(
20+
["the quick brown fox", "the earth is round"],
21+
"mistral_test_vocab.spm",
22+
vocab_size=10,
23+
model_type="WORD",
24+
pad_id=-1,
25+
unk_id=0,
26+
bos_id=1,
27+
eos_id=2,
28+
)
29+
30+
31+
if __name__ == "__main__":
32+
main()

0 commit comments

Comments
 (0)