Fix weight tying logic between _tied_weights_keys and tie_word_embeddings#42385
Fix weight tying logic between _tied_weights_keys and tie_word_embeddings#42385
_tied_weights_keys and tie_word_embeddings#42385Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
run-slow: mistral |
|
This comment contains models: ["models/mistral"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
|
run-slow: mistral3 |
|
This comment contains models: ["models/mistral3"] |
CI ResultsModel CI Report❌ Failed tests
|
|
run-slow: mistral, mistral3 |
|
This comment contains models: ["models/mistral", "models/mistral3"] |
CI ResultsModel CI Report❌ Failed tests
|
|
run-slow: olmoe |
|
This comment contains models: ["models/olmoe"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
|
run-slow: olmoe, mistral, mistral3 |
|
This comment contains models: ["models/mistral", "models/mistral3", "models/olmoe"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
|
run-slow: olmoe, mistral, mistral3, phi3, starcoder |
|
This comment contains models: ["models/mistral", "models/mistral3", "models/olmoe", "models/phi3"] |
|
run-slow: olmoe, mistral, mistral3, phi3, starcoder, bigbird, bigbird_pegasus |
CI Results |
|
run-slow: olmoe, mistral, mistral3, phi3, starcoder, bigbird, bigbird_pegasus, whisper, internvl, llava, llava_next, llava_next_video, qwen, fsmt, video_llava, deepseek_v3, qwen3_vl_moe |
|
This comment contains models: ["models/bigbird_pegasus", "models/deepseek_v3", "models/fsmt", "models/internvl", "models/llava", "models/llava_next", "models/llava_next_video", "models/mistral", "models/mistral3", "models/olmoe", "models/phi3", "models/qwen3_vl_moe", "models/video_llava", "models/whisper"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
_tied_weights_keys and tie_word_embeddings
|
[For maintainers] Suggested jobs to run (before merge) run-slow: fsmt |
|
I removed the fallback to parent config in case |
|
A remaining headscratcher is should_tie = tie_encoder_decoder if tie_word_embeddings is None else tie_word_embeddingsand it seems to work/keep the key as source of authority. LMK! |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: fsmt, kyutai_speech_to_text, musicgen, musicgen_melody |
1 similar comment
|
[For maintainers] Suggested jobs to run (before merge) run-slow: fsmt, kyutai_speech_to_text, musicgen, musicgen_melody |
CMIIW,
|
Cyrilvallez
left a comment
There was a problem hiding this comment.
A few comments/answers! Sorry for the delay!
| mismatch_keys.add((target_name, param_value.shape, ref.shape)) | ||
| module_obj.param_name._is_hf_initialized = False # Needs to be initialized |
There was a problem hiding this comment.
Yes, was completely masked by the log_to_misc, along with other issues we have currently on some models. We will need to update the logger to avoid silencing important issues!!
src/transformers/modeling_utils.py
Outdated
| text_config = self.config.get_text_config(decoder=True) | ||
| if not hasattr(text_config, "tie_word_embeddings"): | ||
| logger.warning( | ||
| f"Text config {text_config.__class__.__name__} does not have 'tie_word_embeddings' attribute. " | ||
| "This may cause issues with weight tying." | ||
| ) | ||
| tie_word_embeddings = getattr(text_config, "tie_word_embeddings", None) |
There was a problem hiding this comment.
I don't think all multimodals rely on the text_config here unfortunately, do they? I.e. what we want is that each model decide on its own if it should tie its own weights. So we don't really want to check the text_config...
There was a problem hiding this comment.
hmm, it's the opposite of what we discussed above. In that case, what do you suggest?
There was a problem hiding this comment.
Well the issue is that if you have a submodel such as self.text_model = AutoModel._from_config(text_config), you cannot know its weights in advance as it can be any model. So that model is the only responsible for its own weight.
We cannot delegate for the top-most model, as this makes it wayyy too hard and not scalable in general
src/transformers/modeling_utils.py
Outdated
| ) | ||
| tie_word_embeddings = getattr(text_config, "tie_word_embeddings", None) | ||
| tie_encoder_decoder = getattr(self.config, "tie_encoder_decoder", False) | ||
| should_tie = tie_encoder_decoder if tie_word_embeddings is None else tie_word_embeddings |
There was a problem hiding this comment.
So we basically disregard tie_encoder_decoder completely? Because from your change is configuration_utils, tie_word_embeddings will NEVER be None anymore if I understand correctly
There was a problem hiding this comment.
yeah, I'm reverting this unfortunately, I would indeed like to get rid of tie_encoder_decoder
src/transformers/modeling_utils.py
Outdated
| tied_keys_attr = getattr(self, "all_tied_weights_keys", None) | ||
| if tied_keys_attr is not None: | ||
| _tied_weights_keys = set(tied_keys_attr.keys()) | ||
| else: | ||
| _tied_weights_keys = set(_get_tied_weight_keys(self)) |
There was a problem hiding this comment.
The fact here is that I think some older models were tying using _tie_weights (the one with underscore) that were ALWAYS run, independently of the configs. So they could have tied weights even though the config say to not tie technically. In save_pretrained, we check the pointers of the weight tensors, so we know the tied weights from there and we cannot be wrong about it. But then if we use the good citizen all_tied_weights_keys, which correctly looks at the config, some of those tied weights will not find their source as they are not SUPPOSED to be tied (but they were anyway).
So technically looking at all potential _tied_weights_keys here is not wrong as we check pointers anyway, and this avoids this kind of issues.
But indeed, in the future we want to always know which weights are tied and simply rely on the internal list (or even better, recomputing it with get_expanded_tied_weights_keys(all_submodels=True))
src/transformers/modeling_utils.py
Outdated
| for key in missing_keys - self.all_tied_weights_keys.keys(): | ||
| tied_keys_attr = getattr(self, "all_tied_weights_keys", {}) or {} | ||
| tied_keys = set(tied_keys_attr.keys()) | ||
| for key in missing_keys - tied_keys: |
There was a problem hiding this comment.
Why do we have this? all_tied_weights_keys is guaranteed to be a dict already at this point from post_init
There was a problem hiding this comment.
ah yes indeed, could be just
tied_keys = set(self.all_tied_weights_keys.keys())
for key in missing_keys - tied_keys:
do stuffwill update
There was a problem hiding this comment.
No need to cast it again as a set the keys already are a subclass of set!
| if not hasattr(model_tied, "_tied_weights_keys") or not model_tied._tied_weights_keys: | ||
| continue |
There was a problem hiding this comment.
note from huddle: should recurse through all submodels here to get the tied keys
|
[For maintainers] Suggested jobs to run (before merge) run-slow: fsmt, kyutai_speech_to_text, musicgen, musicgen_melody |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: fsmt, kyutai_speech_to_text, llava_onevision, musicgen, musicgen_melody |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: fsmt, kyutai_speech_to_text, llava_next_video, llava_onevision, musicgen, musicgen_melody |
1 similar comment
|
[For maintainers] Suggested jobs to run (before merge) run-slow: fsmt, kyutai_speech_to_text, llava_next_video, llava_onevision, musicgen, musicgen_melody |
What does this PR do?