-
Notifications
You must be signed in to change notification settings - Fork 306
Remove the use of SentencePieceTrainer from tests
#1283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
c7c24d4
4261061
dba4ac5
e316ff7
8993672
5e6c26c
4aa6347
0134799
ec944b8
a76c577
a77b362
bf0b044
0ce2274
234b908
75ead27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,10 +12,9 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import io | ||
| import pathlib | ||
|
|
||
| import pytest | ||
| import sentencepiece | ||
|
|
||
| from keras_nlp.models.albert.albert_backbone import AlbertBackbone | ||
| from keras_nlp.models.albert.albert_classifier import AlbertClassifier | ||
|
|
@@ -27,26 +26,16 @@ | |
| class AlbertClassifierTest(TestCase): | ||
| def setUp(self): | ||
| # Setup model. | ||
| vocab_data = ["the quick brown fox", "the earth is round"] | ||
| bytes_io = io.BytesIO() | ||
| sentencepiece.SentencePieceTrainer.train( | ||
| sentence_iterator=iter(vocab_data), | ||
| model_writer=bytes_io, | ||
| vocab_size=12, | ||
| model_type="WORD", | ||
| pad_id=0, | ||
| unk_id=1, | ||
| bos_id=2, | ||
| eos_id=3, | ||
| pad_piece="<pad>", | ||
| unk_piece="<unk>", | ||
| bos_piece="[CLS]", | ||
| eos_piece="[SEP]", | ||
| user_defined_symbols="[MASK]", | ||
| ) | ||
| self.preprocessor = AlbertPreprocessor( | ||
| AlbertTokenizer(proto=bytes_io.getvalue()), | ||
| sequence_length=5, | ||
| AlbertTokenizer( | ||
| proto=str( | ||
| pathlib.Path(__file__).parent.parent.parent | ||
| / "tests" | ||
| / "test_data" | ||
| / "albert_sentencepiece.proto" | ||
|
||
| ), | ||
| sequence_length=5, | ||
| ) | ||
| ) | ||
| self.backbone = AlbertBackbone( | ||
| vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,35 +12,24 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import io | ||
| import pathlib | ||
|
|
||
| import pytest | ||
| import sentencepiece | ||
|
|
||
| from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer | ||
| from keras_nlp.tests.test_case import TestCase | ||
|
|
||
|
|
||
| class AlbertTokenizerTest(TestCase): | ||
| def setUp(self): | ||
| vocab_data = ["the quick brown fox", "the earth is round"] | ||
| bytes_io = io.BytesIO() | ||
| sentencepiece.SentencePieceTrainer.train( | ||
| sentence_iterator=iter(vocab_data), | ||
| model_writer=bytes_io, | ||
| vocab_size=12, | ||
| model_type="WORD", | ||
| pad_id=0, | ||
| unk_id=1, | ||
| bos_id=2, | ||
| eos_id=3, | ||
| pad_piece="<pad>", | ||
| unk_piece="<unk>", | ||
| bos_piece="[CLS]", | ||
| eos_piece="[SEP]", | ||
| user_defined_symbols="[MASK]", | ||
| ) | ||
| self.init_kwargs = {"proto": bytes_io.getvalue()} | ||
| self.init_kwargs = { | ||
| "proto": str( | ||
| pathlib.Path(__file__).parent.parent.parent | ||
| / "tests" | ||
| / "test_data" | ||
| / "albert_sentencepiece.proto" | ||
| ) | ||
| } | ||
| self.input_data = ["the quick brown fox.", "the earth is round."] | ||
|
|
||
| def test_tokenizer_basics(self): | ||
|
|
@@ -52,17 +41,15 @@ def test_tokenizer_basics(self): | |
| ) | ||
|
|
||
| def test_errors_missing_special_tokens(self): | ||
| bytes_io = io.BytesIO() | ||
| sentencepiece.SentencePieceTrainer.train( | ||
| sentence_iterator=iter(["abc"]), | ||
| model_writer=bytes_io, | ||
| vocab_size=5, | ||
| pad_id=-1, | ||
| eos_id=-1, | ||
| bos_id=-1, | ||
| ) | ||
| with self.assertRaises(ValueError): | ||
| AlbertTokenizer(proto=bytes_io.getvalue()) | ||
| AlbertTokenizer( | ||
| proto=str( | ||
| pathlib.Path(__file__).parent.parent.parent | ||
| / "tests" | ||
| / "test_data" | ||
| / "sentencepiece_bad.proto" | ||
|
||
| ) | ||
| ) | ||
|
|
||
| @pytest.mark.large | ||
| def test_smallest_preset(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit of a mouthful. Can we maybe add this to our base class for tests in
test_case.py?proto=os.path.join(self.test_data_dir(), "albert_test_vocab.spm")There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.