diff --git a/keras_nlp/layers/masked_lm_head.py b/keras_nlp/layers/masked_lm_head.py index 2a31aa7c91..2f749eb0c7 100644 --- a/keras_nlp/layers/masked_lm_head.py +++ b/keras_nlp/layers/masked_lm_head.py @@ -140,7 +140,10 @@ def __init__( self.vocabulary_size = shape[0] def build(self, input_shapes): - feature_size = input_shapes[-1] + if self.embedding_weights is not None: + feature_size = self.embedding_weights.shape[-1] + else: + feature_size = input_shapes[-1] self._dense = keras.layers.Dense( feature_size, diff --git a/keras_nlp/layers/masked_lm_head_test.py b/keras_nlp/layers/masked_lm_head_test.py index 15734c88c8..f66f0b1918 100644 --- a/keras_nlp/layers/masked_lm_head_test.py +++ b/keras_nlp/layers/masked_lm_head_test.py @@ -46,15 +46,17 @@ def test_valid_call_with_embedding_weights(self): embedding_weights=embedding.embeddings, activation="softmax", ) - encoded_tokens = keras.Input(shape=(10, 16)) + # Use a difference "hidden dim" for the model than "embedding dim", we + # need to support this in the layer. + sequence = keras.Input(shape=(10, 32)) positions = keras.Input(shape=(5,), dtype="int32") - outputs = head(encoded_tokens, mask_positions=positions) - model = keras.Model((encoded_tokens, positions), outputs) - token_data = tf.random.uniform(shape=(4, 10, 16)) + outputs = head(sequence, mask_positions=positions) + model = keras.Model((sequence, positions), outputs) + sequence_data = tf.random.uniform(shape=(4, 10, 32)) position_data = tf.random.uniform( shape=(4, 5), maxval=10, dtype="int32" ) - model((token_data, position_data)) + model((sequence_data, position_data)) def test_get_config_and_from_config(self): head = masked_lm_head.MaskedLMHead( diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 893d93000b..a490a113d3 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -13,6 +13,10 @@ # limitations under the License. from keras_nlp.models.albert.albert_backbone import AlbertBackbone +from keras_nlp.models.albert.albert_masked_lm import AlbertMaskedLM +from keras_nlp.models.albert.albert_masked_lm_preprocessor import ( + AlbertMaskedLMPreprocessor, +) from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer from keras_nlp.models.bart.bart_backbone import BartBackbone diff --git a/keras_nlp/models/albert/albert_classifier_test.py b/keras_nlp/models/albert/albert_classifier_test.py index 622e8bae3b..40fec53486 100644 --- a/keras_nlp/models/albert/albert_classifier_test.py +++ b/keras_nlp/models/albert/albert_classifier_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for BERT classification model.""" +"""Tests for ALBERT classification model.""" import io import os @@ -57,6 +57,7 @@ def setUp(self): unk_piece="", bos_piece="[CLS]", eos_piece="[SEP]", + user_defined_symbols="[MASK]", ) self.proto = bytes_io.getvalue() diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py new file mode 100644 index 0000000000..2fe248856c --- /dev/null +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -0,0 +1,154 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ALBERT masked LM model.""" + +import copy + +from tensorflow import keras + +from keras_nlp.layers.masked_lm_head import MaskedLMHead +from keras_nlp.models.albert.albert_backbone import AlbertBackbone +from keras_nlp.models.albert.albert_backbone import albert_kernel_initializer +from keras_nlp.models.albert.albert_masked_lm_preprocessor import ( + AlbertMaskedLMPreprocessor, +) +from keras_nlp.models.albert.albert_presets import backbone_presets +from keras_nlp.models.task import Task +from keras_nlp.utils.python_utils import classproperty + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class AlbertMaskedLM(Task): + """An end-to-end ALBERT model for the masked language modeling task. + + This model will train ALBERT on a masked language modeling task. + The model will predict labels for a number of masked tokens in the + input data. For usage of this model with pre-trained weights, see the + `from_preset()` method. + + This model can optionally be configured with a `preprocessor` layer, in + which case inputs can be raw string features during `fit()`, `predict()`, + and `evaluate()`. Inputs will be tokenized and dynamically masked during + training and evaluation. This is done by default when creating the model + with `from_preset()`. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. + + Args: + backbone: A `keras_nlp.models.AlbertBackbone` instance. + preprocessor: A `keras_nlp.models.AlbertMaskedLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + + Example usage: + + Raw string inputs and pretrained backbone. + ```python + # Create a dataset with raw string features. Labels are inferred. + features = ["The quick brown fox jumped.", "I forgot my homework."] + + # Create a AlbertMaskedLM with a pretrained backbone and further train + # on an MLM task. + masked_lm = keras_nlp.models.AlbertMaskedLM.from_preset( + "albert_base_en_uncased", + ) + masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + ) + masked_lm.fit(x=features, batch_size=2) + ``` + + Preprocessed inputs and custom backbone. + ```python + # Create a preprocessed dataset where 0 is the mask token. + preprocessed_features = { + "segment_ids": tf.constant( + [[1, 0, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8) + ), + "token_ids": tf.constant( + [[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8) + ), + "padding_mask": tf.constant( + [[1, 1, 1, 1, 1, 1, 1, 1]] * 2, shape=(2, 8) + ), + "mask_positions": tf.constant([[2, 4]] * 2, shape=(2, 2)) + } + # Labels are the original masked values. + labels = [[3, 5]] * 2 + + # Randomly initialize a ALBERT encoder + backbone = keras_nlp.models.AlbertBackbone( + vocabulary_size=1000, + num_layers=2, + num_heads=2, + embedding_dim=64, + hidden_dim=64, + intermediate_dim=128, + max_sequence_length=128) + + # Create a ALBERT masked LM and fit the data. + masked_lm = keras_nlp.models.AlbertMaskedLM( + backbone, + preprocessor=None, + ) + masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=True + ) + masked_lm.fit(x=preprocessed_features, y=labels, batch_size=2) + ``` + """ + + def __init__(self, backbone, preprocessor=None, **kwargs): + inputs = { + **backbone.input, + "mask_positions": keras.Input( + shape=(None,), dtype="int32", name="mask_positions" + ), + } + + backbone_outputs = backbone(backbone.input) + outputs = MaskedLMHead( + vocabulary_size=backbone.vocabulary_size, + embedding_weights=backbone.token_embedding.embeddings, + intermediate_activation=lambda x: keras.activations.gelu( + x, approximate=True + ), + kernel_initializer=albert_kernel_initializer(), + name="mlm_head", + )(backbone_outputs["sequence_output"], inputs["mask_positions"]) + + super().__init__( + inputs=inputs, + outputs=outputs, + include_preprocessing=preprocessor is not None, + **kwargs + ) + + self.backbone = backbone + self.preprocessor = preprocessor + + @classproperty + def backbone_cls(cls): + return AlbertBackbone + + @classproperty + def preprocessor_cls(cls): + return AlbertMaskedLMPreprocessor + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py new file mode 100644 index 0000000000..1c874a0b7b --- /dev/null +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py @@ -0,0 +1,198 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ALBERT masked language model preprocessor layer.""" + +from absl import logging +from tensorflow import keras + +from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator +from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class AlbertMaskedLMPreprocessor(AlbertPreprocessor): + """ALBERT preprocessing for the masked language modeling task. + + This preprocessing layer will prepare inputs for a masked language modeling + task. It is primarily intended for use with the + `keras_nlp.models.AlbertMaskedLM` task model. Preprocessing will occur in + multiple steps. + + - Tokenize any number of input segments using the `tokenizer`. + - Pack the inputs together with the appropriate `""`, `""` and + `""` tokens, i.e., adding a single `""` at the start of the + entire sequence, `""` between each segment, + and a `""` at the end of the entire sequence. + - Randomly select non-special tokens to mask, controlled by + `mask_selection_rate`. + - Construct a `(x, y, sample_weight)` tuple suitable for training with a + `keras_nlp.models.AlbertMaskedLM` task model. + + Args: + tokenizer: A `keras_nlp.models.AlbertTokenizer` instance. + sequence_length: The length of the packed inputs. + mask_selection_rate: The probability an input token will be dynamically + masked. + mask_selection_length: The maximum number of masked tokens supported + by the layer. + mask_token_rate: float, defaults to 0.8. `mask_token_rate` must be + between 0 and 1 which indicates how often the mask_token is + substituted for tokens selected for masking. + random_token_rate: float, defaults to 0.1. `random_token_rate` must be + between 0 and 1 which indicates how often a random token is + substituted for tokens selected for masking. Default is 0.1. + Note: mask_token_rate + random_token_rate <= 1, and for + (1 - mask_token_rate - random_token_rate), the token will not be + changed. + truncate: string. The algorithm to truncate a list of batched segments + to fit within `sequence_length`. The value can be either + `round_robin` or `waterfall`: + - `"round_robin"`: Available space is assigned one token at a + time in a round-robin fashion to the inputs that still need + some, until the limit is reached. + - `"waterfall"`: The allocation of the budget is done using a + "waterfall" algorithm that allocates quota in a + left-to-right manner and fills up the buckets until we run + out of budget. It supports an arbitrary number of segments. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.AlbertMaskedLMPreprocessor.from_preset( + "albert_base_en_uncased" + ) + + # Tokenize and mask a single sentence. + sentence = tf.constant("The quick brown fox jumped.") + preprocessor(sentence) + + # Tokenize and mask a batch of sentences. + sentences = tf.constant( + ["The quick brown fox jumped.", "Call me Ishmael."] + ) + preprocessor(sentences) + + # Tokenize and mask a dataset of sentences. + features = tf.constant( + ["The quick brown fox jumped.", "Call me Ishmael."] + ) + ds = tf.data.Dataset.from_tensor_slices((features)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Alternatively, you can create a preprocessor from your own vocabulary. + vocab_data = tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the earth is round"] + ) + + # Creating sentencepiece tokenizer for ALBERT LM preprocessor + bytes_io = io.BytesIO() + + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=vocab_data.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=12, + model_type="WORD", + pad_id=0, + unk_id=1, + bos_id=2, + eos_id=3, + pad_piece="", + unk_piece="", + bos_piece="[CLS]", + eos_piece="[SEP]", + user_defined_symbols="[MASK]" + ) + + proto = bytes_io.getvalue() + + tokenizer = AlbertTokenizer(proto=proto) + + preprocessor = AlbertMaskedLMPreprocessor( + tokenizer=tokenizer + ) + + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=512, + truncate="round_robin", + mask_selection_rate=0.15, + mask_selection_length=96, + mask_token_rate=0.8, + random_token_rate=0.1, + **kwargs, + ): + super().__init__( + tokenizer, + sequence_length=sequence_length, + truncate=truncate, + **kwargs, + ) + + self.masker = MaskedLMMaskGenerator( + mask_selection_rate=mask_selection_rate, + mask_selection_length=mask_selection_length, + mask_token_rate=mask_token_rate, + random_token_rate=random_token_rate, + vocabulary_size=tokenizer.vocabulary_size(), + mask_token_id=tokenizer.mask_token_id, + unselectable_token_ids=[ + tokenizer.cls_token_id, + tokenizer.sep_token_id, + tokenizer.pad_token_id, + ], + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "mask_selection_rate": self.masker.mask_selection_rate, + "mask_selection_length": self.masker.mask_selection_length, + "mask_token_rate": self.masker.mask_token_rate, + "random_token_rate": self.masker.random_token_rate, + } + ) + return config + + def call(self, x, y=None, sample_weight=None): + if y is not None or sample_weight is not None: + logging.warning( + f"{self.__class__.__name__} generates `y` and `sample_weight` " + "based on your input data, but your data already contains `y` " + "or `sample_weight`. Your `y` and `sample_weight` will be " + "ignored." + ) + + x = super().call(x) + token_ids, segment_ids, padding_mask = ( + x["token_ids"], + x["segment_ids"], + x["padding_mask"], + ) + masker_outputs = self.masker(token_ids) + x = { + "token_ids": masker_outputs["token_ids"], + "segment_ids": segment_ids, + "padding_mask": padding_mask, + "mask_positions": masker_outputs["mask_positions"], + } + y = masker_outputs["mask_ids"] + sample_weight = masker_outputs["mask_weights"] + return pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py new file mode 100644 index 0000000000..c77f24c213 --- /dev/null +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py @@ -0,0 +1,162 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for ALBERT masked language model preprocessor layer.""" + +import io +import os + +import sentencepiece +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.albert.albert_masked_lm_preprocessor import ( + AlbertMaskedLMPreprocessor, +) +from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer + + +class AlbertMaskedLMPreprocessorTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + vocab_data = tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the earth is round"] + ) + + bytes_io = io.BytesIO() + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=vocab_data.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=12, + model_type="WORD", + pad_id=0, + unk_id=1, + bos_id=2, + eos_id=3, + pad_piece="", + unk_piece="", + bos_piece="[CLS]", + eos_piece="[SEP]", + user_defined_symbols="[MASK]", + ) + + proto = bytes_io.getvalue() + + tokenizer = AlbertTokenizer(proto=proto) + + self.preprocessor = AlbertMaskedLMPreprocessor( + tokenizer=tokenizer, + # Simplify out testing by masking every available token. + mask_selection_rate=1.0, + mask_token_rate=1.0, + random_token_rate=0.0, + mask_selection_length=4, + sequence_length=12, + ) + + def test_preprocess_strings(self): + input_data = "the quick brown fox" + + x, y, sw = self.preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], [2, 4, 4, 4, 4, 3, 0, 0, 0, 0, 0, 0] + ) + self.assertAllEqual( + x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] + ) + self.assertAllEqual(x["mask_positions"], [1, 2, 3, 4]) + self.assertAllEqual(y, [5, 10, 6, 8]) + self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0]) + + def test_preprocess_list_of_strings(self): + input_data = ["the quick brown fox"] * 4 + + x, y, sw = self.preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], [[2, 4, 4, 4, 4, 3, 0, 0, 0, 0, 0, 0]] * 4 + ) + self.assertAllEqual( + x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4 + ) + self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4]] * 4) + self.assertAllEqual(y, [[5, 10, 6, 8]] * 4) + self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0]] * 4) + + def test_preprocess_dataset(self): + sentences = tf.constant(["the quick brown fox"] * 4) + ds = tf.data.Dataset.from_tensor_slices(sentences) + ds = ds.map(self.preprocessor) + x, y, sw = ds.batch(4).take(1).get_single_element() + self.assertAllEqual( + x["token_ids"], [[2, 4, 4, 4, 4, 3, 0, 0, 0, 0, 0, 0]] * 4 + ) + self.assertAllEqual( + x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4 + ) + self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4]] * 4) + self.assertAllEqual(y, [[5, 10, 6, 8]] * 4) + self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0]] * 4) + + def test_mask_multiple_sentences(self): + sentence_one = tf.constant("the quick") + sentence_two = tf.constant("brown fox") + + x, y, sw = self.preprocessor((sentence_one, sentence_two)) + self.assertAllEqual( + x["token_ids"], [2, 4, 4, 3, 4, 4, 3, 0, 0, 0, 0, 0] + ) + self.assertAllEqual( + x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] + ) + self.assertAllEqual(x["mask_positions"], [1, 2, 4, 5]) + self.assertAllEqual(y, [5, 10, 6, 8]) + self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0]) + + def test_no_masking_zero_rate(self): + no_mask_preprocessor = AlbertMaskedLMPreprocessor( + self.preprocessor.tokenizer, + mask_selection_rate=0.0, + mask_selection_length=4, + sequence_length=12, + ) + input_data = "the quick brown fox" + + x, y, sw = no_mask_preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], [2, 5, 10, 6, 8, 3, 0, 0, 0, 0, 0, 0] + ) + self.assertAllEqual( + x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] + ) + self.assertAllEqual(x["mask_positions"], [0, 0, 0, 0]) + self.assertAllEqual(y, [0, 0, 0, 0]) + self.assertAllEqual(sw, [0.0, 0.0, 0.0, 0.0]) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + input_data = tf.constant(["the quick brown fox"]) + + inputs = keras.Input(dtype="string", shape=()) + outputs = self.preprocessor(inputs) + model = keras.Model(inputs, outputs) + + path = os.path.join(self.get_temp_dir(), filename) + model.save(path, save_format=save_format) + + restored_model = keras.models.load_model(path) + outputs = model(input_data)[0]["token_ids"] + restored_outputs = restored_model(input_data)[0]["token_ids"] + self.assertAllEqual(outputs, restored_outputs) diff --git a/keras_nlp/models/albert/albert_masked_lm_test.py b/keras_nlp/models/albert/albert_masked_lm_test.py new file mode 100644 index 0000000000..557b1bd411 --- /dev/null +++ b/keras_nlp/models/albert/albert_masked_lm_test.py @@ -0,0 +1,152 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for ALBERT masked language model.""" + +import io +import os + +import sentencepiece +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.albert.albert_backbone import AlbertBackbone +from keras_nlp.models.albert.albert_masked_lm import AlbertMaskedLM +from keras_nlp.models.albert.albert_masked_lm_preprocessor import ( + AlbertMaskedLMPreprocessor, +) +from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer + + +class AlbertMaskedLMTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + self.backbone = AlbertBackbone( + vocabulary_size=1000, + num_layers=2, + num_heads=2, + embedding_dim=128, + hidden_dim=64, + intermediate_dim=128, + max_sequence_length=128, + ) + vocab_data = tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the earth is round", "an eagle flew"] + ) + + bytes_io = io.BytesIO() + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=vocab_data.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=15, + model_type="WORD", + pad_id=0, + unk_id=1, + bos_id=2, + eos_id=3, + pad_piece="", + unk_piece="", + bos_piece="[CLS]", + eos_piece="[SEP]", + user_defined_symbols="[MASK]", + ) + + proto = bytes_io.getvalue() + + tokenizer = AlbertTokenizer(proto=proto) + + self.preprocessor = AlbertMaskedLMPreprocessor( + tokenizer=tokenizer, + # Simplify out testing by masking every available token. + mask_selection_rate=1.0, + mask_token_rate=1.0, + random_token_rate=0.0, + mask_selection_length=5, + sequence_length=12, + ) + self.masked_lm = AlbertMaskedLM( + self.backbone, + preprocessor=self.preprocessor, + ) + self.masked_lm_no_preprocessing = AlbertMaskedLM( + self.backbone, + preprocessor=None, + ) + + self.raw_batch = tf.constant( + [ + "quick brown fox", + "eagle flew over fox", + "the eagle flew quick", + "a brown eagle", + ] + ) + self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] + self.raw_dataset = tf.data.Dataset.from_tensor_slices( + self.raw_batch + ).batch(2) + self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) + + def test_valid_call_masked_lm(self): + self.masked_lm(self.preprocessed_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_albert_masked_lm_predict(self, jit_compile): + self.masked_lm.compile(jit_compile=jit_compile) + self.masked_lm.predict(self.raw_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_albert_masked_lm_predict_no_preprocessing(self, jit_compile): + self.masked_lm_no_preprocessing.compile(jit_compile=jit_compile) + self.masked_lm_no_preprocessing.predict(self.preprocessed_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_albert_masked_lm_fit(self, jit_compile): + self.masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=jit_compile, + ) + self.masked_lm.fit(self.raw_dataset) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_albert_masked_lm_fit_no_preprocessing(self, jit_compile): + self.masked_lm_no_preprocessing.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=jit_compile, + ) + self.masked_lm_no_preprocessing.fit(self.preprocessed_dataset) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + save_path = os.path.join(self.get_temp_dir(), filename) + self.masked_lm.save(save_path, save_format=save_format) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, AlbertMaskedLM) + + model_output = self.masked_lm(self.preprocessed_batch) + restored_output = restored_model(self.preprocessed_batch) + + self.assertAllClose(model_output, restored_output) diff --git a/keras_nlp/models/albert/albert_preprocessor_test.py b/keras_nlp/models/albert/albert_preprocessor_test.py index ee6038f839..53639517ea 100644 --- a/keras_nlp/models/albert/albert_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_preprocessor_test.py @@ -35,7 +35,7 @@ def setUp(self): sentencepiece.SentencePieceTrainer.train( sentence_iterator=vocab_data.as_numpy_iterator(), model_writer=bytes_io, - vocab_size=10, + vocab_size=12, model_type="WORD", pad_id=0, unk_id=1, @@ -45,6 +45,7 @@ def setUp(self): unk_piece="", bos_piece="[CLS]", eos_piece="[SEP]", + user_defined_symbols="[MASK]", ) self.proto = bytes_io.getvalue() @@ -57,7 +58,7 @@ def test_tokenize_strings(self): input_data = "the quick brown fox" output = self.preprocessor(input_data) self.assertAllEqual( - output["token_ids"], [2, 4, 9, 5, 7, 3, 0, 0, 0, 0, 0, 0] + output["token_ids"], [2, 5, 10, 6, 8, 3, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual( output["segment_ids"], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] @@ -72,7 +73,7 @@ def test_tokenize_list_of_strings(self): output = self.preprocessor(input_data) self.assertAllEqual( output["token_ids"], - [[2, 4, 9, 5, 7, 3, 0, 0, 0, 0, 0, 0]] * 4, + [[2, 5, 10, 6, 8, 3, 0, 0, 0, 0, 0, 0]] * 4, ) self.assertAllEqual( output["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 @@ -88,7 +89,7 @@ def test_tokenize_labeled_batch(self): x_out, y_out, sw_out = self.preprocessor(x, y, sw) self.assertAllEqual( x_out["token_ids"], - [[2, 4, 9, 5, 7, 3, 0, 0, 0, 0, 0, 0]] * 4, + [[2, 5, 10, 6, 8, 3, 0, 0, 0, 0, 0, 0]] * 4, ) self.assertAllEqual( x_out["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 @@ -108,7 +109,7 @@ def test_tokenize_labeled_dataset(self): x_out, y_out, sw_out = ds.batch(4).take(1).get_single_element() self.assertAllEqual( x_out["token_ids"], - [[2, 4, 9, 5, 7, 3, 0, 0, 0, 0, 0, 0]] * 4, + [[2, 5, 10, 6, 8, 3, 0, 0, 0, 0, 0, 0]] * 4, ) self.assertAllEqual( x_out["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 @@ -125,7 +126,7 @@ def test_tokenize_multiple_sentences(self): output = self.preprocessor((sentence_one, sentence_two)) self.assertAllEqual( output["token_ids"], - [2, 4, 9, 5, 7, 3, 4, 6, 3, 0, 0, 0], + [2, 5, 10, 6, 8, 3, 5, 7, 3, 0, 0, 0], ) self.assertAllEqual( output["segment_ids"], [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0] @@ -142,7 +143,7 @@ def test_tokenize_multiple_batched_sentences(self): output = self.preprocessor((sentence_one, sentence_two)) self.assertAllEqual( output["token_ids"], - [[2, 4, 9, 5, 7, 3, 4, 6, 3, 0, 0, 0]] * 4, + [[2, 5, 10, 6, 8, 3, 5, 7, 3, 0, 0, 0]] * 4, ) self.assertAllEqual( output["segment_ids"], [[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0]] * 4 diff --git a/keras_nlp/models/albert/albert_tokenizer.py b/keras_nlp/models/albert/albert_tokenizer.py index 35dc7ec8d5..5ba47af58d 100644 --- a/keras_nlp/models/albert/albert_tokenizer.py +++ b/keras_nlp/models/albert/albert_tokenizer.py @@ -72,7 +72,8 @@ def __init__(self, proto, **kwargs): cls_token = "[CLS]" sep_token = "[SEP]" pad_token = "" - for token in [cls_token, sep_token, pad_token]: + mask_token = "[MASK]" + for token in [cls_token, sep_token, pad_token, mask_token]: if token not in self.get_vocabulary(): raise ValueError( f"Cannot find token `'{token}'` in the provided " @@ -83,6 +84,7 @@ def __init__(self, proto, **kwargs): self.cls_token_id = self.token_to_id(cls_token) self.sep_token_id = self.token_to_id(sep_token) self.pad_token_id = self.token_to_id(pad_token) + self.mask_token_id = self.token_to_id(mask_token) @classproperty def presets(cls): diff --git a/keras_nlp/models/albert/albert_tokenizer_test.py b/keras_nlp/models/albert/albert_tokenizer_test.py index d80657dbde..e97af4ee3a 100644 --- a/keras_nlp/models/albert/albert_tokenizer_test.py +++ b/keras_nlp/models/albert/albert_tokenizer_test.py @@ -34,7 +34,7 @@ def setUp(self): sentencepiece.SentencePieceTrainer.train( sentence_iterator=vocab_data.as_numpy_iterator(), model_writer=bytes_io, - vocab_size=10, + vocab_size=12, model_type="WORD", pad_id=0, unk_id=1, @@ -44,6 +44,7 @@ def setUp(self): unk_piece="", bos_piece="[CLS]", eos_piece="[SEP]", + user_defined_symbols="[MASK]", ) self.proto = bytes_io.getvalue() @@ -52,20 +53,21 @@ def setUp(self): def test_tokenize(self): input_data = "the quick brown fox" output = self.tokenizer(input_data) - self.assertAllEqual(output, [4, 9, 5, 7]) + self.assertAllEqual(output, [5, 10, 6, 8]) def test_tokenize_batch(self): input_data = tf.constant(["the quick brown fox", "the earth is round"]) output = self.tokenizer(input_data) - self.assertAllEqual(output, [[4, 9, 5, 7], [4, 6, 8, 1]]) + self.assertAllEqual(output, [[5, 10, 6, 8], [5, 7, 9, 11]]) def test_detokenize(self): - input_data = tf.constant([[4, 9, 5, 7]]) + input_data = tf.constant([[5, 10, 6, 8]]) output = self.tokenizer.detokenize(input_data) self.assertEqual(output, tf.constant(["the quick brown fox"])) def test_vocabulary_size(self): - self.assertEqual(self.tokenizer.vocabulary_size(), 10) + tokenizer = AlbertTokenizer(proto=self.proto) + self.assertEqual(tokenizer.vocabulary_size(), 12) def test_errors_missing_special_tokens(self): bytes_io = io.BytesIO()