Skip to content
1 change: 1 addition & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.models.bert.bert_backbone import BertBackbone
Expand Down
271 changes: 271 additions & 0 deletions keras_nlp/models/albert/albert_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ALBERT backbone model."""

import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.layers.transformer_encoder import TransformerEncoder
from keras_nlp.models.backbone import Backbone
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.python_utils import format_docstring


def albert_kernel_initializer(stddev=0.02):
return keras.initializers.TruncatedNormal(stddev=stddev)


@keras.utils.register_keras_serializable(package="keras_nlp")
class AlbertBackbone(Backbone):
"""ALBERT encoder network.

This class implements a bi-directional Transformer-based encoder as
described in
["ALBERT: A Lite BERT for Self-supervised Learning of Language Representations"](https://arxiv.org/abs/1909.11942).
ALBERT is a more efficient variant of BERT, and uses parameter reduction
techniques such as cross-layer parameter sharing and factorized embedding
parameterization. This model class includes the embedding lookups and
transformer layers, but not the masked language model or sentence order
prediction heads.

The default constructor gives a fully customizable, randomly initialized
ALBERT encoder with any number of layers, heads, and embedding dimensions.
To load preset architectures and weights, use the `from_preset` constructor.

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind.

Args:
vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of "virtual" layers, i.e., the total number
of times the input sequence will be fed through the Transformer
layers in one forward pass. The input will be routed to the correct
group based on the layer index.
num_heads: int. The number of attention heads for each transformer.
The hidden size must be divisible by the number of attention heads.
num_groups: int. Number of groups, with each group having a certain
number of Transformer layers.
num_layers_per_group: int. Number of Transformer layers per group.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this? The original implementation does not seem to have it, and it seems like we are overpameterized here. Let's just take in num_layers and num_groups.

https://github.com/google-research/albert/blob/master/modeling.py#L47

I think this could also lead to a simplification of the loop as well, will comment below.

Copy link
Collaborator Author

@abheesht17 abheesht17 Jan 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as #622 (comment).

embedding_dim: int. The size of the embeddings.
hidden_dim: int. The size of the transformer encoding and pooler layers.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
dropout: float. Dropout probability for the Transformer encoder.
max_sequence_length: int. The maximum sequence length that this encoder
can consume. If None, `max_sequence_length` uses the value from
sequence length. This determines the variable shape for positional
embeddings.
num_segments: int. The number of types that the 'segment_ids' input can
take.

Examples:
```python
input_data = {
"token_ids": tf.ones(shape=(1, 12), dtype=tf.int64),
"segment_ids": tf.constant(
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
),
"padding_mask": tf.constant(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
),
}

# Randomly initialized ALBERT encoder
model = keras_nlp.models.AlbertBackbone(
vocabulary_size=30000,
num_layers=12,
num_heads=12,
num_groups=1,
num_layers_per_group=1,
embedding_dim=128,
hidden_dim=768,
intermediate_dim=3072,
max_sequence_length=12,
)
output = model(input_data)
```
"""

def __init__(
self,
vocabulary_size,
num_layers,
num_heads,
num_groups,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's let's give num_groups and num_inner_repetitions defaults of 1. That will be a good way to indicate to users how this is used in practice (this means we should also drop them below the dim arguments.

num_layers_per_group,
embedding_dim,
hidden_dim,
intermediate_dim,
dropout=0.0,
max_sequence_length=512,
num_segments=2,
**kwargs,
):

# Index of classification token in the vocabulary
cls_token_index = 0
# Inputs
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
segment_id_input = keras.Input(
shape=(None,), dtype="int32", name="segment_ids"
)
padding_mask = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

# Embed tokens, positions, and segment ids.
token_embedding_layer = keras.layers.Embedding(
input_dim=vocabulary_size,
output_dim=embedding_dim,
embeddings_initializer=albert_kernel_initializer(),
name="token_embedding",
)
token_embedding = token_embedding_layer(token_id_input)
position_embedding = PositionEmbedding(
initializer=albert_kernel_initializer(),
sequence_length=max_sequence_length,
name="position_embedding",
)(token_embedding)
segment_embedding = keras.layers.Embedding(
input_dim=num_segments,
output_dim=embedding_dim,
embeddings_initializer=albert_kernel_initializer(),
name="segment_embedding",
)(segment_id_input)

# Sum, normalize and apply dropout to embeddings.
x = keras.layers.Add()(
(token_embedding, position_embedding, segment_embedding)
)
x = keras.layers.LayerNormalization(
name="embeddings_layer_norm",
axis=-1,
epsilon=1e-12,
dtype=tf.float32,
)(x)
x = keras.layers.Dropout(
dropout,
name="embeddings_dropout",
)(x)

# Project the embedding to `hidden_dim`.
x = keras.layers.Dense(
hidden_dim,
kernel_initializer=albert_kernel_initializer(),
name="embedding_projection",
)(x)

layer_idx = 0
for i in range(num_groups):
Copy link
Contributor

@jbischof jbischof Jan 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop is pretty inscrutable. Can we at least offer a comment? I particularly don't understand the while loop logic and wonder if there's a more readable approach.

Copy link
Member

@mattdangerw mattdangerw Jan 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking at the original implementation, I think we have over parameterized. We just want to take in the number of groups, and split out groups evenly so that any remainder layers are distributed evenly amount the first groups.

I tried to write this out as readably as possible in a nested loop.

for group_idx in range(num_groups):
    # Create a single encoder block with shared weights for the group.
    shared_encoder = TransformerEncoder(...)
    # Split our total layers evenly among groups.
    layers_per_group = num_layers // num_groups
    # Any remainder layers go into the earlier groups.
    if group_idx < num_layers % num_groups:
        layers_per_group += 1
    # Apply our shared encoder block once per layer in the group.
    for _ in range(layers_per_group):
        x = shared_encoder(x, padding_mask=padding_mask)

@abheesht17 would this work?

Copy link
Collaborator Author

@abheesht17 abheesht17 Jan 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, is this what's happening? 🤔

They have an argument called "inner_group_num" (which I have renamed as num_layers_per_group because inner_group_num isn't exactly easy to understand). Check this notebook out; the graph is similar to what I have done. And the number of parameters is the same for both.

https://colab.research.google.com/drive/17m3ozVYBTuodsSRxb8ycn9g37i6q48gc?usp=sharing

Copy link
Collaborator Author

@abheesht17 abheesht17 Jan 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jbischof, yeah. The alternative is to have a separate loop for this. Let me know which one you guys prefer. Something like:

        group_layers = []
        for i in range(num_groups):
            transformer_layers = [
                TransformerEncoder(
                    num_heads=num_heads,
                    intermediate_dim=intermediate_dim,
                    activation=lambda x: keras.activations.gelu(
                        x, approximate=True
                    ),
                    dropout=dropout,
                    kernel_initializer=albert_kernel_initializer(),
                    name=f"group_{i}_transformer_layer_{j}",
                )
                for j in range(num_layers_per_group)
            ]
            group_layers.append(transformer_layers)



        for layer_idx in range(num_layers):
            group_idx = int(layer_idx / (num_layers / num_groups))
            transformer_layers = group_layers[group_idx]
            for transformer_layer in transformer_layers:
                x = transformer_layer(x, padding_mask=padding_mask)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! I missed the "inner_group_num", but that seems fairly separate? A single "layer" according to the num_layers argument is just a stack of our encoders right? I think my solution could still stand just slightly expanded...

def get_layer():
    # A "layer" in Albert terminology is actually any number of repeated attention and
    # feed-forward blocks, controller by the `inner_group_num` parameter.
    layer = keras.Sequential()
    for _ in range(inner_group_num):
        layer.add(TransformerEncoder(...))
    return layer

for group_idx in range(num_groups):
    # Create a single encoder block with shared weights for the group.
    shared_encoder = get_layer()
    # Split our total layers evenly among groups.
    layers_per_group = num_layers // num_groups
    # Any remainder layers go into the earlier groups.
    if group_idx < num_layers % num_groups:
        layers_per_group += 1
    # Apply our shared encoder block once per layer in the group.
    for _ in range(layers_per_group):
        x = shared_encoder(x, padding_mask=padding_mask)

Maybe I am still missing something though, let's sync tomorrow!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, not quite. There is a difference between the two:

>>> for group_idx in range(5):
...   layers_per_group = 12 // 5
...   if group_idx < 12 % 5:
...     layers_per_group += 1
...   for _ in range(layers_per_group):
...     print(group_idx, end=", ")
... 
0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4,
>>> for layer_idx in range(12):
...   group_idx = int (layer_idx / (12 / 5))
...   print(group_idx, end=", ")
... 
0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4,

https://github.com/huggingface/transformers/blob/35a7052b61579cfe8df1a059d4cd3359310ec2d1/src/transformers/models/albert/modeling_albert.py#L473-L486

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This might be it?

layer_idx = 0
for group_idx in range(num_groups):
    # Create a single block with shared weights for the group.
    shared_block = get_shared_block()
    # Apply the shared block until the fraction of total layer exceeds
    # the fractional boundary of the next group.
    while layer_idx / num_layers < (group_idx + 1) / num_groups:
        x = shared_block(x, padding_mask=padding_mask)
        layer_idx += 1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this should work!

Also,

    while layer_idx / num_layers < (group_idx + 1) / num_groups:
        x = shared_block(x, padding_mask=padding_mask)
        layer_idx += 1

is same as what is on the PR currently, if we just shift num_groups to the other side:
while int(layer_idx / (num_layers / num_groups)) == group_idx.

transformer_layers = [
TransformerEncoder(
num_heads=num_heads,
intermediate_dim=intermediate_dim,
activation=lambda x: keras.activations.gelu(
x, approximate=True
),
dropout=dropout,
kernel_initializer=albert_kernel_initializer(),
name=f"group_{i}_transformer_layer_{j}",
)
for j in range(num_layers_per_group)
]

while int(layer_idx / (num_layers / num_groups)) == i:
for transformer_layer in transformer_layers:
x = transformer_layer(x, padding_mask=padding_mask)
layer_idx += 1

# Construct the two ALBERT outputs. The pooled output is a dense layer on
# top of the [CLS] token.
sequence_output = x
pooled_output = keras.layers.Dense(
hidden_dim,
kernel_initializer=albert_kernel_initializer(),
activation="tanh",
name="pooled_dense",
)(x[:, cls_token_index, :])

# Instantiate using Functional API Model constructor
super().__init__(
inputs={
"token_ids": token_id_input,
"segment_ids": segment_id_input,
"padding_mask": padding_mask,
},
outputs={
"sequence_output": sequence_output,
"pooled_output": pooled_output,
},
**kwargs,
)
# All references to `self` below this line
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_heads = num_heads
self.num_groups = num_groups
self.num_layers_per_group = num_layers_per_group
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.dropout = dropout
self.max_sequence_length = max_sequence_length
self.num_segments = num_segments
self.cls_token_index = cls_token_index

def get_config(self):
return {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"num_groups": self.num_groups,
"num_layers_per_group": self.num_layers_per_group,
"embedding_dim": self.embedding_dim,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
"name": self.name,
"trainable": self.trainable,
}

@classmethod
def from_config(cls, config):
return cls(**config)

@classproperty
def presets(cls):
return {}

@classmethod
def from_preset(
cls,
preset,
load_weights=True,
**kwargs,
):
return super().from_preset(preset, load_weights, **kwargs)


AlbertBackbone.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__
format_docstring(
model_name=AlbertBackbone.__name__,
example_preset_name="", # TODO: Add example preset name.
preset_names='", "'.join(AlbertBackbone.presets),
)(AlbertBackbone.from_preset.__func__)
107 changes: 107 additions & 0 deletions keras_nlp/models/albert/albert_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for ALBERT backbone model."""

import os

import tensorflow as tf
from absl.testing import parameterized
from tensorflow import keras

from keras_nlp.models.albert.albert_backbone import AlbertBackbone


class AlbertBackboneTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
self.model = AlbertBackbone(
vocabulary_size=1000,
num_layers=2,
num_heads=2,
num_groups=1,
num_layers_per_group=1,
embedding_dim=16,
hidden_dim=64,
intermediate_dim=128,
max_sequence_length=128,
)
self.batch_size = 8
self.input_batch = {
"token_ids": tf.ones(
(self.batch_size, self.model.max_sequence_length), dtype="int32"
),
"segment_ids": tf.ones(
(self.batch_size, self.model.max_sequence_length), dtype="int32"
),
"padding_mask": tf.ones(
(self.batch_size, self.model.max_sequence_length), dtype="int32"
),
}

self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
).batch(2)

def test_valid_call_albert(self):
self.model(self.input_batch)

# Check default name passed through
self.assertRegexpMatches(self.model.name, "albert_backbone")

def test_variable_sequence_length_call_albert(self):
for seq_length in (25, 50, 75):
input_data = {
"token_ids": tf.ones(
(self.batch_size, seq_length), dtype="int32"
),
"segment_ids": tf.ones(
(self.batch_size, seq_length), dtype="int32"
),
"padding_mask": tf.ones(
(self.batch_size, seq_length), dtype="int32"
),
}
self.model(input_data)

@parameterized.named_parameters(
("jit_compile_false", False), ("jit_compile_true", True)
)
def test_compile(self, jit_compile):
self.model.compile(jit_compile=jit_compile)
self.model.predict(self.input_batch)

@parameterized.named_parameters(
("jit_compile_false", False), ("jit_compile_true", True)
)
def test_compile_batched_ds(self, jit_compile):
self.model.compile(jit_compile=jit_compile)
self.model.predict(self.input_dataset)

@parameterized.named_parameters(
("tf_format", "tf", "model"),
("keras_format", "keras_v3", "model.keras"),
)
def test_saved_model(self, save_format, filename):
model_output = self.model(self.input_batch)
save_path = os.path.join(self.get_temp_dir(), filename)
self.model.save(save_path, save_format=save_format)
restored_model = keras.models.load_model(save_path)

# Check we got the real object back.
self.assertIsInstance(restored_model, AlbertBackbone)

# Check that output matches.
restored_output = restored_model(self.input_batch)
self.assertAllClose(
model_output["pooled_output"], restored_output["pooled_output"]
)