Skip to content

Fix Qwen3OmniMoe Talker weight loading and config initialization#43084

Merged
ArthurZucker merged 3 commits intohuggingface:mainfrom
HuiyingLi:huiyingl/qwen3_omni_fix
Jan 5, 2026
Merged

Fix Qwen3OmniMoe Talker weight loading and config initialization#43084
ArthurZucker merged 3 commits intohuggingface:mainfrom
HuiyingLi:huiyingl/qwen3_omni_fix

Conversation

@HuiyingLi
Copy link
Copy Markdown
Contributor

What does this PR do?

This PR fixes several issues preventing Qwen3OmniMoeForConditionalGeneration from loading and running correctly from the Qwen/Qwen3-Omni-30B-A3B-Instruct checkpoint.

Issues Fixed

Issue: #43083

  1. Incorrect inherited _tied_weights_keys in Talker
    Qwen3OmniMoeTalkerForConditionalGeneration inherits from Qwen3MoeForCausalLM, which defines _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}. However, the Talker model uses codec_head instead of lm_head and doesn't tie weights. This caused loading errors as the loader expected keys that don't exist in the checkpoint. This causes garbled audio output.
    Fix: Override _tied_weights_keys = {} in the Talker class.

  2. Incorrect _tp_plan and _pp_plan references
    The inherited tensor/pipeline parallelism plans reference lm_head, but the Talker uses codec_head.
    Fix: Override with _tp_plan = {"codec_head": "colwise_rep"} and _pp_plan = {"codec_head": (["hidden_states"], ["logits"])}.

  3. Missing initializer_range in config classes
    Qwen3OmniMoeTalkerConfig, Qwen3OmniMoeCode2WavConfig, and Qwen3OmniMoeConfig were missing the initializer_range attribute. This caused AttributeError during _init_weights() calls.
    Fix: Add initializer_range attribute to all affected config classes.

Who can review?

@zucchini-nlp @ArthurZucker @Cyrilvallez

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

@auto_docstring
class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the correct way is codec_head: model.codec_embedding.weight. It will allow users to tie weights if needed. We just need to make sure that the model is not tying weight, I see that the default is already tie_word_embeddings=False in config

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much @zucchini-nlp ! Fixed.

self.audio_start_token_id = audio_start_token_id
self.vision_start_token_id = vision_start_token_id
self.speaker_id = speaker_id
self.initializer_range = self.text_config.initializer_range
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we're using text config's init range in any case, we can instead update init_weights to use config.get_text_config().init_range

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Jan 5, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen3_omni_moe

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TY!

@ArthurZucker ArthurZucker merged commit 64a476b into huggingface:main Jan 5, 2026
17 checks passed
sniper35 pushed a commit to sniper35/transformers that referenced this pull request Jan 5, 2026
…gingface#43084)

* fix modular_qwen3_omni_moe

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* update generated configuration and modeling file

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix tie weight keys

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

---------

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
…gingface#43084)

* fix modular_qwen3_omni_moe

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* update generated configuration and modeling file

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix tie weight keys

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

---------

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants