Skip to content

Commit 725f4ad

Browse files
authored
Add tie_weights() to LM heads and set bias in set_output_embeddings() (#28948)
* Add tie_weights() to LM heads and set bias in set_output_embeddings() The bias were not tied correctly in some LM heads, and this change should fix that. * Moving test_save_and_load_low_cpu_mem_usage to ModelTesterMixin * Adding _tie_weights() to MPNet and Vilt * Skip test for low cpu mem usage for Deta/DeformableDetr since they cannot init on meta device * Rename to test name to save_load to match the convention
1 parent 3f4e79d commit 725f4ad

File tree

20 files changed

+104
-0
lines changed

20 files changed

+104
-0
lines changed

src/transformers/models/bert/modeling_bert.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,9 @@ def __init__(self, config):
692692
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
693693
self.decoder.bias = self.bias
694694

695+
def _tie_weights(self):
696+
self.decoder.bias = self.bias
697+
695698
def forward(self, hidden_states):
696699
hidden_states = self.transform(hidden_states)
697700
hidden_states = self.decoder(hidden_states)
@@ -1062,6 +1065,7 @@ def get_output_embeddings(self):
10621065

10631066
def set_output_embeddings(self, new_embeddings):
10641067
self.cls.predictions.decoder = new_embeddings
1068+
self.cls.predictions.bias = new_embeddings.bias
10651069

10661070
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
10671071
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
@@ -1171,6 +1175,7 @@ def get_output_embeddings(self):
11711175

11721176
def set_output_embeddings(self, new_embeddings):
11731177
self.cls.predictions.decoder = new_embeddings
1178+
self.cls.predictions.bias = new_embeddings.bias
11741179

11751180
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
11761181
@add_code_sample_docstrings(
@@ -1324,6 +1329,7 @@ def get_output_embeddings(self):
13241329

13251330
def set_output_embeddings(self, new_embeddings):
13261331
self.cls.predictions.decoder = new_embeddings
1332+
self.cls.predictions.bias = new_embeddings.bias
13271333

13281334
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
13291335
@add_code_sample_docstrings(

src/transformers/models/big_bird/modeling_big_bird.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,6 +1707,9 @@ def __init__(self, config):
17071707
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
17081708
self.decoder.bias = self.bias
17091709

1710+
def _tie_weights(self):
1711+
self.decoder.bias = self.bias
1712+
17101713
def forward(self, hidden_states):
17111714
hidden_states = self.transform(hidden_states)
17121715
hidden_states = self.decoder(hidden_states)
@@ -2266,6 +2269,7 @@ def get_output_embeddings(self):
22662269

22672270
def set_output_embeddings(self, new_embeddings):
22682271
self.cls.predictions.decoder = new_embeddings
2272+
self.cls.predictions.bias = new_embeddings.bias
22692273

22702274
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
22712275
@replace_return_docstrings(output_type=BigBirdForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
@@ -2378,6 +2382,7 @@ def get_output_embeddings(self):
23782382

23792383
def set_output_embeddings(self, new_embeddings):
23802384
self.cls.predictions.decoder = new_embeddings
2385+
self.cls.predictions.bias = new_embeddings.bias
23812386

23822387
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
23832388
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
@@ -2519,6 +2524,7 @@ def get_output_embeddings(self):
25192524

25202525
def set_output_embeddings(self, new_embeddings):
25212526
self.cls.predictions.decoder = new_embeddings
2527+
self.cls.predictions.bias = new_embeddings.bias
25222528

25232529
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
25242530
@add_code_sample_docstrings(

src/transformers/models/blip/modeling_blip_text.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,9 @@ def __init__(self, config):
523523
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
524524
self.decoder.bias = self.bias
525525

526+
def _tie_weights(self):
527+
self.decoder.bias = self.bias
528+
526529
def forward(self, hidden_states):
527530
hidden_states = self.transform(hidden_states)
528531
hidden_states = self.decoder(hidden_states)
@@ -816,6 +819,7 @@ def get_output_embeddings(self):
816819

817820
def set_output_embeddings(self, new_embeddings):
818821
self.cls.predictions.decoder = new_embeddings
822+
self.cls.predictions.bias = new_embeddings.bias
819823

820824
def forward(
821825
self,

src/transformers/models/ernie/modeling_ernie.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,9 @@ def __init__(self, config):
608608
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
609609
self.decoder.bias = self.bias
610610

611+
def _tie_weights(self):
612+
self.decoder.bias = self.bias
613+
611614
def forward(self, hidden_states):
612615
hidden_states = self.transform(hidden_states)
613616
hidden_states = self.decoder(hidden_states)
@@ -995,6 +998,7 @@ def get_output_embeddings(self):
995998
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings
996999
def set_output_embeddings(self, new_embeddings):
9971000
self.cls.predictions.decoder = new_embeddings
1001+
self.cls.predictions.bias = new_embeddings.bias
9981002

9991003
@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
10001004
@replace_return_docstrings(output_type=ErnieForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
@@ -1109,6 +1113,7 @@ def get_output_embeddings(self):
11091113
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings
11101114
def set_output_embeddings(self, new_embeddings):
11111115
self.cls.predictions.decoder = new_embeddings
1116+
self.cls.predictions.bias = new_embeddings.bias
11121117

11131118
@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
11141119
@add_code_sample_docstrings(
@@ -1269,6 +1274,7 @@ def get_output_embeddings(self):
12691274
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings
12701275
def set_output_embeddings(self, new_embeddings):
12711276
self.cls.predictions.decoder = new_embeddings
1277+
self.cls.predictions.bias = new_embeddings.bias
12721278

12731279
@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
12741280
@add_code_sample_docstrings(

src/transformers/models/layoutlm/modeling_layoutlm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,9 @@ def __init__(self, config):
589589
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
590590
self.decoder.bias = self.bias
591591

592+
def _tie_weights(self):
593+
self.decoder.bias = self.bias
594+
592595
def forward(self, hidden_states):
593596
hidden_states = self.transform(hidden_states)
594597
hidden_states = self.decoder(hidden_states)
@@ -869,6 +872,7 @@ def get_output_embeddings(self):
869872

870873
def set_output_embeddings(self, new_embeddings):
871874
self.cls.predictions.decoder = new_embeddings
875+
self.cls.predictions.bias = new_embeddings.bias
872876

873877
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
874878
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)

src/transformers/models/markuplm/modeling_markuplm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,9 @@ def __init__(self, config):
318318
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
319319
self.decoder.bias = self.bias
320320

321+
def _tie_weights(self):
322+
self.decoder.bias = self.bias
323+
321324
def forward(self, hidden_states):
322325
hidden_states = self.transform(hidden_states)
323326
hidden_states = self.decoder(hidden_states)

src/transformers/models/megatron_bert/modeling_megatron_bert.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,9 @@ def __init__(self, config):
659659
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
660660
self.decoder.bias = self.bias
661661

662+
def _tie_weights(self):
663+
self.decoder.bias = self.bias
664+
662665
def forward(self, hidden_states):
663666
hidden_states = self.transform(hidden_states)
664667
hidden_states = self.decoder(hidden_states)
@@ -1023,6 +1026,7 @@ def get_output_embeddings(self):
10231026

10241027
def set_output_embeddings(self, new_embeddings):
10251028
self.cls.predictions.decoder = new_embeddings
1029+
self.cls.predictions.bias = new_embeddings.bias
10261030

10271031
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
10281032
@replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
@@ -1132,6 +1136,7 @@ def get_output_embeddings(self):
11321136

11331137
def set_output_embeddings(self, new_embeddings):
11341138
self.cls.predictions.decoder = new_embeddings
1139+
self.cls.predictions.bias = new_embeddings.bias
11351140

11361141
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
11371142
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
@@ -1290,6 +1295,7 @@ def get_output_embeddings(self):
12901295

12911296
def set_output_embeddings(self, new_embeddings):
12921297
self.cls.predictions.decoder = new_embeddings
1298+
self.cls.predictions.bias = new_embeddings.bias
12931299

12941300
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
12951301
@add_code_sample_docstrings(

src/transformers/models/mpnet/modeling_mpnet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ def get_output_embeddings(self):
587587

588588
def set_output_embeddings(self, new_embeddings):
589589
self.lm_head.decoder = new_embeddings
590+
self.lm_head.bias = new_embeddings.bias
590591

591592
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
592593
@add_code_sample_docstrings(
@@ -659,6 +660,9 @@ def __init__(self, config):
659660
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
660661
self.decoder.bias = self.bias
661662

663+
def _tie_weights(self):
664+
self.decoder.bias = self.bias
665+
662666
def forward(self, features, **kwargs):
663667
x = self.dense(features)
664668
x = gelu(x)

src/transformers/models/mra/modeling_mra.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,9 @@ def __init__(self, config):
820820
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
821821
self.decoder.bias = self.bias
822822

823+
def _tie_weights(self):
824+
self.decoder.bias = self.bias
825+
823826
def forward(self, hidden_states):
824827
hidden_states = self.transform(hidden_states)
825828
hidden_states = self.decoder(hidden_states)
@@ -1053,6 +1056,7 @@ def get_output_embeddings(self):
10531056

10541057
def set_output_embeddings(self, new_embeddings):
10551058
self.cls.predictions.decoder = new_embeddings
1059+
self.cls.predictions.bias = new_embeddings.bias
10561060

10571061
@add_start_docstrings_to_model_forward(MRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
10581062
@add_code_sample_docstrings(

src/transformers/models/nezha/modeling_nezha.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,9 @@ def __init__(self, config):
679679
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
680680
self.decoder.bias = self.bias
681681

682+
def _tie_weights(self):
683+
self.decoder.bias = self.bias
684+
682685
def forward(self, hidden_states):
683686
hidden_states = self.transform(hidden_states)
684687
hidden_states = self.decoder(hidden_states)
@@ -1044,6 +1047,7 @@ def get_output_embeddings(self):
10441047

10451048
def set_output_embeddings(self, new_embeddings):
10461049
self.cls.predictions.decoder = new_embeddings
1050+
self.cls.predictions.bias = new_embeddings.bias
10471051

10481052
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
10491053
@replace_return_docstrings(output_type=NezhaForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
@@ -1152,6 +1156,7 @@ def get_output_embeddings(self):
11521156

11531157
def set_output_embeddings(self, new_embeddings):
11541158
self.cls.predictions.decoder = new_embeddings
1159+
self.cls.predictions.bias = new_embeddings.bias
11551160

11561161
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
11571162
@add_code_sample_docstrings(

0 commit comments

Comments
 (0)