diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 893d93000b..1cedd0c688 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -25,6 +25,12 @@ from keras_nlp.models.distil_bert.distil_bert_classifier import ( DistilBertClassifier, ) +from keras_nlp.models.distil_bert.distil_bert_masked_lm import ( + DistilBertMaskedLM, +) +from keras_nlp.models.distil_bert.distil_bert_masked_lm_preprocessor import ( + DistilBertMaskedLMPreprocessor, +) from keras_nlp.models.distil_bert.distil_bert_preprocessor import ( DistilBertPreprocessor, ) diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py new file mode 100644 index 0000000000..386504f569 --- /dev/null +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py @@ -0,0 +1,155 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DistilBERT masked lm model.""" + +import copy + +from tensorflow import keras + +from keras_nlp.layers.masked_lm_head import MaskedLMHead +from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone +from keras_nlp.models.distil_bert.distil_bert_backbone import ( + distilbert_kernel_initializer, +) +from keras_nlp.models.distil_bert.distil_bert_masked_lm_preprocessor import ( + DistilBertMaskedLMPreprocessor, +) +from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets +from keras_nlp.models.task import Task +from keras_nlp.utils.python_utils import classproperty + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class DistilBertMaskedLM(Task): + """An end-to-end DistilBERT model for the masked language modeling task. + + This model will train DistilBERT on a masked language modeling task. + The model will predict labels for a number of masked tokens in the + input data. For usage of this model with pre-trained weights, see the + `from_preset()` method. + + This model can optionally be configured with a `preprocessor` layer, in + which case inputs can be raw string features during `fit()`, `predict()`, + and `evaluate()`. Inputs will be tokenized and dynamically masked during + training and evaluation. This is done by default when creating the model + with `from_preset()`. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. The underlying model is provided by a + third party and subject to a separate license, available + [here](https://github.com/huggingface/transformers). + + Args: + backbone: A `keras_nlp.models.DistilBertBackbone` instance. + preprocessor: A `keras_nlp.models.DistilBertMaskedLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + + Example usage: + + Raw string inputs and pretrained backbone. + ```python + # Create a dataset with raw string features. Labels are inferred. + features = ["The quick brown fox jumped.", "I forgot my homework."] + + # Create a DistilBertMaskedLM with a pretrained backbone and further train + # on an MLM task. + masked_lm = keras_nlp.models.DistilBertMaskedLM.from_preset( + "distil_bert_base_en", + ) + masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + ) + masked_lm.fit(x=features, batch_size=2) + ``` + + Preprocessed inputs and custom backbone. + ```python + # Create a preprocessed dataset where 0 is the mask token. + preprocessed_features = { + "token_ids": tf.constant( + [[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8) + ), + "padding_mask": tf.constant( + [[1, 1, 1, 1, 1, 1, 1, 1]] * 2, shape=(2, 8) + ), + "mask_positions": tf.constant([[2, 4]] * 2, shape=(2, 2)) + } + # Labels are the original masked values. + labels = [[3, 5]] * 2 + + # Randomly initialize a DistilBERT encoder + backbone = keras_nlp.models.DistilBertBackbone( + vocabulary_size=50265, + num_layers=12, + num_heads=12, + hidden_dim=768, + intermediate_dim=3072, + max_sequence_length=12 + ) + # Create a DistilBERT masked_lm and fit the data. + masked_lm = keras_nlp.models.DistilBertMaskedLM( + backbone, + preprocessor=None, + ) + masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + ) + masked_lm.fit(x=preprocessed_features, y=labels, batch_size=2) + ``` + """ + + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): + inputs = { + **backbone.input, + "mask_positions": keras.Input( + shape=(None,), dtype="int32", name="mask_positions" + ), + } + backbone_outputs = backbone(backbone.input) + outputs = MaskedLMHead( + vocabulary_size=backbone.vocabulary_size, + embedding_weights=backbone.token_embedding.embeddings, + intermediate_activation="gelu", + kernel_initializer=distilbert_kernel_initializer(), + name="mlm_head", + )(backbone_outputs, inputs["mask_positions"]) + + # Instantiate using Functional API Model constructor + super().__init__( + inputs=inputs, + outputs=outputs, + include_preprocessing=preprocessor is not None, + **kwargs, + ) + # All references to `self` below this line + self.backbone = backbone + self.preprocessor = preprocessor + + @classproperty + def backbone_cls(cls): + return DistilBertBackbone + + @classproperty + def preprocessor_cls(cls): + return DistilBertMaskedLMPreprocessor + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor.py new file mode 100644 index 0000000000..bf19585b99 --- /dev/null +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor.py @@ -0,0 +1,143 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DistilBERT masked language model preprocessor layer.""" + +from absl import logging +from tensorflow import keras + +from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator +from keras_nlp.models.distil_bert.distil_bert_preprocessor import ( + DistilBertPreprocessor, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class DistilBertMaskedLMPreprocessor(DistilBertPreprocessor): + """DistilBERT preprocessing for the masked language modeling task. + + This preprocessing layer will prepare inputs for a masked language modeling + task. It is primarily intended for use with the + `keras_nlp.models.DistilBertMaskedLM` task model. Preprocessing will occur in + multiple steps. + + - Tokenize any number of input segments using the `tokenizer`. + - Pack the inputs together using a `keras_nlp.layers.MultiSegmentPacker`. + with the appropriate `"[CLS]"`, `"[SEP]"` and `"[PAD]"` tokens. + - Randomly select non-special tokens to mask, controlled by + `mask_selection_rate`. + - Construct a `(x, y, sample_weight)` tuple suitable for training with a + `keras_nlp.models.DistilBertMaskedLM` task model. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.DistilBertMaskedLMPreprocessor.from_preset( + "distil_bert_base_en" + ) + + # Tokenize and mask a single sentence. + sentence = tf.constant("The quick brown fox jumped.") + preprocessor(sentence) + + # Tokenize and mask a batch of sentences. + sentences = tf.constant( + ["The quick brown fox jumped.", "Call me Ishmael."] + ) + preprocessor(sentences) + + # Tokenize and mask a dataset of sentences. + features = tf.constant( + ["The quick brown fox jumped.", "Call me Ishmael."] + ) + ds = tf.data.Dataset.from_tensor_slices((features)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Alternatively, you can create a preprocessor from your own vocabulary. + # The usage is exactly the same as above. + vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]"] + vocab += ["The", "qu", "##ick", "br", "##own", "fox", "tripped"] + vocab += ["Call", "me", "Ish", "##mael", "."] + vocab += ["Oh", "look", "a", "whale"] + vocab += ["I", "forgot", "my", "home", "##work"] + tokenizer = keras_nlp.models.DistilBertTokenizer(vocabulary=vocab) + preprocessor = keras_nlp.models.DistilBertMaskedLMPreprocessor(tokenizer) + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=512, + truncate="round_robin", + mask_selection_rate=0.15, + mask_selection_length=96, + mask_token_rate=0.8, + random_token_rate=0.1, + **kwargs, + ): + super().__init__( + tokenizer, + sequence_length=sequence_length, + truncate=truncate, + **kwargs, + ) + + self.masker = MaskedLMMaskGenerator( + mask_selection_rate=mask_selection_rate, + mask_selection_length=mask_selection_length, + mask_token_rate=mask_token_rate, + random_token_rate=random_token_rate, + vocabulary_size=tokenizer.vocabulary_size(), + mask_token_id=tokenizer.mask_token_id, + unselectable_token_ids=[ + tokenizer.cls_token_id, + tokenizer.sep_token_id, + tokenizer.pad_token_id, + ], + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "mask_selection_rate": self.masker.mask_selection_rate, + "mask_selection_length": self.masker.mask_selection_length, + "mask_token_rate": self.masker.mask_token_rate, + "random_token_rate": self.masker.random_token_rate, + } + ) + return config + + def call(self, x, y=None, sample_weight=None): + if y is not None or sample_weight is not None: + logging.warning( + f"{self.__class__.__name__} generates `y` and `sample_weight` " + "based on your input data, but your data already contains `y` " + "or `sample_weight`. Your `y` and `sample_weight` will be " + "ignored." + ) + + x = super().call(x) + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + masker_outputs = self.masker(token_ids) + x = { + "token_ids": masker_outputs["token_ids"], + "padding_mask": padding_mask, + "mask_positions": masker_outputs["mask_positions"], + } + y = masker_outputs["mask_ids"] + sample_weight = masker_outputs["mask_weights"] + return pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py new file mode 100644 index 0000000000..62eadb5e81 --- /dev/null +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py @@ -0,0 +1,126 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DistilBERT masked language model preprocessor layer.""" + +import os + +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.distil_bert.distil_bert_masked_lm_preprocessor import ( + DistilBertMaskedLMPreprocessor, +) +from keras_nlp.models.distil_bert.distil_bert_tokenizer import ( + DistilBertTokenizer, +) + + +class DistilBertMaskedLMPreprocessorTest( + tf.test.TestCase, parameterized.TestCase +): + def setUp(self): + self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab += ["THE", "QUICK", "BROWN", "FOX"] + self.vocab += ["the", "quick", "brown", "fox"] + + self.preprocessor = DistilBertMaskedLMPreprocessor( + tokenizer=DistilBertTokenizer( + vocabulary=self.vocab, + ), + # Simplify out testing by masking every available token. + mask_selection_rate=1.0, + mask_token_rate=1.0, + random_token_rate=0.0, + mask_selection_length=5, + sequence_length=8, + ) + + def test_preprocess_strings(self): + input_data = " THE QUICK BROWN FOX." + + x, y, sw = self.preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [2, 4, 4, 4, 4, 4, 3, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0]) + self.assertAllEqual(x["mask_positions"], [1, 2, 3, 4, 5]) + self.assertAllEqual(y, [5, 6, 7, 8, 1]) + self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0, 1.0]) + + def test_preprocess_list_of_strings(self): + input_data = [" THE QUICK BROWN FOX."] * 4 + + x, y, sw = self.preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[2, 4, 4, 4, 4, 4, 3, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4) + self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4, 5]] * 4) + self.assertAllEqual(y, [[5, 6, 7, 8, 1]] * 4) + self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0, 1.0]] * 4) + + def test_preprocess_dataset(self): + sentences = tf.constant([" THE QUICK BROWN FOX."] * 4) + ds = tf.data.Dataset.from_tensor_slices(sentences) + ds = ds.map(self.preprocessor) + x, y, sw = ds.batch(4).take(1).get_single_element() + self.assertAllEqual(x["token_ids"], [[2, 4, 4, 4, 4, 4, 3, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4) + self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4, 5]] * 4) + self.assertAllEqual(y, [[5, 6, 7, 8, 1]] * 4) + self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0, 1.0]] * 4) + + def test_mask_multiple_sentences(self): + sentence_one = tf.constant(" THE QUICK") + sentence_two = tf.constant(" BROWN FOX.") + + x, y, sw = self.preprocessor((sentence_one, sentence_two)) + self.assertAllEqual(x["token_ids"], [2, 4, 4, 3, 4, 4, 4, 3]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1]) + self.assertAllEqual(x["mask_positions"], [1, 2, 4, 5, 6]) + self.assertAllEqual(y, [5, 6, 7, 8, 1]) + self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0, 1.0]) + + def test_no_masking_zero_rate(self): + no_mask_preprocessor = DistilBertMaskedLMPreprocessor( + self.preprocessor.tokenizer, + mask_selection_rate=0.0, + mask_selection_length=5, + sequence_length=8, + ) + input_data = " THE QUICK BROWN FOX." + + x, y, sw = no_mask_preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [2, 5, 6, 7, 8, 1, 3, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0]) + self.assertAllEqual(x["mask_positions"], [0, 0, 0, 0, 0]) + self.assertAllEqual(y, [0, 0, 0, 0, 0]) + self.assertAllEqual(sw, [0.0, 0.0, 0.0, 0.0, 0.0]) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + input_data = tf.constant([" THE QUICK BROWN FOX."]) + + inputs = keras.Input(dtype="string", shape=()) + outputs = self.preprocessor(inputs) + model = keras.Model(inputs, outputs) + + path = os.path.join(self.get_temp_dir(), filename) + model.save(path, save_format=save_format) + + restored_model = keras.models.load_model(path) + outputs = model(input_data)[0]["token_ids"] + restored_outputs = restored_model(input_data)[0]["token_ids"] + self.assertAllEqual(outputs, restored_outputs) diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py new file mode 100644 index 0000000000..76a5774731 --- /dev/null +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py @@ -0,0 +1,127 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for DistilBERT masked language model.""" + +import os + +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone +from keras_nlp.models.distil_bert.distil_bert_masked_lm import ( + DistilBertMaskedLM, +) +from keras_nlp.models.distil_bert.distil_bert_masked_lm_preprocessor import ( + DistilBertMaskedLMPreprocessor, +) +from keras_nlp.models.distil_bert.distil_bert_tokenizer import ( + DistilBertTokenizer, +) + + +class DistilBertMaskedLMTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + self.backbone = DistilBertBackbone( + vocabulary_size=1000, + num_layers=2, + num_heads=2, + hidden_dim=64, + intermediate_dim=128, + max_sequence_length=128, + ) + self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab += ["THE", "QUICK", "BROWN", "FOX"] + self.vocab += ["the", "quick", "brown", "fox"] + self.preprocessor = DistilBertMaskedLMPreprocessor( + DistilBertTokenizer(vocabulary=self.vocab), + sequence_length=8, + mask_selection_length=2, + ) + self.masked_lm = DistilBertMaskedLM( + self.backbone, + preprocessor=self.preprocessor, + ) + self.masked_lm_no_preprocessing = DistilBertMaskedLM( + self.backbone, + preprocessor=None, + ) + + self.raw_batch = tf.constant( + [ + "the quick brown fox.", + "the slow brown fox.", + "the smelly brown fox.", + "the old brown fox.", + ] + ) + self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] + self.raw_dataset = tf.data.Dataset.from_tensor_slices( + self.raw_batch + ).batch(2) + self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) + + def test_valid_call_masked_lm(self): + self.masked_lm(self.preprocessed_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_distilbert_masked_lm_predict(self, jit_compile): + self.masked_lm.compile(jit_compile=jit_compile) + self.masked_lm.predict(self.raw_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_distilbert_masked_lm_predict_no_preprocessing(self, jit_compile): + self.masked_lm_no_preprocessing.compile(jit_compile=jit_compile) + self.masked_lm_no_preprocessing.predict(self.preprocessed_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_distilbert_masked_lm_fit(self, jit_compile): + self.masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=jit_compile, + ) + self.masked_lm.fit(self.raw_dataset) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_distilbert_masked_lm_fit_no_preprocessing(self, jit_compile): + self.masked_lm_no_preprocessing.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=jit_compile, + ) + self.masked_lm_no_preprocessing.fit(self.preprocessed_dataset) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + save_path = os.path.join(self.get_temp_dir(), filename) + self.masked_lm.save(save_path, save_format=save_format) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, DistilBertMaskedLM) + + model_output = self.masked_lm(self.preprocessed_batch) + restored_output = restored_model(self.preprocessed_batch) + + self.assertAllClose(model_output, restored_output) diff --git a/keras_nlp/models/distil_bert/distil_bert_tokenizer.py b/keras_nlp/models/distil_bert/distil_bert_tokenizer.py index 3d0dd7d794..7fc402e379 100644 --- a/keras_nlp/models/distil_bert/distil_bert_tokenizer.py +++ b/keras_nlp/models/distil_bert/distil_bert_tokenizer.py @@ -52,23 +52,23 @@ class DistilBertTokenizer(WordPieceTokenizer): Examples: Batched input. - >>> vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]"] + >>> vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] >>> vocab += ["The", "qu", "##ick", "brown", "fox", "."] >>> inputs = ["The quick brown fox.", "The fox."] >>> tokenizer = keras_nlp.models.DistilBertTokenizer(vocabulary=vocab) >>> tokenizer(inputs) - + Unbatched input. - >>> vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]"] + >>> vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] >>> vocab += ["The", "qu", "##ick", "brown", "fox", "."] >>> inputs = "The fox." >>> tokenizer = keras_nlp.models.DistilBertTokenizer(vocabulary=vocab) >>> tokenizer(inputs) - + Detokenization. - >>> vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]"] + >>> vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] >>> vocab += ["The", "qu", "##ick", "brown", "fox", "."] >>> inputs = "The quick brown fox." >>> tokenizer = keras_nlp.models.DistilBertTokenizer(vocabulary=vocab) @@ -92,7 +92,8 @@ def __init__( cls_token = "[CLS]" sep_token = "[SEP]" pad_token = "[PAD]" - for token in [cls_token, pad_token, sep_token]: + mask_token = "[MASK]" + for token in [cls_token, pad_token, sep_token, mask_token]: if token not in self.get_vocabulary(): raise ValueError( f"Cannot find token `'{token}'` in the provided " @@ -103,6 +104,7 @@ def __init__( self.cls_token_id = self.token_to_id(cls_token) self.sep_token_id = self.token_to_id(sep_token) self.pad_token_id = self.token_to_id(pad_token) + self.mask_token_id = self.token_to_id(mask_token) @classproperty def presets(cls):