Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 57 additions & 69 deletions keras_nlp/models/albert/albert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,67 +118,47 @@ def __init__(
f"`num_layers={num_layers}` and `num_groups={num_groups}`."
)

# Index of classification token in the vocabulary
cls_token_index = 0
# Inputs
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
segment_id_input = keras.Input(
shape=(None,), dtype="int32", name="segment_ids"
)
padding_mask = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

# Embed tokens, positions, and segment ids.
token_embedding_layer = ReversibleEmbedding(
# === Layers ===
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=embedding_dim,
embeddings_initializer=albert_kernel_initializer(),
name="token_embedding",
)
token_embedding = token_embedding_layer(token_id_input)
position_embedding = PositionEmbedding(
self.position_embedding = PositionEmbedding(
initializer=albert_kernel_initializer(),
sequence_length=max_sequence_length,
name="position_embedding",
)(token_embedding)
segment_embedding = keras.layers.Embedding(
)
self.segment_embedding = keras.layers.Embedding(
input_dim=num_segments,
output_dim=embedding_dim,
embeddings_initializer=albert_kernel_initializer(),
name="segment_embedding",
)(segment_id_input)

# Sum, normalize and apply dropout to embeddings.
x = keras.layers.Add()(
(token_embedding, position_embedding, segment_embedding)
)
x = keras.layers.LayerNormalization(
self.embeddings_add = keras.layers.Add(
name="embeddings_add",
)
self.embeddings_layer_norm = keras.layers.LayerNormalization(
name="embeddings_layer_norm",
axis=-1,
epsilon=1e-12,
dtype="float32",
)(x)
x = keras.layers.Dropout(
)
self.embeddings_dropout = keras.layers.Dropout(
dropout,
name="embeddings_dropout",
)(x)

# Project the embedding to `hidden_dim`.
x = keras.layers.Dense(
)
self.embeddings_projection = keras.layers.Dense(
hidden_dim,
kernel_initializer=albert_kernel_initializer(),
name="embedding_projection",
)(x)

def get_group_layer(group_idx):
"""Defines a group `num_inner_repetitions` transformer layers and
returns the callable.
"""
transformer_layers = [
TransformerEncoder(
)
self.transformer_layers = []
for group_idx in range(num_groups):
inner_layers = []
for inner_idx in range(num_inner_repetitions):
layer = TransformerEncoder(
num_heads=num_heads,
intermediate_dim=intermediate_dim,
activation=gelu_approximate,
Expand All @@ -187,51 +167,60 @@ def get_group_layer(group_idx):
kernel_initializer=albert_kernel_initializer(),
name=f"group_{group_idx}_inner_layer_{inner_idx}",
)
for inner_idx in range(num_inner_repetitions)
]

def call(x, padding_mask):
for transformer_layer in transformer_layers:
x = transformer_layer(x, padding_mask=padding_mask)
return x

return call

num_calls_per_group = num_layers // num_groups
for group_idx in range(num_groups):
# Define the group. A group in ALBERT terminology is any number of
# repeated attention and FFN blocks.
group_layer = get_group_layer(group_idx)

# Assume num_layers = 8, num_groups = 4. Then, the order of group
# calls will be 0, 0, 1, 1, 2, 2, 3, 3.
for call in range(num_calls_per_group):
x = group_layer(x, padding_mask=padding_mask)

# Construct the two ALBERT outputs. The pooled output is a dense layer on
# top of the [CLS] token.
sequence_output = x
pooled_output = keras.layers.Dense(
inner_layers.append(layer)
self.transformer_layers.append(inner_layers)
self.pooled_dense = keras.layers.Dense(
hidden_dim,
kernel_initializer=albert_kernel_initializer(),
activation="tanh",
name="pooled_dense",
)(x[:, cls_token_index, :])
)

# Instantiate using Functional API Model constructor
# === Functional Model ===
# Inputs
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
segment_id_input = keras.Input(
shape=(None,), dtype="int32", name="segment_ids"
)
padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)
# Embed tokens, positions, and segment ids.
tokens = self.token_embedding(token_id_input)
positions = self.position_embedding(tokens)
segments = self.segment_embedding(segment_id_input)
# Sum, normalize and apply dropout to embeddings.
x = self.embeddings_add((tokens, positions, segments))
x = self.embeddings_layer_norm(x)
x = self.embeddings_dropout(x)
x = self.embeddings_projection(x)
# Call transformer layers with repeated groups.
num_calls_per_group = num_layers // num_groups
for group in self.transformer_layers:
for _ in range(num_calls_per_group):
for transformer_layer in group:
x = transformer_layer(x, padding_mask=padding_mask_input)
# Construct the two ALBERT outputs. The pooled output is a dense layer
# on top of the [CLS] token.
sequence_output = x
cls_token_index = 0
pooled_output = self.pooled_dense(x[:, cls_token_index, :])
super().__init__(
inputs={
"token_ids": token_id_input,
"segment_ids": segment_id_input,
"padding_mask": padding_mask,
"padding_mask": padding_mask_input,
},
outputs={
"sequence_output": sequence_output,
"pooled_output": pooled_output,
},
**kwargs,
)
# All references to `self` below this line

# === Config ===
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_heads = num_heads
Expand All @@ -244,7 +233,6 @@ def call(x, padding_mask):
self.max_sequence_length = max_sequence_length
self.num_segments = num_segments
self.cls_token_index = cls_token_index
self.token_embedding = token_embedding_layer

def get_config(self):
config = super().get_config()
Expand Down
29 changes: 18 additions & 11 deletions keras_nlp/models/albert/albert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,30 +155,37 @@ def __init__(
dropout=0.1,
**kwargs,
):
inputs = backbone.input
pooled = backbone(inputs)["pooled_output"]
pooled = keras.layers.Dropout(dropout)(pooled)
outputs = keras.layers.Dense(
# === Layers ===
self.preprocessor = preprocessor
self.backbone = backbone
self.output_dense = keras.layers.Dense(
num_classes,
kernel_initializer=albert_kernel_initializer(),
activation=activation,
name="logits",
)(pooled)
# Instantiate using Functional API Model constructor
)
self.output_dropout = keras.layers.Dropout(
dropout,
name="output_dropout",
)

# === Functional Model ===
inputs = backbone.input
pooled = backbone(inputs)["pooled_output"]
pooled = self.output_dropout(pooled)
outputs = self.output_dense(pooled)
super().__init__(
inputs=inputs,
outputs=outputs,
include_preprocessing=preprocessor is not None,
**kwargs,
)
# All references to `self` below this line
self._backbone = backbone
self._preprocessor = preprocessor

# === Config ===
self.num_classes = num_classes
self.activation = keras.activations.get(activation)
self.dropout = dropout

# Default compilation
# === Default compilation ===
logit_output = self.activation == keras.activations.linear
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(
Expand Down
31 changes: 17 additions & 14 deletions keras_nlp/models/albert/albert_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,32 +97,35 @@ class AlbertMaskedLM(Task):
"""

def __init__(self, backbone, preprocessor=None, **kwargs):
# === Layers ===
self.backbone = backbone
self.preprocessor = preprocessor
self.masked_lm_head = MaskedLMHead(
vocabulary_size=backbone.vocabulary_size,
token_embedding=backbone.token_embedding,
intermediate_activation=gelu_approximate,
kernel_initializer=albert_kernel_initializer(),
name="mlm_head",
)

# === Functional Model ===
inputs = {
**backbone.input,
"mask_positions": keras.Input(
shape=(None,), dtype="int32", name="mask_positions"
),
}

backbone_outputs = backbone(backbone.input)
outputs = MaskedLMHead(
vocabulary_size=backbone.vocabulary_size,
token_embedding=backbone.token_embedding,
intermediate_activation=gelu_approximate,
kernel_initializer=albert_kernel_initializer(),
name="mlm_head",
)(backbone_outputs["sequence_output"], inputs["mask_positions"])

outputs = self.masked_lm_head(
backbone_outputs["sequence_output"], inputs["mask_positions"]
)
super().__init__(
inputs=inputs,
outputs=outputs,
include_preprocessing=preprocessor is not None,
**kwargs
**kwargs,
)

self.backbone = backbone
self.preprocessor = preprocessor

# === Default compilation ===
self.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
Expand Down
34 changes: 23 additions & 11 deletions keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.utils.preset_utils import check_preset_class
from keras_nlp.utils.preset_utils import load_from_preset
Expand All @@ -23,43 +24,54 @@
class Backbone(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._token_embedding = None
self._functional_layer_ids = set(
id(layer) for layer in self._flatten_layers()
)
self._initialized = True

def __dir__(self):
# Temporary fixes for weight saving. This mimics the following PR for
if config.keras_3():
return super().__dir__()

# Temporary fixes for Keras 2 saving. This mimics the following PR for
# older version of Keras: https://github.com/keras-team/keras/pull/18982
def filter_fn(attr):
if attr == "_layer_checkpoint_dependencies":
if attr in [
"_layer_checkpoint_dependencies",
"transformer_layers",
"encoder_transformer_layers",
"decoder_transformer_layers",
]:
return False
return id(getattr(self, attr)) not in self._functional_layer_ids

return filter(filter_fn, super().__dir__())

def __setattr__(self, name, value):
# Work around torch setattr for properties.
if name in ["token_embedding"]:
# Work around setattr issues for Keras 2 and Keras 3 torch backend.
# Since all our state is covered by functional model we can route
# around custom setattr calls.
is_property = isinstance(getattr(type(self), name, None), property)
is_unitialized = not hasattr(self, "_initialized")
is_torch = config.backend() == "torch"
is_keras_2 = not config.keras_3()
if is_torch and (is_property or is_unitialized):
return object.__setattr__(self, name, value)
if is_keras_2 and is_unitialized:
return object.__setattr__(self, name, value)
return super().__setattr__(name, value)

@property
def token_embedding(self):
"""A `keras.layers.Embedding` instance for embedding token ids.

This layer integer token ids to the hidden dim of the model.
This layer embeds integer token ids to the hidden dim of the model.
"""
return self._token_embedding

@token_embedding.setter
def token_embedding(self, value):
# Workaround tf.keras h5 checkpoint loading, which is sensitive to layer
# count mismatches and does not deduplicate layers. This could go away
# if we update our checkpoints to the newer `.weights.h5` format.
self._setattr_tracking = False
self._token_embedding = value
self._setattr_tracking = True

def get_config(self):
# Don't chain to super here. The default `get_config()` for functional
Expand Down
Loading