1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import io
15+ import os
1616
1717import pytest
18- import sentencepiece
1918
2019from keras_nlp .models .albert .albert_tokenizer import AlbertTokenizer
2120from keras_nlp .tests .test_case import TestCase
2221
2322
2423class AlbertTokenizerTest (TestCase ):
2524 def setUp (self ):
26- vocab_data = ["the quick brown fox" , "the earth is round" ]
27- bytes_io = io .BytesIO ()
28- sentencepiece .SentencePieceTrainer .train (
29- sentence_iterator = iter (vocab_data ),
30- model_writer = bytes_io ,
31- vocab_size = 12 ,
32- model_type = "WORD" ,
33- pad_id = 0 ,
34- unk_id = 1 ,
35- bos_id = 2 ,
36- eos_id = 3 ,
37- pad_piece = "<pad>" ,
38- unk_piece = "<unk>" ,
39- bos_piece = "[CLS]" ,
40- eos_piece = "[SEP]" ,
41- user_defined_symbols = "[MASK]" ,
42- )
43- self .init_kwargs = {"proto" : bytes_io .getvalue ()}
25+ self .init_kwargs = {
26+ # Generated using create_albert_test_proto.py
27+ "proto" : os .path .join (
28+ self .get_test_data_dir (), "albert_test_vocab.spm"
29+ )
30+ }
4431 self .input_data = ["the quick brown fox." , "the earth is round." ]
4532
4633 def test_tokenizer_basics (self ):
@@ -52,17 +39,13 @@ def test_tokenizer_basics(self):
5239 )
5340
5441 def test_errors_missing_special_tokens (self ):
55- bytes_io = io .BytesIO ()
56- sentencepiece .SentencePieceTrainer .train (
57- sentence_iterator = iter (["abc" ]),
58- model_writer = bytes_io ,
59- vocab_size = 5 ,
60- pad_id = - 1 ,
61- eos_id = - 1 ,
62- bos_id = - 1 ,
63- )
6442 with self .assertRaises (ValueError ):
65- AlbertTokenizer (proto = bytes_io .getvalue ())
43+ AlbertTokenizer (
44+ # Generated using create_no_special_token_proto.py
45+ proto = os .path .join (
46+ self .get_test_data_dir (), "no_special_token_vocab.spm"
47+ )
48+ )
6649
6750 @pytest .mark .large
6851 def test_smallest_preset (self ):
0 commit comments