diff --git a/keras_nlp/models/albert/albert_backbone.py b/keras_nlp/models/albert/albert_backbone.py index 462361501c..f97fb1ec83 100644 --- a/keras_nlp/models/albert/albert_backbone.py +++ b/keras_nlp/models/albert/albert_backbone.py @@ -253,21 +253,23 @@ def call(x, padding_mask): self.cls_token_index = cls_token_index def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "num_heads": self.num_heads, - "num_groups": self.num_groups, - "num_inner_repetitions": self.num_inner_repetitions, - "embedding_dim": self.embedding_dim, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "num_segments": self.num_segments, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "num_groups": self.num_groups, + "num_inner_repetitions": self.num_inner_repetitions, + "embedding_dim": self.embedding_dim, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "num_segments": self.num_segments, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index b47aa00cd5..51cf3472da 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -27,8 +27,18 @@ class Backbone(keras.Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + def get_config(self): + # Don't chain to super here. The default `get_config()` for functional + # models is nested and cannot be passed to our Backbone constructors. + return { + "name": self.name, + "trainable": self.trainable, + } + @classmethod def from_config(cls, config): + # The default `from_config()` for functional models will return a + # vanilla `keras.Model`. We override it to get a subclass instance back. return cls(**config) @classproperty diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py index 752365f9df..924b6f434b 100644 --- a/keras_nlp/models/bert/bert_backbone.py +++ b/keras_nlp/models/bert/bert_backbone.py @@ -203,18 +203,20 @@ def __init__( self.cls_token_index = cls_token_index def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "num_layers": self.num_layers, - "num_heads": self.num_heads, - "max_sequence_length": self.max_sequence_length, - "num_segments": self.num_segments, - "dropout": self.dropout, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "max_sequence_length": self.max_sequence_length, + "num_segments": self.num_segments, + "dropout": self.dropout, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py index da8fff2033..1d69c7de42 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py @@ -185,18 +185,20 @@ def __init__( self.start_token_index = 0 def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "num_heads": self.num_heads, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "bucket_size": self.bucket_size, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "bucket_size": self.bucket_size, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone.py b/keras_nlp/models/distil_bert/distil_bert_backbone.py index f124e79220..9b1f14c2fc 100644 --- a/keras_nlp/models/distil_bert/distil_bert_backbone.py +++ b/keras_nlp/models/distil_bert/distil_bert_backbone.py @@ -166,17 +166,19 @@ def __init__( self.cls_token_index = 0 def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "num_heads": self.num_heads, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/f_net/f_net_backbone.py b/keras_nlp/models/f_net/f_net_backbone.py index 593cc74564..67ea8c4ec9 100644 --- a/keras_nlp/models/f_net/f_net_backbone.py +++ b/keras_nlp/models/f_net/f_net_backbone.py @@ -202,17 +202,19 @@ def __init__( self.cls_token_index = cls_token_index def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "num_segments": self.num_segments, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "num_segments": self.num_segments, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py index caa354c01d..91e38823c4 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone.py +++ b/keras_nlp/models/gpt2/gpt2_backbone.py @@ -173,17 +173,19 @@ def __init__( self.max_sequence_length = max_sequence_length def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "num_heads": self.num_heads, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/roberta/roberta_backbone.py b/keras_nlp/models/roberta/roberta_backbone.py index 65d2aa3d36..3d16625544 100644 --- a/keras_nlp/models/roberta/roberta_backbone.py +++ b/keras_nlp/models/roberta/roberta_backbone.py @@ -160,17 +160,19 @@ def __init__( self.start_token_index = 0 def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "num_heads": self.num_heads, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 5fffe5218a..f2543aed67 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -40,6 +40,8 @@ def preprocessor(self): return self._preprocessor def get_config(self): + # Don't chain to super here. The default `get_config()` for functional + # models is nested and cannot be passed to our Task constructors. return { "backbone": keras.layers.serialize(self.backbone), "preprocessor": keras.layers.serialize(self.preprocessor), @@ -49,6 +51,8 @@ def get_config(self): @classmethod def from_config(cls, config): + # The default `from_config()` for functional models will return a + # vanilla `keras.Model`. We override it to get a subclass instance back. if "backbone" in config and isinstance(config["backbone"], dict): config["backbone"] = keras.layers.deserialize(config["backbone"]) if "preprocessor" in config and isinstance(