Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions keras_nlp/models/albert/albert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,21 +253,23 @@ def call(x, padding_mask):
self.cls_token_index = cls_token_index

def get_config(self):
return {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"num_groups": self.num_groups,
"num_inner_repetitions": self.num_inner_repetitions,
"embedding_dim": self.embedding_dim,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
"name": self.name,
"trainable": self.trainable,
}
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"num_groups": self.num_groups,
"num_inner_repetitions": self.num_inner_repetitions,
"embedding_dim": self.embedding_dim,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
}
)
return config

@classproperty
def presets(cls):
Expand Down
10 changes: 10 additions & 0 deletions keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,18 @@ class Backbone(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_config(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Couldn't you technically chain all the way up to keras.Model for this?

Copy link
Member Author

@mattdangerw mattdangerw Jan 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot because of a pain point with functional subclassed models.

By default get_config on a functional model will return a huge nested description of all layers, that can only be used with from_config for the function keras model class itself. There's no way to use the standard functional model get_config and from_config and still get back a BertBackbone, you would just get back a vanilla functional model (with no special properties, methods, docstring etc).

So essentially for functional subclasses, you have to "break the chain" of get_config and from_config chaining to super, which is unusual. Everywhere else in Keras you want to do this! But at least we are able to hide this pain point from our downstream model implementation now. And potentially we could prioritize some work in core Keras at some point to improve the "functional subclass" experience and remove this weird gotcha.

I can leave a (much shorter than this), comment explaining this on the backbone class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! Lifelong learner 🤓

# Don't chain to super here. The default `get_config()` for functional
# models is nested and cannot be passed to our Backbone constructors.
return {
"name": self.name,
"trainable": self.trainable,
}

@classmethod
def from_config(cls, config):
# The default `from_config()` for functional models will return a
# vanilla `keras.Model`. We override it to get a subclass instance back.
return cls(**config)

@classproperty
Expand Down
26 changes: 14 additions & 12 deletions keras_nlp/models/bert/bert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,20 @@ def __init__(
self.cls_token_index = cls_token_index

def get_config(self):
return {
"vocabulary_size": self.vocabulary_size,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
"dropout": self.dropout,
"name": self.name,
"trainable": self.trainable,
}
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
"dropout": self.dropout,
}
)
return config

@classproperty
def presets(cls):
Expand Down
26 changes: 14 additions & 12 deletions keras_nlp/models/deberta_v3/deberta_v3_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,18 +185,20 @@ def __init__(
self.start_token_index = 0

def get_config(self):
return {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"bucket_size": self.bucket_size,
"name": self.name,
"trainable": self.trainable,
}
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"bucket_size": self.bucket_size,
}
)
return config

@classproperty
def presets(cls):
Expand Down
24 changes: 13 additions & 11 deletions keras_nlp/models/distil_bert/distil_bert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,19 @@ def __init__(
self.cls_token_index = 0

def get_config(self):
return {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"name": self.name,
"trainable": self.trainable,
}
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
}
)
return config

@classproperty
def presets(cls):
Expand Down
24 changes: 13 additions & 11 deletions keras_nlp/models/f_net/f_net_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,19 @@ def __init__(
self.cls_token_index = cls_token_index

def get_config(self):
return {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
"name": self.name,
"trainable": self.trainable,
}
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
}
)
return config

@classproperty
def presets(cls):
Expand Down
24 changes: 13 additions & 11 deletions keras_nlp/models/gpt2/gpt2_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,19 @@ def __init__(
self.max_sequence_length = max_sequence_length

def get_config(self):
return {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"name": self.name,
"trainable": self.trainable,
}
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
}
)
return config

@classproperty
def presets(cls):
Expand Down
24 changes: 13 additions & 11 deletions keras_nlp/models/roberta/roberta_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,19 @@ def __init__(
self.start_token_index = 0

def get_config(self):
return {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"name": self.name,
"trainable": self.trainable,
}
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
}
)
return config

@classproperty
def presets(cls):
Expand Down
4 changes: 4 additions & 0 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def preprocessor(self):
return self._preprocessor

def get_config(self):
# Don't chain to super here. The default `get_config()` for functional
# models is nested and cannot be passed to our Task constructors.
return {
"backbone": keras.layers.serialize(self.backbone),
"preprocessor": keras.layers.serialize(self.preprocessor),
Expand All @@ -49,6 +51,8 @@ def get_config(self):

@classmethod
def from_config(cls, config):
# The default `from_config()` for functional models will return a
# vanilla `keras.Model`. We override it to get a subclass instance back.
if "backbone" in config and isinstance(config["backbone"], dict):
config["backbone"] = keras.layers.deserialize(config["backbone"])
if "preprocessor" in config and isinstance(
Expand Down