diff --git a/keras_nlp/models/f_net/f_net_classifier.py b/keras_nlp/models/f_net/f_net_classifier.py new file mode 100644 index 0000000000..04b1fef98d --- /dev/null +++ b/keras_nlp/models/f_net/f_net_classifier.py @@ -0,0 +1,144 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FNet classification model.""" + +import copy + +from tensorflow import keras + +from keras_nlp.models.f_net.f_net_backbone import FNetBackbone +from keras_nlp.models.f_net.f_net_backbone import f_net_kernel_initializer +from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor +from keras_nlp.models.f_net.f_net_presets import backbone_presets +from keras_nlp.models.task import Task +from keras_nlp.utils.python_utils import classproperty + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class FNetClassifier(Task): + """An end-to-end f_net model for classification tasks. + + This model attaches a classification head to a + `keras_nlp.model.FNetBackbone` model, mapping from the backbone + outputs to logit output suitable for a classification task. For usage of + this model with pre-trained weights, see the `from_preset()` method. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to raw inputs during + `fit()`, `predict()`, and `evaluate()`. This is done by default when + creating the model with `from_preset()`. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. + + Args: + backbone: A `keras_nlp.models.FNetBackbone` instance. + num_classes: int. Number of classes to predict. + hidden_dim: int. The size of the pooler layer. + dropout: float. The dropout probability value, applied after the dense + layer. + preprocessor: A `keras_nlp.models.FNetPreprocessor` or `None`. If + `None`, this model will not apply preprocessing, and inputs should + be preprocessed before calling the model. + + Example usage: + ```python + preprocessed_features = { + "token_ids": tf.ones(shape=(2, 12), dtype=tf.int64), + "segment_ids": tf.constant( + [[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12) + ), + "padding_mask": tf.constant( + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12) + ), + } + labels = [0, 3] + + # Randomly initialize a Fnet backbone + backbone = keras_nlp.models.FNetBackbone( + vocabulary_size=32000, + num_layers=12, + num_heads=12, + hidden_dim=768, + intermediate_dim=3072, + max_sequence_length=12, + ) + + # Create a Fnet classifier and fit your data. + classifier = keras_nlp.models.FnetClassifier( + backbone, + num_classes=4, + preprocessor=None, + ) + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + ) + classifier.fit(x=preprocessed_features, y=labels, batch_size=2) + + # Access backbone programatically (e.g., to change `trainable`) + classifier.backbone.trainable = False + ``` + """ + + def __init__( + self, + backbone, + num_classes=2, + dropout=0.1, + preprocessor=None, + **kwargs, + ): + inputs = backbone.input + pooled = backbone(inputs)["pooled_output"] + pooled = keras.layers.Dropout(dropout)(pooled) + outputs = keras.layers.Dense( + num_classes, + kernel_initializer=f_net_kernel_initializer(), + name="logits", + )(pooled) + # Instantiate using Functional API Model constructor + super().__init__( + inputs=inputs, + outputs=outputs, + include_preprocessing=preprocessor is not None, + **kwargs, + ) + # All references to `self` below this line + self._backbone = backbone + self._preprocessor = preprocessor + self.num_classes = num_classes + self.dropout = dropout + + def get_config(self): + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "dropout": self.dropout, + } + ) + return config + + @classproperty + def backbone_cls(cls): + return FNetBackbone + + @classproperty + def preprocessor_cls(cls): + return FNetPreprocessor + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/f_net/f_net_classifier_test.py b/keras_nlp/models/f_net/f_net_classifier_test.py new file mode 100644 index 0000000000..cf4120fe83 --- /dev/null +++ b/keras_nlp/models/f_net/f_net_classifier_test.py @@ -0,0 +1,145 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for FNet classification model.""" + +import io +import os + +import sentencepiece +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.f_net.f_net_backbone import FNetBackbone +from keras_nlp.models.f_net.f_net_classifier import FNetClassifier +from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor +from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer + + +class FNetClassifierTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + self.backbone = FNetBackbone( + vocabulary_size=1000, + num_layers=2, + hidden_dim=64, + intermediate_dim=128, + max_sequence_length=128, + name="encoder", + ) + + bytes_io = io.BytesIO() + vocab_data = tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the earth is round"] + ) + + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=vocab_data.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=10, + model_type="WORD", + pad_id=3, + unk_id=0, + bos_id=4, + eos_id=5, + pad_piece="", + unk_piece="", + bos_piece="[CLS]", + eos_piece="[SEP]", + ) + + self.proto = bytes_io.getvalue() + + self.preprocessor = FNetPreprocessor( + tokenizer=FNetTokenizer(proto=self.proto), + sequence_length=12, + ) + + self.classifier = FNetClassifier( + self.backbone, + 4, + preprocessor=self.preprocessor, + ) + self.classifier_no_preprocessing = FNetClassifier( + self.backbone, + 4, + preprocessor=None, + ) + + self.raw_batch = tf.constant( + [ + "the quick brown fox.", + "the slow brown fox.", + "the smelly brown fox.", + "the old brown fox.", + ] + ) + self.preprocessed_batch = self.preprocessor(self.raw_batch) + self.raw_dataset = tf.data.Dataset.from_tensor_slices( + (self.raw_batch, tf.ones((4,))) + ).batch(2) + self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) + + def test_valid_call_classifier(self): + self.classifier(self.preprocessed_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_fnet_classifier_predict(self, jit_compile): + self.classifier.compile(jit_compile=jit_compile) + self.classifier.predict(self.raw_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_fnet_classifier_predict_no_preprocessing(self, jit_compile): + self.classifier_no_preprocessing.compile(jit_compile=jit_compile) + self.classifier_no_preprocessing.predict(self.preprocessed_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_fnet_classifier_fit(self, jit_compile): + self.classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=jit_compile, + ) + self.classifier.fit(self.raw_dataset) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_fnet_classifier_fit_no_preprocessing(self, jit_compile): + self.classifier_no_preprocessing.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=jit_compile, + ) + self.classifier_no_preprocessing.fit(self.preprocessed_dataset) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + model_output = self.classifier.predict(self.raw_batch) + save_path = os.path.join(self.get_temp_dir(), filename) + self.classifier.save(save_path, save_format=save_format) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, FNetClassifier) + + # Check that output matches. + restored_output = restored_model.predict(self.raw_batch) + self.assertAllClose(model_output, restored_output) diff --git a/keras_nlp/models/f_net/f_net_presets_test.py b/keras_nlp/models/f_net/f_net_presets_test.py index b73b3d22e9..e92bd6d2c2 100644 --- a/keras_nlp/models/f_net/f_net_presets_test.py +++ b/keras_nlp/models/f_net/f_net_presets_test.py @@ -18,6 +18,7 @@ from absl.testing import parameterized from keras_nlp.models.f_net.f_net_backbone import FNetBackbone +from keras_nlp.models.f_net.f_net_classifier import FNetClassifier from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer @@ -49,26 +50,46 @@ def test_preprocessor_output(self): self.assertAllEqual(outputs, expected_outputs) @parameterized.named_parameters( - ("preset_weights", True), ("random_weights", False) + ("load_weights", True), ("no_load_weights", False) ) def test_backbone_output(self, load_weights): input_data = { - "token_ids": tf.constant([[4, 97, 1467, 5]]), + "token_ids": tf.constant([[101, 1996, 4248, 102]]), "segment_ids": tf.constant([[0, 0, 0, 0]]), + "padding_mask": tf.constant([[1, 1, 1, 1]]), } model = FNetBackbone.from_preset( "f_net_base_en", load_weights=load_weights ) - outputs = model(input_data) + outputs = model(input_data)["sequence_output"] if load_weights: - outputs = outputs["sequence_output"][0, 0, :5] - expected = [4.182479, -0.072181, -0.138097, -0.036582, -0.521765] + # The forward pass from a preset should be stable! + # This test should catch cases where we unintentionally change our + # network code in a way that would invalidate our preset weights. + # We should only update these numbers if we are updating a weights + # file, or have found a discrepancy with the upstream source. + outputs = outputs[0, 0, :5] + expected = [4.157282, -0.096616, -0.244943, -0.068104, -0.559592] + # Keep a high tolerance, so we are robust to different hardware. self.assertAllClose(outputs, expected, atol=0.01, rtol=0.01) + @parameterized.named_parameters( + ("load_weights", True), ("no_load_weights", False) + ) + def test_classifier_output(self, load_weights): + input_data = tf.constant(["The quick brown fox."]) + model = FNetClassifier.from_preset( + "f_net_base_en", + load_weights=load_weights, + ) + # We don't assert output values, as the head weights are random. + model.predict(input_data) + @parameterized.named_parameters( ("f_net_tokenizer", FNetTokenizer), ("f_net_preprocessor", FNetPreprocessor), ("f_net", FNetBackbone), + ("f_net_classifier", FNetClassifier), ) def test_preset_docstring(self, cls): """Check we did our docstring formatting correctly.""" @@ -79,6 +100,7 @@ def test_preset_docstring(self, cls): ("f_net_tokenizer", FNetTokenizer), ("f_net_preprocessor", FNetPreprocessor), ("f_net", FNetBackbone), + ("f_net_classifier", FNetClassifier), ) def test_unknown_preset_error(self, cls): # Not a preset name @@ -112,6 +134,41 @@ def test_load_f_net(self, load_weights): } model(input_data) + @parameterized.named_parameters( + ("load_weights", True), ("no_load_weights", False) + ) + def test_load_fnet_classifier(self, load_weights): + for preset in FNetClassifier.presets: + classifier = FNetClassifier.from_preset( + preset, + load_weights=load_weights, + ) + input_data = tf.constant(["This quick brown fox"]) + classifier.predict(input_data) + + @parameterized.named_parameters( + ("load_weights", True), ("no_load_weights", False) + ) + def test_load_fnet_classifier_without_preprocessing(self, load_weights): + for preset in FNetClassifier.presets: + classifier = FNetClassifier.from_preset( + preset, + preprocessor=None, + load_weights=load_weights, + ) + input_data = { + "token_ids": tf.random.uniform( + shape=(1, 512), + dtype=tf.int64, + maxval=classifier.backbone.vocabulary_size, + ), + "segment_ids": tf.constant( + [0] * 200 + [1] * 312, shape=(1, 512) + ), + "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + } + classifier.predict(input_data) + def test_load_tokenizers(self): for preset in FNetTokenizer.presets: tokenizer = FNetTokenizer.from_preset(preset)