diff --git a/keras_nlp/models/albert/albert_backbone.py b/keras_nlp/models/albert/albert_backbone.py index 4501799174..462361501c 100644 --- a/keras_nlp/models/albert/albert_backbone.py +++ b/keras_nlp/models/albert/albert_backbone.py @@ -14,12 +14,16 @@ """ALBERT backbone model.""" +import copy + import tensorflow as tf from tensorflow import keras from keras_nlp.layers.position_embedding import PositionEmbedding from keras_nlp.layers.transformer_encoder import TransformerEncoder +from keras_nlp.models.albert.albert_presets import backbone_presets from keras_nlp.models.backbone import Backbone +from keras_nlp.utils.python_utils import classproperty def albert_kernel_initializer(stddev=0.02): @@ -264,3 +268,7 @@ def get_config(self): "name": self.name, "trainable": self.trainable, } + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/albert/albert_preprocessor.py b/keras_nlp/models/albert/albert_preprocessor.py index e34fe54536..2a14f1e35d 100644 --- a/keras_nlp/models/albert/albert_preprocessor.py +++ b/keras_nlp/models/albert/albert_preprocessor.py @@ -13,9 +13,12 @@ # limitations under the License. """ALBERT preprocessor layer.""" +import copy + from tensorflow import keras from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker +from keras_nlp.models.albert.albert_presets import backbone_presets from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( @@ -179,3 +182,7 @@ def call(self, x, y=None, sample_weight=None): @classproperty def tokenizer_cls(cls): return AlbertTokenizer + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/albert/albert_presets.py b/keras_nlp/models/albert/albert_presets.py new file mode 100644 index 0000000000..0931193a5a --- /dev/null +++ b/keras_nlp/models/albert/albert_presets.py @@ -0,0 +1,114 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ALBERT model preset configurations.""" + + +backbone_presets = { + "albert_base_en_uncased": { + "config": { + "vocabulary_size": 30000, + "num_layers": 12, + "num_heads": 12, + "num_groups": 1, + "num_inner_repetitions": 1, + "embedding_dim": 128, + "hidden_dim": 768, + "intermediate_dim": 3072, + "dropout": 0.0, + "max_sequence_length": 512, + "num_segments": 2, + }, + "preprocessor_config": {}, + "description": ( + "Base size of ALBERT where all input is lowercased. " + "Trained on English Wikipedia + BooksCorpus." + ), + "weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_base_en_uncased/v1/model.h5", + "weights_hash": "b83ccf3418dd84adc569324183176813", + "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_base_en_uncased/v1/vocab.spm", + "spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5", + }, + "albert_large_en_uncased": { + "config": { + "vocabulary_size": 30000, + "num_layers": 24, + "num_heads": 16, + "num_groups": 1, + "num_inner_repetitions": 1, + "embedding_dim": 128, + "hidden_dim": 1024, + "intermediate_dim": 4096, + "dropout": 0, + "max_sequence_length": 512, + "num_segments": 2, + }, + "preprocessor_config": {}, + "description": ( + "Large size of ALBERT where all input is lowercased. " + "Trained on English Wikipedia + BooksCorpus." + ), + "weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_large_en_uncased/v1/model.h5", + "weights_hash": "c7754804efb245f06dd6e7ced32e082c", + "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_large_en_uncased/v1/vocab.spm", + "spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5", + }, + "albert_extra_large_en_uncased": { + "config": { + "vocabulary_size": 30000, + "num_layers": 24, + "num_heads": 16, + "num_groups": 1, + "num_inner_repetitions": 1, + "embedding_dim": 128, + "hidden_dim": 2048, + "intermediate_dim": 8192, + "dropout": 0, + "max_sequence_length": 512, + "num_segments": 2, + }, + "preprocessor_config": {}, + "description": ( + "Extra Large size of ALBERT where all input is lowercased. " + "Trained on English Wikipedia + BooksCorpus." + ), + "weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_large_en_uncased/v1/model.h5", + "weights_hash": "713209be8aadfa614fd79f18c9aeb16d", + "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_large_en_uncased/v1/vocab.spm", + "spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5", + }, + "albert_extra_extra_large_en_uncased": { + "config": { + "vocabulary_size": 30000, + "num_layers": 12, + "num_heads": 64, + "num_groups": 1, + "num_inner_repetitions": 1, + "embedding_dim": 128, + "hidden_dim": 4096, + "intermediate_dim": 16384, + "dropout": 0, + "max_sequence_length": 512, + "num_segments": 2, + }, + "preprocessor_config": {}, + "description": ( + "Extra Large size of ALBERT where all input is lowercased. " + "Trained on English Wikipedia + BooksCorpus." + ), + "weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_extra_large_en_uncased/v1/model.h5", + "weights_hash": "a835177b692fb6a82139f94c66db2f22", + "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_extra_large_en_uncased/v1/vocab.spm", + "spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5", + }, +} diff --git a/keras_nlp/models/albert/albert_presets_test.py b/keras_nlp/models/albert/albert_presets_test.py new file mode 100644 index 0000000000..958e2e8aa6 --- /dev/null +++ b/keras_nlp/models/albert/albert_presets_test.py @@ -0,0 +1,127 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for loading pretrained model presets.""" + +import pytest +import tensorflow as tf +from absl.testing import parameterized + +from keras_nlp.models.albert.albert_backbone import AlbertBackbone +from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor +from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer + + +@pytest.mark.large +class AlbertPresetSmokeTest(tf.test.TestCase, parameterized.TestCase): + """ + A smoke test for ALBERT presets we run continuously. + + This only tests the smallest weights we have available. Run with: + `pytest keras_nlp/models/albert/albert_presets_test.py --run_large` + """ + + def test_tokenizer_output(self): + tokenizer = AlbertTokenizer.from_preset( + "albert_base_en_uncased", + ) + outputs = tokenizer("The quick brown fox.") + expected_outputs = [13, 1, 438, 2231, 886, 2385, 9] + self.assertAllEqual(outputs, expected_outputs) + + def test_preprocessor_output(self): + preprocessor = AlbertPreprocessor.from_preset( + "albert_base_en_uncased", + sequence_length=4, + ) + outputs = preprocessor("The quick brown fox.")["token_ids"] + expected_outputs = [2, 13, 1, 3] + self.assertAllEqual(outputs, expected_outputs) + + @parameterized.named_parameters( + ("preset_weights", True), ("random_weights", False) + ) + def test_backbone_output(self, load_weights): + input_data = { + "token_ids": tf.constant([[2, 13, 1, 3]]), + "segment_ids": tf.constant([[0, 0, 0, 0]]), + "padding_mask": tf.constant([[1, 1, 1, 1]]), + } + model = AlbertBackbone.from_preset( + "albert_base_en_uncased", load_weights=load_weights + ) + outputs = model(input_data) + if load_weights: + outputs = outputs["sequence_output"][0, 0, :5] + expected = [1.830863, 1.698645, -1.819195, -0.53382, -0.38114] + self.assertAllClose(outputs, expected, atol=0.01, rtol=0.01) + + @parameterized.named_parameters( + ("albert_tokenizer", AlbertTokenizer), + ("albert_preprocessor", AlbertPreprocessor), + ("albert", AlbertBackbone), + ) + def test_preset_docstring(self, cls): + """Check we did our docstring formatting correctly.""" + for name in cls.presets: + self.assertRegex(cls.from_preset.__doc__, name) + + @parameterized.named_parameters( + ("albert_tokenizer", AlbertTokenizer), + ("albert_preprocessor", AlbertPreprocessor), + ("albert", AlbertBackbone), + ) + def test_unknown_preset_error(self, cls): + # Not a preset name + with self.assertRaises(ValueError): + cls.from_preset("albert_base_en_uncased_clowntown") + + +@pytest.mark.extra_large +class AlbertPresetFullTest(tf.test.TestCase, parameterized.TestCase): + """ + Test the full enumeration of our preset. + + This tests every ALBERT preset and is only run manually. + Run with: + `pytest keras_nlp/models/albert/albert_presets_test.py --run_extra_large` + """ + + @parameterized.named_parameters( + ("preset_weights", True), ("random_weights", False) + ) + def test_load_albert(self, load_weights): + for preset in AlbertBackbone.presets: + model = AlbertBackbone.from_preset( + preset, load_weights=load_weights + ) + input_data = { + "token_ids": tf.random.uniform( + shape=(1, 512), dtype=tf.int64, maxval=model.vocabulary_size + ), + "segment_ids": tf.constant( + [0] * 200 + [1] * 312, shape=(1, 512) + ), + "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + } + model(input_data) + + def test_load_tokenizers(self): + for preset in AlbertTokenizer.presets: + tokenizer = AlbertTokenizer.from_preset(preset) + tokenizer("The quick brown fox.") + + def test_load_preprocessors(self): + for preset in AlbertPreprocessor.presets: + preprocessor = AlbertPreprocessor.from_preset(preset) + preprocessor("The quick brown fox.") diff --git a/keras_nlp/models/albert/albert_tokenizer.py b/keras_nlp/models/albert/albert_tokenizer.py index c020c2759a..f5c373e6b5 100644 --- a/keras_nlp/models/albert/albert_tokenizer.py +++ b/keras_nlp/models/albert/albert_tokenizer.py @@ -14,11 +14,15 @@ """ALBERT tokenizer.""" +import copy +import os from tensorflow import keras +from keras_nlp.models.albert.albert_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.python_utils import format_docstring @keras.utils.register_keras_serializable(package="keras_nlp") @@ -84,12 +88,53 @@ def __init__(self, proto, **kwargs): @classproperty def presets(cls): - return {} + return copy.deepcopy(backbone_presets) @classmethod + @format_docstring(names=", ".join(backbone_presets)) def from_preset( cls, preset, **kwargs, ): - raise NotImplementedError + """Instantiate an ALBERT tokenizer from preset vocabulary. + + Args: + preset: string. Must be one of {{names}}. + + Examples: + ```python + # Load a preset tokenizer. + tokenizer = keras_nlp.models.AlbertTokenizer.from_preset( + "albert_base_en_uncased", + ) + + # Tokenize some input. + tokenizer("The quick brown fox tripped.") + + # Detokenize some input. + tokenizer.detokenize([5, 6, 7, 8, 9]) + ``` + """ + if preset not in cls.presets: + raise ValueError( + "`preset` must be one of " + f"""{", ".join(cls.presets)}. Received: {preset}.""" + ) + metadata = cls.presets[preset] + + spm_proto = keras.utils.get_file( + "vocab.spm", + metadata["spm_proto_url"], + cache_subdir=os.path.join("models", preset), + file_hash=metadata["spm_proto_hash"], + ) + + config = metadata["preprocessor_config"] + config.update( + { + "proto": spm_proto, + }, + ) + + return cls.from_config({**config, **kwargs})