-
Notifications
You must be signed in to change notification settings - Fork 33.6k
Fix weight tying logic between _tied_weights_keys and tie_word_embeddings
#42385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 14 commits
47b0e0e
70fb402
6ce2bc6
60665c6
779012a
aa7ab80
afa33c5
22a258f
076f20e
58ec6e6
846bb69
4abe6d7
5a05c97
9c81dfc
e22acad
f939769
9960aa2
24f0b09
d2a03d2
8c4fb33
c8d8ad0
589e89c
65dca34
f85d0f4
b5537a3
587fef4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2248,8 +2248,10 @@ def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict: | |
| return expanded_tied_weights | ||
|
|
||
| tied_mapping = self._tied_weights_keys | ||
| text_config = self.config.get_text_config(decoder=True) | ||
| tie_word_embeddings = getattr(text_config, "tie_word_embeddings", self.config.tie_word_embeddings) | ||
| # If the config does not specify any tying, return empty dict | ||
| if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder: | ||
| if not tie_word_embeddings and not self.config.tie_encoder_decoder: | ||
| return {} | ||
| # If None, return empty dict | ||
| elif tied_mapping is None: | ||
|
|
@@ -3174,7 +3176,11 @@ def save_pretrained( | |
| shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} | ||
|
|
||
| # Recursively descend to find tied weight keys | ||
| _tied_weights_keys = set(_get_tied_weight_keys(self)) | ||
| tied_keys_attr = getattr(self, "all_tied_weights_keys", None) | ||
| if tied_keys_attr is not None: | ||
| _tied_weights_keys = set(tied_keys_attr.keys()) | ||
| else: | ||
| _tied_weights_keys = set(_get_tied_weight_keys(self)) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the main change. IMO it's why
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, i think we fill it in in many cases just to pass some tests even if the released models don't tie weights
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I am a bit lost here. Do we snow try to tie weights with
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it's what I observed. Tried to get a reproducer here @zucchini-nlp @ArthurZucker from transformers import PaliGemmaConfig, PaliGemmaForConditionalGeneration
from transformers.models.gemma.configuration_gemma import GemmaConfig
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
vision_config = SiglipVisionConfig(
hidden_size=32,
intermediate_size=64,
num_hidden_layers=1,
num_attention_heads=2,
image_size=32,
patch_size=4,
vocab_size=64,
projection_dim=32,
).to_dict()
text_config = GemmaConfig(
vocab_size=64,
hidden_size=32,
intermediate_size=64,
num_hidden_layers=1,
num_attention_heads=4,
num_key_value_heads=4,
tie_word_embeddings=False,
).to_dict()
config = PaliGemmaConfig(vision_config=vision_config, text_config=text_config)
print("Config tie flags:")
print(f" config.tie_word_embeddings -> {config.tie_word_embeddings}")
print(f" config.text_config.tie_word_embeddings -> {config.text_config.tie_word_embeddings}")
model = PaliGemmaForConditionalGeneration(config)
lm_head_weight = model.lm_head.weight
input_embed_weight = model.model.language_model.embed_tokens.weight
tied = lm_head_weight.data_ptr() == input_embed_weight.data_ptr()
print("\nModel tying details:")
print(f" class _tied_weights_keys -> {model._tied_weights_keys}")
print(f" model.all_tied_weights_keys -> {model.all_tied_weights_keys}")
print(f" lm_head shares embedding tensor? -> {tied}")this will show tied weights. However if you go to modeling_paligemma and change _tied_weights_keys = {}# {"lm_head.weight": "model.language_model.embed_tokens.weight"}then it will work and weights will not be tied. I think the solution is to use and for instance for There might be a simpler fix, not seeing it right now
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fact here is that I think some older models were tying using So technically looking at all potential
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Hmm ok, but then why are models like mistral3 broken by the current setup? Ah because you mean we did it in Not super sure of the way forward there, can you elaborate a bit? |
||
| error_names = [] | ||
| to_delete_names = set() | ||
| for names in shared_ptrs.values(): | ||
|
|
@@ -4408,7 +4414,9 @@ def _move_missing_keys_from_meta_to_cpu( | |
| # The tied weight keys are in the "missing" usually, but they should not be moved (they will be tied anyway) | ||
| # This is especially important because if they are moved, they will lose the `_is_hf_initialized` flag, and they | ||
| # will be re-initialized for nothing (which can be quite long) | ||
| for key in missing_keys - self.all_tied_weights_keys.keys(): | ||
| tied_keys_attr = getattr(self, "all_tied_weights_keys", {}) or {} | ||
| tied_keys = set(tied_keys_attr.keys()) | ||
| for key in missing_keys - tied_keys: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have this?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah yes indeed, could be just tied_keys = set(self.all_tied_weights_keys.keys())
for key in missing_keys - tied_keys:
do stuffwill update
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to cast it again as a |
||
| param = model_state_dict[key] | ||
| # Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them | ||
| if param.device == torch.device("meta"): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is what we were trying to avoid but it might not really be possible as vlms do indeed just rely on child's ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I think you're correct unfortunately, not 100% sure. invoking @zucchini-nlp 's brain here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed with Arthur, I think
self.config.get_text_config(decoder=True).tie_word_embeddingsis equivalent of what we need. The fallback will probably just hide the bug if some configs are implemented poorlyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, removed the fallback and added a warning!