Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras_nlp/models/albert/albert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ def get_config(self):
)
return config

@property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @NusretOzates, the self._token_embedding approach seems cleaner than relying on a specific layer name. We already have several other similar class variables.

Copy link
Collaborator Author

@abheesht17 abheesht17 Jan 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jbischof - it doesn't work for models with TokenAndPositionEmbedding layer. Keras considers self._token_embedding as a separate embedding layer, and errors out when we try to load preset checkpoints. Hence, this elaborate-ish solution.

Copy link
Member

@mattdangerw mattdangerw Jan 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think what we are facing here is setattr tracking on all keras layers. Basically anytime you are setting a layer attribute on self, if gets added to a list of resources used for serialization. It looks like this can affect our checkpoint compatibility! Which is not good, we don't want to be affecting our checkpoints just to expose something like this. Relevant code -> https://github.com/keras-team/keras/blob/2727df09aa284a94ce8234ad1279d9659cdf2064/keras/engine/base_layer.py#L3215-L3229

The solution laid our here seems like a nice way to avoid the setattr tracking entirely. This LGTM.

The alternate I see would be to add a line self._auto_track_sub_layers = False to the backbone base class. But this could run us into hot water if we ever had non-functional Backbones (not everything can be a functional model -> https://keras.io/guides/functional_api/#functional-api-weakness). So the solution here seem most robust!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I am dumping too much context, but for those interested in going deeper...

The __setattr__ tracking is deduped, so for Bert, where the token embedding is already a sublayer of the model directly, there is no issue here. self.some_properly = direct_layer_of_model has no issues. But Roberta for example will have the token embedding as a nested layer. self.some_property = nested_layer_of_model will change our checkpoint structure! This is what @NusretOzates was mentioning above.

Also thanks @NusretOzates for that writeup! Very helpful!

def token_embedding(self):
return self.get_layer("token_embedding")

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
5 changes: 5 additions & 0 deletions keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class Backbone(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@property
def token_embedding(self):
"""A `keras.layers.Embedding` instance for embedding token ids."""
raise NotImplementedError

def get_config(self):
# Don't chain to super here. The default `get_config()` for functional
# models is nested and cannot be passed to our Backbone constructors.
Expand Down
4 changes: 4 additions & 0 deletions keras_nlp/models/bart/bart_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,7 @@ def get_config(self):
"name": self.name,
"trainable": self.trainable,
}

@property
def token_embedding(self):
return self.get_layer("token_embedding")
18 changes: 11 additions & 7 deletions keras_nlp/models/bert/bert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,34 +190,38 @@ def __init__(
},
**kwargs,
)

# All references to `self` below this line
self.vocabulary_size = vocabulary_size
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.dropout = dropout
self.max_sequence_length = max_sequence_length
self.num_segments = num_segments
self.dropout = dropout
self.token_embedding = token_embedding_layer
self.cls_token_index = cls_token_index

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
"dropout": self.dropout,
}
)
return config

@property
def token_embedding(self):
return self.get_layer("token_embedding")

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
4 changes: 4 additions & 0 deletions keras_nlp/models/deberta_v3/deberta_v3_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def get_config(self):
)
return config

@property
def token_embedding(self):
return self.get_layer("token_embedding")

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
4 changes: 4 additions & 0 deletions keras_nlp/models/distil_bert/distil_bert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def get_config(self):
)
return config

@property
def token_embedding(self):
return self.get_layer("token_and_position_embedding").token_embedding

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
4 changes: 4 additions & 0 deletions keras_nlp/models/f_net/f_net_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ def get_config(self):
)
return config

@property
def token_embedding(self):
return self.get_layer("token_embedding")

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
4 changes: 4 additions & 0 deletions keras_nlp/models/gpt2/gpt2_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ def get_config(self):
)
return config

@property
def token_embedding(self):
return self.get_layer("token_embedding")

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
4 changes: 4 additions & 0 deletions keras_nlp/models/roberta/roberta_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def get_config(self):
)
return config

@property
def token_embedding(self):
return self.get_layer("embeddings").token_embedding

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)