Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
134f8a8
Attempting to test automatically the `_keys_to_ignore`.
Narsil Nov 3, 2022
6706337
Style.
Narsil Nov 3, 2022
9e33f8a
First fix pass.
Narsil Nov 3, 2022
2225e65
Moving test on its own.
Narsil Nov 3, 2022
0d1642c
Another batch.
Narsil Nov 3, 2022
602edaf
Second round removing BatchNorm
Narsil Nov 3, 2022
7851df5
Fixing layoutlmv{2,3} + support older Python.
Narsil Nov 3, 2022
8f6a6ae
Disable miss missing warning.
Narsil Nov 3, 2022
f53fb25
Removing dodgy additions.
Narsil Nov 3, 2022
7a7b378
Big pass.
Narsil Nov 3, 2022
e7cb3df
mbart.
Narsil Nov 3, 2022
1d8811c
More corrections.
Narsil Nov 4, 2022
e66d960
Fixup.
Narsil Nov 4, 2022
7039418
Updating test_correct_missing_keys
Narsil Nov 4, 2022
8e7c939
Add escape hatch for when the head has no extra params so doesn't need
Narsil Nov 4, 2022
3f38678
Fixing test.
Narsil Nov 4, 2022
79c7e6e
Greener.
Narsil Nov 4, 2022
6c399d4
Green ! (except for weird splinter bug).
Narsil Nov 4, 2022
c3bfad1
Adding a test about `named_parameters` usage.
Narsil Nov 4, 2022
81b3692
Shorten message.
Narsil Nov 4, 2022
fb6007b
Apply suggestions from code review
Narsil Nov 4, 2022
56411ac
After rebase modifications.
Narsil Nov 4, 2022
d845b3a
More explicit condition checking.
Narsil Nov 4, 2022
85c4bb2
Fixing slow tests issues.
Narsil Nov 8, 2022
e2f6a84
Remove extra pdb.
Narsil Nov 8, 2022
3354c8e
Remove print.
Narsil Nov 8, 2022
f552a91
Attempt to make failure consistent + fixing roc_bert.
Narsil Nov 9, 2022
bb3673c
Removing the seed (all tests passing with it).
Narsil Nov 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2421,8 +2421,9 @@ def _fix_key(key):
add_prefix_to_model = has_prefix_module and not expects_prefix_module

if remove_prefix_from_model:
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
_prefix = f"{prefix}."
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)]
expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys]
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]

Expand Down Expand Up @@ -2641,13 +2642,16 @@ def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=Fal

# torch.nn.ParameterList is a special case where two parameter keywords
# are appended to the module name, *e.g.* bert.special_embeddings.0
module_keys = module_keys.union(set([".".join(key.split(".")[:-2]) for key in names if key[-1].isdigit()]))
module_keys = module_keys.union(
set([".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()])
)

retrieved_modules = []
# retrieve all modules that has at least one missing weight name
for name, module in self.named_modules():
if remove_prefix:
name = ".".join(name.split(".")[1:]) if name.startswith(self.base_model_prefix) else name
_prefix = f"{self.base_model_prefix}."
name = name[len(_prefix) :] if name.startswith(_prefix) else name
elif add_prefix:
name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix

Expand Down
11 changes: 11 additions & 0 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,12 @@ def forward(
ALBERT_START_DOCSTRING,
)
class AlbertForPreTraining(AlbertPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"predictions.decoder.weight",
"predictions.decoder.bias",
"embeddings.position_ids",
]

def __init__(self, config: AlbertConfig):
super().__init__(config)

Expand Down Expand Up @@ -910,6 +916,11 @@ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
class AlbertForMaskedLM(AlbertPreTrainedModel):

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [
"predictions.decoder.weight",
"predictions.decoder.bias",
"embeddings.position_ids",
]

def __init__(self, config):
super().__init__(config)
Expand Down
15 changes: 14 additions & 1 deletion src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,8 @@ def custom_forward(*inputs):
BART_START_DOCSTRING,
)
class BartModel(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

def __init__(self, config: BartConfig):
super().__init__(config)

Expand Down Expand Up @@ -1281,7 +1283,12 @@ def forward(
)
class BartForConditionalGeneration(BartPretrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]

def __init__(self, config: BartConfig):
super().__init__(config)
Expand Down Expand Up @@ -1451,6 +1458,8 @@ def _reorder_cache(past, beam_idx):
BART_START_DOCSTRING,
)
class BartForSequenceClassification(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BartModel(config)
Expand Down Expand Up @@ -1578,6 +1587,8 @@ def forward(
BART_START_DOCSTRING,
)
class BartForQuestionAnswering(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -1714,6 +1725,8 @@ def forward(self, *args, **kwargs):
BART_START_DOCSTRING,
)
class BartForCausalLM(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]

def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,8 @@ def forward(
BERT_START_DOCSTRING,
)
class BertForPreTraining(BertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -1153,7 +1155,7 @@ def forward(
class BertLMHeadModel(BertPreTrainedModel):

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me review this slowly 🙏 and verify a few things. But do you think r"predictions.decoder.bias" is a mistake and should be r"cls.predictions.decoder.bias" ?

I am going to check myself anyway.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A re match is done, so while the exact name is indeed r"cls.predictions.decoder.bias", this works. But would be great to fix just in case one day a weight named predictions.decoder.bias that should not be ignore appears ;-)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me review this slowly 🙏 and verify a few things. But do you think r"predictions.decoder.bias" is a mistake and should be r"cls.predictions.decoder.bias" ?

I am going to check myself anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe. I'm not sure but the test does yell if we're hiding a valid key (we I don't try to yell when we have an unused _keys_to_ignore. Want me to try and add it to the test ?


def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -1288,7 +1290,7 @@ def _reorder_cache(self, past, beam_idx):
class BertForMaskedLM(BertPreTrainedModel):

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]

def __init__(self, config):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,8 @@ def _tie_weights(self):
BERT_GENERATION_START_DOCSTRING,
)
class BertGenerationDecoder(BertGenerationPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder.weight", "lm_head.decoder.bias", "embeddings.position_ids"]

def __init__(self, config):
super().__init__(config)

Expand Down
12 changes: 10 additions & 2 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -2262,6 +2262,8 @@ def _pad_to_block_size(


class BigBirdForPreTraining(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -2366,6 +2368,8 @@ def forward(

@add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING)
class BigBirdForMaskedLM(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -2508,8 +2512,12 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_
"""BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING
)
class BigBirdForCausalLM(BigBirdPreTrainedModel):

_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"predictions.decoder.bias",
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
]

def __init__(self, config):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2350,6 +2350,8 @@ def custom_forward(*inputs):
)
# Copied from transformers.models.bart.modeling_bart.BartModel with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

def __init__(self, config: BigBirdPegasusConfig):
super().__init__(config)

Expand Down Expand Up @@ -2480,7 +2482,12 @@ def forward(
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]

def __init__(self, config: BigBirdPegasusConfig):
super().__init__(config)
Expand Down Expand Up @@ -2651,6 +2658,8 @@ def _reorder_cache(past, beam_idx):
)
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

def __init__(self, config: BigBirdPegasusConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BigBirdPegasusModel(config)
Expand Down Expand Up @@ -2779,6 +2788,8 @@ def forward(
)
# Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -2910,6 +2921,8 @@ def forward(self, *args, **kwargs):


class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]

def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,8 @@ def custom_forward(*inputs):
BLENDERBOT_START_DOCSTRING,
)
class BlenderbotModel(BlenderbotPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]

def __init__(self, config: BlenderbotConfig):
super().__init__(config)

Expand Down Expand Up @@ -1231,6 +1233,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]

def __init__(self, config: BlenderbotConfig):
Expand Down Expand Up @@ -1420,6 +1424,8 @@ def forward(self, *args, **kwargs):

# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]

def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,8 @@ def custom_forward(*inputs):
BLENDERBOT_SMALL_START_DOCSTRING,
)
class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

def __init__(self, config: BlenderbotSmallConfig):
super().__init__(config)

Expand Down Expand Up @@ -1213,6 +1215,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]

def __init__(self, config: BlenderbotSmallConfig):
Expand Down Expand Up @@ -1387,6 +1391,8 @@ def forward(self, *args, **kwargs):

# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]

def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/models/convbert/modeling_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
CONVBERT_START_DOCSTRING,
)
class ConvBertModel(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]

def __init__(self, config):
super().__init__(config)
self.embeddings = ConvBertEmbeddings(config)
Expand Down Expand Up @@ -877,6 +879,8 @@ def forward(self, generator_hidden_states: torch.FloatTensor) -> torch.FloatTens

@add_start_docstrings("""ConvBERT Model with a `language modeling` head on top.""", CONVBERT_START_DOCSTRING)
class ConvBertForMaskedLM(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids", "generator.lm_head.weight"]

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -987,6 +991,8 @@ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
CONVBERT_START_DOCSTRING,
)
class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
Expand Down Expand Up @@ -1083,6 +1089,8 @@ def forward(
CONVBERT_START_DOCSTRING,
)
class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -1177,6 +1185,8 @@ def forward(
CONVBERT_START_DOCSTRING,
)
class ConvBertForTokenClassification(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
Expand Down Expand Up @@ -1259,6 +1269,8 @@ def forward(
CONVBERT_START_DOCSTRING,
)
class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]

def __init__(self, config):
super().__init__(config)

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/ctrl/modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,8 @@ def forward(
CTRL_START_DOCSTRING,
)
class CTRLLMHeadModel(CTRLPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]

def __init__(self, config):
super().__init__(config)
self.transformer = CTRLModel(config)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ def forward(
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
class DebertaForMaskedLM(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]

def __init__(self, config):
super().__init__(config)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ def forward(
# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2
class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]

def __init__(self, config):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1788,6 +1788,9 @@ def forward(
DEFORMABLE_DETR_START_DOCSTRING,
)
class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_keys_to_ignore_on_load_missing = ["bbox_embed\.[1-9]\d*", "class_embed\.[1-9]\d*"]

def __init__(self, config: DeformableDetrConfig):
super().__init__(config)

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,8 @@ def forward(
DISTILBERT_START_DOCSTRING,
)
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["vocab_projector.weight"]

def __init__(self, config: PretrainedConfig):
super().__init__(config)

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,8 @@ def forward(
ELECTRA_START_DOCSTRING,
)
class ElectraForMaskedLM(ElectraPreTrainedModel):
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -1530,6 +1532,8 @@ def forward(
"""ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.""", ELECTRA_START_DOCSTRING
)
class ElectraForCausalLM(ElectraPreTrainedModel):
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]

def __init__(self, config):
super().__init__(config)

Expand Down
Loading