From fe706d02ec121a2a030bc424919b3af7c3d95b47 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Wed, 18 Jan 2023 13:15:04 -0800 Subject: [PATCH 1/2] Handle trainable and name in the backbone base class --- keras_nlp/models/albert/albert_backbone.py | 32 ++++++++++--------- keras_nlp/models/backbone.py | 6 ++++ keras_nlp/models/bert/bert_backbone.py | 26 ++++++++------- .../models/deberta_v3/deberta_v3_backbone.py | 26 ++++++++------- .../distil_bert/distil_bert_backbone.py | 24 +++++++------- keras_nlp/models/f_net/f_net_backbone.py | 24 +++++++------- keras_nlp/models/gpt2/gpt2_backbone.py | 24 +++++++------- keras_nlp/models/roberta/roberta_backbone.py | 24 +++++++------- 8 files changed, 103 insertions(+), 83 deletions(-) diff --git a/keras_nlp/models/albert/albert_backbone.py b/keras_nlp/models/albert/albert_backbone.py index 462361501c..f97fb1ec83 100644 --- a/keras_nlp/models/albert/albert_backbone.py +++ b/keras_nlp/models/albert/albert_backbone.py @@ -253,21 +253,23 @@ def call(x, padding_mask): self.cls_token_index = cls_token_index def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "num_heads": self.num_heads, - "num_groups": self.num_groups, - "num_inner_repetitions": self.num_inner_repetitions, - "embedding_dim": self.embedding_dim, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "num_segments": self.num_segments, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "num_groups": self.num_groups, + "num_inner_repetitions": self.num_inner_repetitions, + "embedding_dim": self.embedding_dim, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "num_segments": self.num_segments, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index b47aa00cd5..20e2a2699b 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -27,6 +27,12 @@ class Backbone(keras.Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + def get_config(self): + return { + "name": self.name, + "trainable": self.trainable, + } + @classmethod def from_config(cls, config): return cls(**config) diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py index 752365f9df..924b6f434b 100644 --- a/keras_nlp/models/bert/bert_backbone.py +++ b/keras_nlp/models/bert/bert_backbone.py @@ -203,18 +203,20 @@ def __init__( self.cls_token_index = cls_token_index def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "num_layers": self.num_layers, - "num_heads": self.num_heads, - "max_sequence_length": self.max_sequence_length, - "num_segments": self.num_segments, - "dropout": self.dropout, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "max_sequence_length": self.max_sequence_length, + "num_segments": self.num_segments, + "dropout": self.dropout, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py index da8fff2033..1d69c7de42 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py @@ -185,18 +185,20 @@ def __init__( self.start_token_index = 0 def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "num_heads": self.num_heads, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "bucket_size": self.bucket_size, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "bucket_size": self.bucket_size, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone.py b/keras_nlp/models/distil_bert/distil_bert_backbone.py index f124e79220..9b1f14c2fc 100644 --- a/keras_nlp/models/distil_bert/distil_bert_backbone.py +++ b/keras_nlp/models/distil_bert/distil_bert_backbone.py @@ -166,17 +166,19 @@ def __init__( self.cls_token_index = 0 def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "num_heads": self.num_heads, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/f_net/f_net_backbone.py b/keras_nlp/models/f_net/f_net_backbone.py index 593cc74564..67ea8c4ec9 100644 --- a/keras_nlp/models/f_net/f_net_backbone.py +++ b/keras_nlp/models/f_net/f_net_backbone.py @@ -202,17 +202,19 @@ def __init__( self.cls_token_index = cls_token_index def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "num_segments": self.num_segments, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "num_segments": self.num_segments, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py index caa354c01d..91e38823c4 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone.py +++ b/keras_nlp/models/gpt2/gpt2_backbone.py @@ -173,17 +173,19 @@ def __init__( self.max_sequence_length = max_sequence_length def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "num_heads": self.num_heads, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + } + ) + return config @classproperty def presets(cls): diff --git a/keras_nlp/models/roberta/roberta_backbone.py b/keras_nlp/models/roberta/roberta_backbone.py index 65d2aa3d36..3d16625544 100644 --- a/keras_nlp/models/roberta/roberta_backbone.py +++ b/keras_nlp/models/roberta/roberta_backbone.py @@ -160,17 +160,19 @@ def __init__( self.start_token_index = 0 def get_config(self): - return { - "vocabulary_size": self.vocabulary_size, - "num_layers": self.num_layers, - "num_heads": self.num_heads, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "dropout": self.dropout, - "max_sequence_length": self.max_sequence_length, - "name": self.name, - "trainable": self.trainable, - } + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + } + ) + return config @classproperty def presets(cls): From 2d7ed18a0ead164bdee45b9ef0e4d5cb0ab14b39 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Wed, 18 Jan 2023 16:03:05 -0800 Subject: [PATCH 2/2] Add comments explaining get_config/from_config --- keras_nlp/models/backbone.py | 4 ++++ keras_nlp/models/task.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 20e2a2699b..51cf3472da 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -28,6 +28,8 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def get_config(self): + # Don't chain to super here. The default `get_config()` for functional + # models is nested and cannot be passed to our Backbone constructors. return { "name": self.name, "trainable": self.trainable, @@ -35,6 +37,8 @@ def get_config(self): @classmethod def from_config(cls, config): + # The default `from_config()` for functional models will return a + # vanilla `keras.Model`. We override it to get a subclass instance back. return cls(**config) @classproperty diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 5fffe5218a..f2543aed67 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -40,6 +40,8 @@ def preprocessor(self): return self._preprocessor def get_config(self): + # Don't chain to super here. The default `get_config()` for functional + # models is nested and cannot be passed to our Task constructors. return { "backbone": keras.layers.serialize(self.backbone), "preprocessor": keras.layers.serialize(self.preprocessor), @@ -49,6 +51,8 @@ def get_config(self): @classmethod def from_config(cls, config): + # The default `from_config()` for functional models will return a + # vanilla `keras.Model`. We override it to get a subclass instance back. if "backbone" in config and isinstance(config["backbone"], dict): config["backbone"] = keras.layers.deserialize(config["backbone"]) if "preprocessor" in config and isinstance(