@@ -253,21 +253,23 @@ def call(x, padding_mask):
253253 self .cls_token_index = cls_token_index
254254
255255 def get_config (self ):
256- return {
257- "vocabulary_size" : self .vocabulary_size ,
258- "num_layers" : self .num_layers ,
259- "num_heads" : self .num_heads ,
260- "num_groups" : self .num_groups ,
261- "num_inner_repetitions" : self .num_inner_repetitions ,
262- "embedding_dim" : self .embedding_dim ,
263- "hidden_dim" : self .hidden_dim ,
264- "intermediate_dim" : self .intermediate_dim ,
265- "dropout" : self .dropout ,
266- "max_sequence_length" : self .max_sequence_length ,
267- "num_segments" : self .num_segments ,
268- "name" : self .name ,
269- "trainable" : self .trainable ,
270- }
256+ config = super ().get_config ()
257+ config .update (
258+ {
259+ "vocabulary_size" : self .vocabulary_size ,
260+ "num_layers" : self .num_layers ,
261+ "num_heads" : self .num_heads ,
262+ "num_groups" : self .num_groups ,
263+ "num_inner_repetitions" : self .num_inner_repetitions ,
264+ "embedding_dim" : self .embedding_dim ,
265+ "hidden_dim" : self .hidden_dim ,
266+ "intermediate_dim" : self .intermediate_dim ,
267+ "dropout" : self .dropout ,
268+ "max_sequence_length" : self .max_sequence_length ,
269+ "num_segments" : self .num_segments ,
270+ }
271+ )
272+ return config
271273
272274 @classproperty
273275 def presets (cls ):
0 commit comments