Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,10 @@ def from_dict(

config = cls(**config_dict)

# default tie_word_embeddings to False if None, see https://github.com/huggingface/transformers/issues/42313
if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is None:
config.tie_word_embeddings = False

# Update config with kwargs if needed
if "num_labels" in kwargs and "id2label" in kwargs:
num_labels = kwargs["num_labels"]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _build_checkpoint_conversion_mapping():
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
mapping["olmoe"] = mapping["qwen2_moe"].copy()
mapping["minimax"] = mapping["mixtral"].copy()

return mapping
Expand Down
14 changes: 11 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2248,8 +2248,10 @@ def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict:
return expanded_tied_weights

tied_mapping = self._tied_weights_keys
text_config = self.config.get_text_config(decoder=True)
tie_word_embeddings = getattr(text_config, "tie_word_embeddings", self.config.tie_word_embeddings)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is what we were trying to avoid but it might not really be possible as vlms do indeed just rely on child's ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think you're correct unfortunately, not 100% sure. invoking @zucchini-nlp 's brain here

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed with Arthur, I think self.config.get_text_config(decoder=True).tie_word_embeddings is equivalent of what we need. The fallback will probably just hide the bug if some configs are implemented poorly

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, removed the fallback and added a warning!

# If the config does not specify any tying, return empty dict
if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder:
if not tie_word_embeddings and not self.config.tie_encoder_decoder:
return {}
# If None, return empty dict
elif tied_mapping is None:
Expand Down Expand Up @@ -3174,7 +3176,11 @@ def save_pretrained(
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}

# Recursively descend to find tied weight keys
_tied_weights_keys = set(_get_tied_weight_keys(self))
tied_keys_attr = getattr(self, "all_tied_weights_keys", None)
if tied_keys_attr is not None:
_tied_weights_keys = set(tied_keys_attr.keys())
else:
_tied_weights_keys = set(_get_tied_weight_keys(self))

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change. IMO it's why tie_word_embeddings is not authoritative rn, it is entirely skipped if tied_weights_keys is filled

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i think we fill it in in many cases just to pass some tests even if the released models don't tie weights

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I am a bit lost here. Do we snow try to tie weights with _tied_weights_keys attribute even when config.tie_word_embeddings = False?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's what I observed. Tried to get a reproducer here @zucchini-nlp @ArthurZucker

from transformers import PaliGemmaConfig, PaliGemmaForConditionalGeneration
from transformers.models.gemma.configuration_gemma import GemmaConfig
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig

vision_config = SiglipVisionConfig(
    hidden_size=32,
    intermediate_size=64,
    num_hidden_layers=1,
    num_attention_heads=2,
    image_size=32,
    patch_size=4,
    vocab_size=64,
    projection_dim=32,
).to_dict()

text_config = GemmaConfig(
    vocab_size=64,
    hidden_size=32,
    intermediate_size=64,
    num_hidden_layers=1,
    num_attention_heads=4,
    num_key_value_heads=4,
    tie_word_embeddings=False,
).to_dict()

config = PaliGemmaConfig(vision_config=vision_config, text_config=text_config)

print("Config tie flags:")
print(f"  config.tie_word_embeddings              -> {config.tie_word_embeddings}")
print(f"  config.text_config.tie_word_embeddings  -> {config.text_config.tie_word_embeddings}")

model = PaliGemmaForConditionalGeneration(config)
lm_head_weight = model.lm_head.weight
input_embed_weight = model.model.language_model.embed_tokens.weight
tied = lm_head_weight.data_ptr() == input_embed_weight.data_ptr()

print("\nModel tying details:")
print(f"  class _tied_weights_keys      -> {model._tied_weights_keys}")
print(f"  model.all_tied_weights_keys   -> {model.all_tied_weights_keys}")
print(f"  lm_head shares embedding tensor? -> {tied}")

this will show tied weights. However if you go to modeling_paligemma and change

    _tied_weights_keys = {}# {"lm_head.weight": "model.language_model.embed_tokens.weight"}

then it will work and weights will not be tied. I think the solution is to use all_tied_weights_keys instead of _get_tied_weights_keys here. It's what I understand from the logic flow

and for instance for mistral3 it breaks the model generation, but for several others as well, this is why I wanted the run-slow upgrade mentioned hehe.

There might be a simpler fix, not seeing it right now

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact here is that I think some older models were tying using _tie_weights (the one with underscore) that were ALWAYS run, independently of the configs. So they could have tied weights even though the config say to not tie technically. In save_pretrained, we check the pointers of the weight tensors, so we know the tied weights from there and we cannot be wrong about it. But then if we use the good citizen all_tied_weights_keys, which correctly looks at the config, some of those tied weights will not find their source as they are not SUPPOSED to be tied (but they were anyway).

So technically looking at all potential _tied_weights_keys here is not wrong as we check pointers anyway, and this avoids this kind of issues.
But indeed, in the future we want to always know which weights are tied and simply rely on the internal list (or even better, recomputing it with get_expanded_tied_weights_keys(all_submodels=True))

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So technically looking at all potential _tied_weights_keys here is not wrong as we check pointers anyway, and this avoids this kind of issues.

Hmm ok, but then why are models like mistral3 broken by the current setup? Ah because you mean we did it in save_pretrained but load_pretrained has to do with what exists now?

Not super sure of the way forward there, can you elaborate a bit?

error_names = []
to_delete_names = set()
for names in shared_ptrs.values():
Expand Down Expand Up @@ -4408,7 +4414,9 @@ def _move_missing_keys_from_meta_to_cpu(
# The tied weight keys are in the "missing" usually, but they should not be moved (they will be tied anyway)
# This is especially important because if they are moved, they will lose the `_is_hf_initialized` flag, and they
# will be re-initialized for nothing (which can be quite long)
for key in missing_keys - self.all_tied_weights_keys.keys():
tied_keys_attr = getattr(self, "all_tied_weights_keys", {}) or {}
tied_keys = set(tied_keys_attr.keys())
for key in missing_keys - tied_keys:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have this? all_tied_weights_keys is guaranteed to be a dict already at this point from post_init

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes indeed, could be just

tied_keys = set(self.all_tied_weights_keys.keys())
  for key in missing_keys - tied_keys:
 do stuff

will update

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to cast it again as a set the keys already are a subclass of set!

param = model_state_dict[key]
# Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
if param.device == torch.device("meta"):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/fsmt/configuration_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(
bos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
num_hidden_layers=encoder_layers,
tie_word_embeddings=tie_word_embeddings,
)
if "decoder" in common_kwargs:
del common_kwargs["decoder"]
Expand Down
3 changes: 3 additions & 0 deletions tests/models/fsmt/test_modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def get_config(self):
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
tie_word_embeddings=True,
)

def prepare_config_and_inputs_for_common(self):
Expand Down Expand Up @@ -254,6 +255,7 @@ def test_ensure_weights_are_shared(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()

config.tie_word_embeddings = True
config.decoder.tie_word_embeddings = True
model = FSMTForConditionalGeneration(config)

# FSMT shares three weights.
Expand All @@ -270,6 +272,7 @@ def test_ensure_weights_are_shared(self):
)

config.tie_word_embeddings = False
config.decoder.tie_word_embeddings = False
model = FSMTForConditionalGeneration(config)

# FSMT shares three weights.
Expand Down