Skip to content

Fix weight tying logic between _tied_weights_keys and tie_word_embeddings#42385

Open
molbap wants to merge 26 commits intomainfrom
fix_weight_tying
Open

Fix weight tying logic between _tied_weights_keys and tie_word_embeddings#42385
molbap wants to merge 26 commits intomainfrom
fix_weight_tying

Conversation

@molbap
Copy link
Contributor

@molbap molbap commented Nov 25, 2025

What does this PR do?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@molbap
Copy link
Contributor Author

molbap commented Nov 25, 2025

run-slow: mistral

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/mistral"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@molbap
Copy link
Contributor Author

molbap commented Nov 25, 2025

run-slow: mistral3

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/mistral3"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

Model CI Report

❌ Failed tests

  • mistral3:
    tests/models/mistral3/test_modeling_mistral3.py::Mistral3ModelTest::test_config

@molbap
Copy link
Contributor Author

molbap commented Nov 25, 2025

run-slow: mistral, mistral3

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/mistral", "models/mistral3"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

Model CI Report

❌ Failed tests

  • mistral3:
    tests/models/mistral3/test_modeling_mistral3.py::Mistral3ModelTest::test_config

@molbap
Copy link
Contributor Author

molbap commented Nov 25, 2025

run-slow: olmoe

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/olmoe"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@molbap
Copy link
Contributor Author

molbap commented Nov 25, 2025

run-slow: olmoe, mistral, mistral3

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/mistral", "models/mistral3", "models/olmoe"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@molbap
Copy link
Contributor Author

molbap commented Nov 25, 2025

run-slow: olmoe, mistral, mistral3, phi3, starcoder

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/mistral", "models/mistral3", "models/olmoe", "models/phi3"]
quantizations: []

@molbap
Copy link
Contributor Author

molbap commented Nov 25, 2025

run-slow: olmoe, mistral, mistral3, phi3, starcoder, bigbird, bigbird_pegasus

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

⚠️ No test being reported (jobs are skipped or cancelled)!

@molbap
Copy link
Contributor Author

molbap commented Nov 26, 2025

run-slow: olmoe, mistral, mistral3, phi3, starcoder, bigbird, bigbird_pegasus, whisper, internvl, llava, llava_next, llava_next_video, qwen, fsmt, video_llava, deepseek_v3, qwen3_vl_moe

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/bigbird_pegasus", "models/deepseek_v3", "models/fsmt", "models/internvl", "models/llava", "models/llava_next", "models/llava_next_video", "models/mistral", "models/mistral3", "models/olmoe", "models/phi3", "models/qwen3_vl_moe", "models/video_llava", "models/whisper"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@molbap molbap mentioned this pull request Nov 27, 2025
@molbap molbap changed the title Various fixes Fix weight tying logic between _tied_weights_keys and tie_word_embeddings Nov 27, 2025
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: fsmt

@molbap
Copy link
Contributor Author

molbap commented Nov 27, 2025

I removed the fallback to parent config in case tie_word_embeddings is not found in the text. I also attempted to add a test, might be a bit overkill, let's see

@molbap
Copy link
Contributor Author

molbap commented Nov 27, 2025

A remaining headscratcher is tie_encoder_decoder which is now redundant. I did a fallback that looks like

should_tie = tie_encoder_decoder if tie_word_embeddings is None else tie_word_embeddings

and it seems to work/keep the key as source of authority. LMK!

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: fsmt, kyutai_speech_to_text, musicgen, musicgen_melody

1 similar comment
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: fsmt, kyutai_speech_to_text, musicgen, musicgen_melody

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Nov 28, 2025

A remaining headscratcher is tie_encoder_decoder which is now redundant. I did a fallback that looks like

CMIIW, tie_encoder_decoder means tying all weights from encoder to decoder, i.e. attention and stuff isn't it? In which case

tie_encoder_decoder = config.tie_encoder_decoder if config.is_encoder_decoder else False?

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments/answers! Sorry for the delay!

Comment on lines 506 to -507
mismatch_keys.add((target_name, param_value.shape, ref.shape))
module_obj.param_name._is_hf_initialized = False # Needs to be initialized
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, was completely masked by the log_to_misc, along with other issues we have currently on some models. We will need to update the logger to avoid silencing important issues!!

Comment on lines +2257 to +2263
text_config = self.config.get_text_config(decoder=True)
if not hasattr(text_config, "tie_word_embeddings"):
logger.warning(
f"Text config {text_config.__class__.__name__} does not have 'tie_word_embeddings' attribute. "
"This may cause issues with weight tying."
)
tie_word_embeddings = getattr(text_config, "tie_word_embeddings", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think all multimodals rely on the text_config here unfortunately, do they? I.e. what we want is that each model decide on its own if it should tie its own weights. So we don't really want to check the text_config...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it's the opposite of what we discussed above. In that case, what do you suggest?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the issue is that if you have a submodel such as self.text_model = AutoModel._from_config(text_config), you cannot know its weights in advance as it can be any model. So that model is the only responsible for its own weight.
We cannot delegate for the top-most model, as this makes it wayyy too hard and not scalable in general

)
tie_word_embeddings = getattr(text_config, "tie_word_embeddings", None)
tie_encoder_decoder = getattr(self.config, "tie_encoder_decoder", False)
should_tie = tie_encoder_decoder if tie_word_embeddings is None else tie_word_embeddings
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we basically disregard tie_encoder_decoder completely? Because from your change is configuration_utils, tie_word_embeddings will NEVER be None anymore if I understand correctly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I'm reverting this unfortunately, I would indeed like to get rid of tie_encoder_decoder

Comment on lines +3179 to +3183
tied_keys_attr = getattr(self, "all_tied_weights_keys", None)
if tied_keys_attr is not None:
_tied_weights_keys = set(tied_keys_attr.keys())
else:
_tied_weights_keys = set(_get_tied_weight_keys(self))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact here is that I think some older models were tying using _tie_weights (the one with underscore) that were ALWAYS run, independently of the configs. So they could have tied weights even though the config say to not tie technically. In save_pretrained, we check the pointers of the weight tensors, so we know the tied weights from there and we cannot be wrong about it. But then if we use the good citizen all_tied_weights_keys, which correctly looks at the config, some of those tied weights will not find their source as they are not SUPPOSED to be tied (but they were anyway).

So technically looking at all potential _tied_weights_keys here is not wrong as we check pointers anyway, and this avoids this kind of issues.
But indeed, in the future we want to always know which weights are tied and simply rely on the internal list (or even better, recomputing it with get_expanded_tied_weights_keys(all_submodels=True))

Comment on lines +4413 to +4428
for key in missing_keys - self.all_tied_weights_keys.keys():
tied_keys_attr = getattr(self, "all_tied_weights_keys", {}) or {}
tied_keys = set(tied_keys_attr.keys())
for key in missing_keys - tied_keys:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have this? all_tied_weights_keys is guaranteed to be a dict already at this point from post_init

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes indeed, could be just

tied_keys = set(self.all_tied_weights_keys.keys())
  for key in missing_keys - tied_keys:
 do stuff

will update

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to cast it again as a set the keys already are a subclass of set!

Comment on lines +2116 to +2117
if not hasattr(model_tied, "_tied_weights_keys") or not model_tied._tied_weights_keys:
continue
Copy link
Contributor Author

@molbap molbap Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note from huddle: should recurse through all submodels here to get the tied keys

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: fsmt, kyutai_speech_to_text, musicgen, musicgen_melody

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: fsmt, kyutai_speech_to_text, llava_onevision, musicgen, musicgen_melody

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: fsmt, kyutai_speech_to_text, llava_next_video, llava_onevision, musicgen, musicgen_melody

1 similar comment
@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: fsmt, kyutai_speech_to_text, llava_next_video, llava_onevision, musicgen, musicgen_melody

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants