Skip to content

Commit c747ad8

Browse files
authored
Handle trainable and name in the backbone base class (#680)
* Handle trainable and name in the backbone base class * Add comments explaining get_config/from_config
1 parent c9e5040 commit c747ad8

File tree

9 files changed

+111
-83
lines changed

9 files changed

+111
-83
lines changed

keras_nlp/models/albert/albert_backbone.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -253,21 +253,23 @@ def call(x, padding_mask):
253253
self.cls_token_index = cls_token_index
254254

255255
def get_config(self):
256-
return {
257-
"vocabulary_size": self.vocabulary_size,
258-
"num_layers": self.num_layers,
259-
"num_heads": self.num_heads,
260-
"num_groups": self.num_groups,
261-
"num_inner_repetitions": self.num_inner_repetitions,
262-
"embedding_dim": self.embedding_dim,
263-
"hidden_dim": self.hidden_dim,
264-
"intermediate_dim": self.intermediate_dim,
265-
"dropout": self.dropout,
266-
"max_sequence_length": self.max_sequence_length,
267-
"num_segments": self.num_segments,
268-
"name": self.name,
269-
"trainable": self.trainable,
270-
}
256+
config = super().get_config()
257+
config.update(
258+
{
259+
"vocabulary_size": self.vocabulary_size,
260+
"num_layers": self.num_layers,
261+
"num_heads": self.num_heads,
262+
"num_groups": self.num_groups,
263+
"num_inner_repetitions": self.num_inner_repetitions,
264+
"embedding_dim": self.embedding_dim,
265+
"hidden_dim": self.hidden_dim,
266+
"intermediate_dim": self.intermediate_dim,
267+
"dropout": self.dropout,
268+
"max_sequence_length": self.max_sequence_length,
269+
"num_segments": self.num_segments,
270+
}
271+
)
272+
return config
271273

272274
@classproperty
273275
def presets(cls):

keras_nlp/models/backbone.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,18 @@ class Backbone(keras.Model):
2727
def __init__(self, *args, **kwargs):
2828
super().__init__(*args, **kwargs)
2929

30+
def get_config(self):
31+
# Don't chain to super here. The default `get_config()` for functional
32+
# models is nested and cannot be passed to our Backbone constructors.
33+
return {
34+
"name": self.name,
35+
"trainable": self.trainable,
36+
}
37+
3038
@classmethod
3139
def from_config(cls, config):
40+
# The default `from_config()` for functional models will return a
41+
# vanilla `keras.Model`. We override it to get a subclass instance back.
3242
return cls(**config)
3343

3444
@classproperty

keras_nlp/models/bert/bert_backbone.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -203,18 +203,20 @@ def __init__(
203203
self.cls_token_index = cls_token_index
204204

205205
def get_config(self):
206-
return {
207-
"vocabulary_size": self.vocabulary_size,
208-
"hidden_dim": self.hidden_dim,
209-
"intermediate_dim": self.intermediate_dim,
210-
"num_layers": self.num_layers,
211-
"num_heads": self.num_heads,
212-
"max_sequence_length": self.max_sequence_length,
213-
"num_segments": self.num_segments,
214-
"dropout": self.dropout,
215-
"name": self.name,
216-
"trainable": self.trainable,
217-
}
206+
config = super().get_config()
207+
config.update(
208+
{
209+
"vocabulary_size": self.vocabulary_size,
210+
"hidden_dim": self.hidden_dim,
211+
"intermediate_dim": self.intermediate_dim,
212+
"num_layers": self.num_layers,
213+
"num_heads": self.num_heads,
214+
"max_sequence_length": self.max_sequence_length,
215+
"num_segments": self.num_segments,
216+
"dropout": self.dropout,
217+
}
218+
)
219+
return config
218220

219221
@classproperty
220222
def presets(cls):

keras_nlp/models/deberta_v3/deberta_v3_backbone.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -185,18 +185,20 @@ def __init__(
185185
self.start_token_index = 0
186186

187187
def get_config(self):
188-
return {
189-
"vocabulary_size": self.vocabulary_size,
190-
"num_layers": self.num_layers,
191-
"num_heads": self.num_heads,
192-
"hidden_dim": self.hidden_dim,
193-
"intermediate_dim": self.intermediate_dim,
194-
"dropout": self.dropout,
195-
"max_sequence_length": self.max_sequence_length,
196-
"bucket_size": self.bucket_size,
197-
"name": self.name,
198-
"trainable": self.trainable,
199-
}
188+
config = super().get_config()
189+
config.update(
190+
{
191+
"vocabulary_size": self.vocabulary_size,
192+
"num_layers": self.num_layers,
193+
"num_heads": self.num_heads,
194+
"hidden_dim": self.hidden_dim,
195+
"intermediate_dim": self.intermediate_dim,
196+
"dropout": self.dropout,
197+
"max_sequence_length": self.max_sequence_length,
198+
"bucket_size": self.bucket_size,
199+
}
200+
)
201+
return config
200202

201203
@classproperty
202204
def presets(cls):

keras_nlp/models/distil_bert/distil_bert_backbone.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,19 @@ def __init__(
166166
self.cls_token_index = 0
167167

168168
def get_config(self):
169-
return {
170-
"vocabulary_size": self.vocabulary_size,
171-
"num_layers": self.num_layers,
172-
"num_heads": self.num_heads,
173-
"hidden_dim": self.hidden_dim,
174-
"intermediate_dim": self.intermediate_dim,
175-
"dropout": self.dropout,
176-
"max_sequence_length": self.max_sequence_length,
177-
"name": self.name,
178-
"trainable": self.trainable,
179-
}
169+
config = super().get_config()
170+
config.update(
171+
{
172+
"vocabulary_size": self.vocabulary_size,
173+
"num_layers": self.num_layers,
174+
"num_heads": self.num_heads,
175+
"hidden_dim": self.hidden_dim,
176+
"intermediate_dim": self.intermediate_dim,
177+
"dropout": self.dropout,
178+
"max_sequence_length": self.max_sequence_length,
179+
}
180+
)
181+
return config
180182

181183
@classproperty
182184
def presets(cls):

keras_nlp/models/f_net/f_net_backbone.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,17 +202,19 @@ def __init__(
202202
self.cls_token_index = cls_token_index
203203

204204
def get_config(self):
205-
return {
206-
"vocabulary_size": self.vocabulary_size,
207-
"num_layers": self.num_layers,
208-
"hidden_dim": self.hidden_dim,
209-
"intermediate_dim": self.intermediate_dim,
210-
"dropout": self.dropout,
211-
"max_sequence_length": self.max_sequence_length,
212-
"num_segments": self.num_segments,
213-
"name": self.name,
214-
"trainable": self.trainable,
215-
}
205+
config = super().get_config()
206+
config.update(
207+
{
208+
"vocabulary_size": self.vocabulary_size,
209+
"num_layers": self.num_layers,
210+
"hidden_dim": self.hidden_dim,
211+
"intermediate_dim": self.intermediate_dim,
212+
"dropout": self.dropout,
213+
"max_sequence_length": self.max_sequence_length,
214+
"num_segments": self.num_segments,
215+
}
216+
)
217+
return config
216218

217219
@classproperty
218220
def presets(cls):

keras_nlp/models/gpt2/gpt2_backbone.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -173,17 +173,19 @@ def __init__(
173173
self.max_sequence_length = max_sequence_length
174174

175175
def get_config(self):
176-
return {
177-
"vocabulary_size": self.vocabulary_size,
178-
"num_layers": self.num_layers,
179-
"num_heads": self.num_heads,
180-
"hidden_dim": self.hidden_dim,
181-
"intermediate_dim": self.intermediate_dim,
182-
"dropout": self.dropout,
183-
"max_sequence_length": self.max_sequence_length,
184-
"name": self.name,
185-
"trainable": self.trainable,
186-
}
176+
config = super().get_config()
177+
config.update(
178+
{
179+
"vocabulary_size": self.vocabulary_size,
180+
"num_layers": self.num_layers,
181+
"num_heads": self.num_heads,
182+
"hidden_dim": self.hidden_dim,
183+
"intermediate_dim": self.intermediate_dim,
184+
"dropout": self.dropout,
185+
"max_sequence_length": self.max_sequence_length,
186+
}
187+
)
188+
return config
187189

188190
@classproperty
189191
def presets(cls):

keras_nlp/models/roberta/roberta_backbone.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,19 @@ def __init__(
160160
self.start_token_index = 0
161161

162162
def get_config(self):
163-
return {
164-
"vocabulary_size": self.vocabulary_size,
165-
"num_layers": self.num_layers,
166-
"num_heads": self.num_heads,
167-
"hidden_dim": self.hidden_dim,
168-
"intermediate_dim": self.intermediate_dim,
169-
"dropout": self.dropout,
170-
"max_sequence_length": self.max_sequence_length,
171-
"name": self.name,
172-
"trainable": self.trainable,
173-
}
163+
config = super().get_config()
164+
config.update(
165+
{
166+
"vocabulary_size": self.vocabulary_size,
167+
"num_layers": self.num_layers,
168+
"num_heads": self.num_heads,
169+
"hidden_dim": self.hidden_dim,
170+
"intermediate_dim": self.intermediate_dim,
171+
"dropout": self.dropout,
172+
"max_sequence_length": self.max_sequence_length,
173+
}
174+
)
175+
return config
174176

175177
@classproperty
176178
def presets(cls):

keras_nlp/models/task.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def preprocessor(self):
4040
return self._preprocessor
4141

4242
def get_config(self):
43+
# Don't chain to super here. The default `get_config()` for functional
44+
# models is nested and cannot be passed to our Task constructors.
4345
return {
4446
"backbone": keras.layers.serialize(self.backbone),
4547
"preprocessor": keras.layers.serialize(self.preprocessor),
@@ -49,6 +51,8 @@ def get_config(self):
4951

5052
@classmethod
5153
def from_config(cls, config):
54+
# The default `from_config()` for functional models will return a
55+
# vanilla `keras.Model`. We override it to get a subclass instance back.
5256
if "backbone" in config and isinstance(config["backbone"], dict):
5357
config["backbone"] = keras.layers.deserialize(config["backbone"])
5458
if "preprocessor" in config and isinstance(

0 commit comments

Comments
 (0)