-
Notifications
You must be signed in to change notification settings - Fork 309
Add ALBERT Backbone #622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ALBERT Backbone #622
Changes from 3 commits
9466b88
443d68d
d417a77
f6ecfa1
34c19c9
a37f7f6
b944864
2303d0a
7e99614
065fc87
986a3b6
1fa46ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright 2022 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,260 @@ | ||
| # Copyright 2022 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ALBERT backbone model.""" | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow import keras | ||
|
|
||
| from keras_nlp.layers.position_embedding import PositionEmbedding | ||
| from keras_nlp.models.albert.albert_group_layer import AlbertGroupLayer | ||
| from keras_nlp.utils.python_utils import classproperty | ||
|
|
||
|
|
||
| def albert_kernel_initializer(stddev=0.02): | ||
| return keras.initializers.TruncatedNormal(stddev=stddev) | ||
|
|
||
|
|
||
| @keras.utils.register_keras_serializable(package="keras_nlp") | ||
| class AlbertBackbone(keras.Model): | ||
| """ALBERT encoder network. | ||
|
|
||
| This class implements a bi-directional Transformer-based encoder as | ||
| described in | ||
| ["ALBERT: A Lite BERT for Self-supervised Learning of Language Representations"](https://arxiv.org/abs/1909.11942). | ||
| ALBERT is a more efficient variant of BERT, and uses parameter reduction | ||
| techniques such as cross-layer parameter sharing and factorized embedding | ||
| parameterization. This model class includes the embedding lookups and | ||
| transformer layers, but not the masked language model or sentence order | ||
| prediction heads. | ||
|
|
||
| The default constructor gives a fully customizable, randomly initialized | ||
| ALBERT encoder with any number of layers, heads, and embedding dimensions. | ||
| To load preset architectures and weights, use the `from_preset` constructor. | ||
|
|
||
| Disclaimer: Pre-trained models are provided on an "as is" basis, without | ||
| warranties or conditions of any kind. | ||
|
|
||
| Args: | ||
| vocabulary_size: int. The size of the token vocabulary. | ||
| num_layers: int. The number of "virtual" layers, i.e., the total number | ||
| of times the input sequence will be fed through the Transformer | ||
| layers in one forward pass. | ||
| num_heads: int. The number of attention heads for each transformer. | ||
| The hidden size must be divisible by the number of attention heads. | ||
| num_hidden_groups: int. Number of groups, with each group having a | ||
| certain number of Transformer layers. | ||
| num_layers_per_group: int. Number of Transformer layers per group. | ||
|
||
| embedding_dim: int. The size of the embeddings. | ||
| hidden_dim: int. The size of the transformer encoding and pooler layers. | ||
| intermediate_dim: int. The output dimension of the first Dense layer in | ||
| a two-layer feedforward network for each transformer. | ||
| dropout: float. Dropout probability for the Transformer encoder. | ||
| max_sequence_length: int. The maximum sequence length that this encoder | ||
| can consume. If None, `max_sequence_length` uses the value from | ||
| sequence length. This determines the variable shape for positional | ||
| embeddings. | ||
| num_segments: int. The number of types that the 'segment_ids' input can | ||
| take. | ||
|
|
||
| Examples: | ||
| ```python | ||
| input_data = { | ||
| "token_ids": tf.ones(shape=(1, 12), dtype=tf.int64), | ||
| "segment_ids": tf.constant( | ||
| [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) | ||
| ), | ||
| "padding_mask": tf.constant( | ||
| [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) | ||
| ), | ||
| } | ||
|
|
||
jbischof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Randomly initialized ALBERT encoder | ||
| model = keras_nlp.models.AlbertBackbone( | ||
| vocabulary_size=30000, | ||
| num_layers=12, | ||
| num_heads=12, | ||
| num_hidden_groups=1, | ||
| num_layers_per_group=1, | ||
| embedding_dim=128, | ||
| hidden_dim=768, | ||
| intermediate_dim=3072, | ||
| max_sequence_length=12, | ||
| ) | ||
| output = model(input_data) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| vocabulary_size, | ||
| num_layers, | ||
| num_heads, | ||
| num_hidden_groups, | ||
| num_layers_per_group, | ||
| embedding_dim, | ||
| hidden_dim, | ||
| intermediate_dim, | ||
| dropout=0.0, | ||
jbischof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| max_sequence_length=512, | ||
| num_segments=2, | ||
| **kwargs, | ||
| ): | ||
|
|
||
| # Index of classification token in the vocabulary | ||
| cls_token_index = 0 | ||
| # Inputs | ||
| token_id_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="token_ids" | ||
| ) | ||
| segment_id_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="segment_ids" | ||
| ) | ||
| padding_mask = keras.Input( | ||
| shape=(None,), dtype="int32", name="padding_mask" | ||
| ) | ||
|
|
||
| # Embed tokens, positions, and segment ids. | ||
| token_embedding_layer = keras.layers.Embedding( | ||
| input_dim=vocabulary_size, | ||
| output_dim=embedding_dim, | ||
| embeddings_initializer=albert_kernel_initializer(), | ||
| name="token_embedding", | ||
| ) | ||
| token_embedding = token_embedding_layer(token_id_input) | ||
| position_embedding = PositionEmbedding( | ||
| initializer=albert_kernel_initializer(), | ||
| sequence_length=max_sequence_length, | ||
| name="position_embedding", | ||
| )(token_embedding) | ||
| segment_embedding = keras.layers.Embedding( | ||
| input_dim=num_segments, | ||
| output_dim=embedding_dim, | ||
| embeddings_initializer=albert_kernel_initializer(), | ||
| name="segment_embedding", | ||
| )(segment_id_input) | ||
|
|
||
| # Sum, normalize and apply dropout to embeddings. | ||
| x = keras.layers.Add()( | ||
| (token_embedding, position_embedding, segment_embedding) | ||
| ) | ||
| x = keras.layers.LayerNormalization( | ||
| name="embeddings_layer_norm", | ||
| axis=-1, | ||
| epsilon=1e-12, | ||
| dtype=tf.float32, | ||
| )(x) | ||
| x = keras.layers.Dropout( | ||
| dropout, | ||
| name="embeddings_dropout", | ||
| )(x) | ||
|
|
||
| # Project the embedding to `hidden_dim`. | ||
| x = keras.layers.Dense( | ||
| hidden_dim, | ||
| kernel_initializer=albert_kernel_initializer(), | ||
| name="embedding_projection", | ||
| )(x) | ||
|
|
||
| albert_group_layers = [ | ||
| AlbertGroupLayer( | ||
|
||
| num_layers=num_layers_per_group, | ||
| num_heads=num_heads, | ||
| intermediate_dim=intermediate_dim, | ||
| activation=lambda x: keras.activations.gelu( | ||
| x, approximate=True | ||
| ), | ||
| dropout=dropout, | ||
| kernel_initializer=albert_kernel_initializer(), | ||
| name=f"group_layer_{i}", | ||
| ) | ||
| for i in range(num_hidden_groups) | ||
| ] | ||
|
|
||
| # Apply successive transformer encoder blocks. | ||
| for i in range(num_layers): | ||
| # Index of the hidden group | ||
| group_idx = int(i / (num_layers / num_hidden_groups)) | ||
| x = albert_group_layers[group_idx](x, padding_mask) | ||
|
||
|
|
||
| # Construct the two ALBERT outputs. The pooled output is a dense layer on | ||
| # top of the [CLS] token. | ||
| sequence_output = x | ||
| pooled_output = keras.layers.Dense( | ||
| hidden_dim, | ||
| kernel_initializer=albert_kernel_initializer(), | ||
| activation="tanh", | ||
| name="pooled_dense", | ||
| )(x[:, cls_token_index, :]) | ||
|
|
||
| # Instantiate using Functional API Model constructor | ||
| super().__init__( | ||
| inputs={ | ||
| "token_ids": token_id_input, | ||
| "segment_ids": segment_id_input, | ||
| "padding_mask": padding_mask, | ||
| }, | ||
| outputs={ | ||
| "sequence_output": sequence_output, | ||
| "pooled_output": pooled_output, | ||
| }, | ||
| **kwargs, | ||
| ) | ||
| # All references to `self` below this line | ||
| self.vocabulary_size = vocabulary_size | ||
| self.num_layers = num_layers | ||
| self.num_heads = num_heads | ||
| self.num_hidden_groups = num_hidden_groups | ||
| self.num_layers_per_group = num_layers_per_group | ||
| self.embedding_dim = embedding_dim | ||
| self.hidden_dim = hidden_dim | ||
| self.intermediate_dim = intermediate_dim | ||
| self.dropout = dropout | ||
| self.max_sequence_length = max_sequence_length | ||
| self.num_segments = num_segments | ||
| self.cls_token_index = cls_token_index | ||
|
|
||
| def get_config(self): | ||
| return { | ||
| "vocabulary_size": self.vocabulary_size, | ||
| "num_layers": self.num_layers, | ||
| "num_heads": self.num_heads, | ||
| "num_hidden_groups": self.num_hidden_groups, | ||
| "num_layers_per_group": self.num_layers_per_group, | ||
| "embedding_dim": self.embedding_dim, | ||
| "hidden_dim": self.hidden_dim, | ||
| "intermediate_dim": self.intermediate_dim, | ||
| "dropout": self.dropout, | ||
| "max_sequence_length": self.max_sequence_length, | ||
| "num_segments": self.num_segments, | ||
| "name": self.name, | ||
| "trainable": self.trainable, | ||
| } | ||
|
|
||
| @classmethod | ||
jbischof marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def from_config(cls, config): | ||
| return cls(**config) | ||
|
|
||
| @classproperty | ||
jbischof marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def presets(cls): | ||
| return {} | ||
|
|
||
| @classmethod | ||
| def from_preset( | ||
| cls, | ||
| preset, | ||
| load_weights=True, | ||
| **kwargs, | ||
| ): | ||
| raise NotImplementedError | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| # Copyright 2022 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Test for ALBERT backbone model.""" | ||
|
|
||
| import os | ||
|
|
||
| import tensorflow as tf | ||
| from absl.testing import parameterized | ||
| from tensorflow import keras | ||
|
|
||
| from keras_nlp.models.albert.albert_backbone import AlbertBackbone | ||
|
|
||
|
|
||
| class AlbertBackboneTest(tf.test.TestCase, parameterized.TestCase): | ||
| def setUp(self): | ||
| self.model = AlbertBackbone( | ||
| vocabulary_size=1000, | ||
| num_layers=2, | ||
| num_heads=2, | ||
| num_hidden_groups=1, | ||
| num_layers_per_group=1, | ||
| embedding_dim=16, | ||
| hidden_dim=64, | ||
| intermediate_dim=128, | ||
| max_sequence_length=128, | ||
| ) | ||
| self.batch_size = 8 | ||
| self.input_batch = { | ||
| "token_ids": tf.ones( | ||
| (self.batch_size, self.model.max_sequence_length), dtype="int32" | ||
| ), | ||
| "segment_ids": tf.ones( | ||
| (self.batch_size, self.model.max_sequence_length), dtype="int32" | ||
| ), | ||
| "padding_mask": tf.ones( | ||
| (self.batch_size, self.model.max_sequence_length), dtype="int32" | ||
| ), | ||
| } | ||
|
|
||
| self.input_dataset = tf.data.Dataset.from_tensor_slices( | ||
| self.input_batch | ||
| ).batch(2) | ||
|
|
||
| def test_valid_call_albert(self): | ||
| self.model(self.input_batch) | ||
|
|
||
| # Check default name passed through | ||
| self.assertRegexpMatches(self.model.name, "albert_backbone") | ||
|
|
||
| def test_variable_sequence_length_call_albert(self): | ||
| for seq_length in (25, 50, 75): | ||
| input_data = { | ||
| "token_ids": tf.ones( | ||
| (self.batch_size, seq_length), dtype="int32" | ||
| ), | ||
| "segment_ids": tf.ones( | ||
| (self.batch_size, seq_length), dtype="int32" | ||
| ), | ||
| "padding_mask": tf.ones( | ||
| (self.batch_size, seq_length), dtype="int32" | ||
| ), | ||
| } | ||
| self.model(input_data) | ||
|
|
||
| @parameterized.named_parameters( | ||
| ("jit_compile_false", False), ("jit_compile_true", True) | ||
| ) | ||
| def test_compile(self, jit_compile): | ||
| self.model.compile(jit_compile=jit_compile) | ||
| self.model.predict(self.input_batch) | ||
|
|
||
| @parameterized.named_parameters( | ||
| ("jit_compile_false", False), ("jit_compile_true", True) | ||
| ) | ||
| def test_compile_batched_ds(self, jit_compile): | ||
| self.model.compile(jit_compile=jit_compile) | ||
| self.model.predict(self.input_dataset) | ||
|
|
||
| @parameterized.named_parameters( | ||
| ("tf_format", "tf", "model"), | ||
| ("keras_format", "keras_v3", "model.keras"), | ||
| ) | ||
| def test_saved_model(self, save_format, filename): | ||
| model_output = self.model(self.input_batch) | ||
| save_path = os.path.join(self.get_temp_dir(), filename) | ||
| self.model.save(save_path, save_format=save_format) | ||
| restored_model = keras.models.load_model(save_path) | ||
|
|
||
| # Check we got the real object back. | ||
| self.assertIsInstance(restored_model, AlbertBackbone) | ||
|
|
||
| # Check that output matches. | ||
| restored_output = restored_model(self.input_batch) | ||
| self.assertAllClose( | ||
| model_output["pooled_output"], restored_output["pooled_output"] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on timing of this landing vs #621, you may want to update this to follow the new base class.