Skip to content
Merged
9 changes: 8 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,14 @@ def _maybe_initialize_input_ids_for_generation(
break

if "inputs_embeds" in model_kwargs:
return torch.ones((batch_size, 0), dtype=torch.long, device=self.device)
return torch.ones(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add comment with reference to here

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohoh yes, forgot before push

Copy link
Copy Markdown
Contributor Author

@Sai-Suraj-27 Sai-Suraj-27 Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, @ydshieh. Thanks for pushing the fix. I was able to run the test on RTX PRO 6000, & it's running fine without the meta device issue. But incase of A-10 GPU the device_map="auto" is offloading the talker module to CPU & iiuc from accelerate code, it keeps the parameters of cpu/disk offloaded modules as meta tensors (which is why model.talker.device is giving "meta" in case of A10) & only loads the real-weights on to the GPU later just before forward.

Since the test ran fine on the big gpu but failing on A10, I think, I can confrim with this fix & that the issue is with how we are using self.device in this method. So, Maybe we can add a comment regarding this accelerate behaviour here pointing to this accelerate code.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Sai-Suraj-27

a comment is added (a few line below)

                # Use the device of the existing tensor to avoid any potential `meta` device isssue.
                # See PR #44848. (Previously, it used `self.device`.)

I think it's enough with the reference to this PR.

(batch_size, 0),
dtype=torch.long,
# Use the device of the existing tensor to avoid any potential `meta` device isssue, which is likely
# linked to the offloading behavior (keeping it on meta device). See PR #44848. Previously, it used
# `self.device`.
device=self.device if self.device.type != "meta" else model_kwargs["inputs_embeds"].device,
)

if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ class Qwen3OmniMoeConfig(PreTrainedConfig):
system_token_id: int = 8948
user_token_id: int = 872
assistant_token_id: int = 77091
initializer_range: float | None = None

def __post_init__(self, **kwargs):
if self.thinker_config is None:
Expand All @@ -640,6 +641,9 @@ def __post_init__(self, **kwargs):
elif isinstance(self.code2wav_config, dict):
self.code2wav_config = Qwen3OmniMoeCode2WavConfig(**self.code2wav_config)

if self.initializer_range is None:
self.initializer_range = self.thinker_config.initializer_range

super().__post_init__(**kwargs)

def get_text_config(self, decoder=False) -> "PreTrainedConfig":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1896,10 +1896,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
config: Qwen3OmniMoeThinkerConfig
base_model_prefix = "thinker"
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_no_split_modules = [
"Qwen3OmniMoeAudioEncoderLayer",
"Qwen3OmniMoeThinkerTextDecoderLayer",
]
_no_split_modules = ["Qwen3OmniMoeAudioEncoder", "Qwen3OmniMoeVisionEncoder"]
_can_record_outputs = {
"hidden_states": Qwen3OmniMoeThinkerTextDecoderLayer,
"attentions": Qwen3OmniMoeThinkerTextAttention,
Expand Down Expand Up @@ -3232,9 +3229,9 @@ def prepare_inputs_for_generation(
hidden_states = kwargs.pop("hidden_states", None)
inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values,
attention_mask,
inputs_embeds,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
is_first_iteration=is_first_iteration,
**kwargs,
)
Expand Down Expand Up @@ -4074,7 +4071,7 @@ def generate(
talker_codes = (
torch.stack([hid[-1] for hid in talker_result.hidden_states if hid[-1] is not None], dim=1)
.transpose(1, 2)
.to(self.code2wav.device)
.to(talker_result.hidden_states[-1][-1].device)
)
talker_wavs = self.code2wav.chunked_decode(talker_codes, chunk_size=300, left_context_size=25)

Expand Down
16 changes: 8 additions & 8 deletions src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ class Qwen3OmniMoeConfig(PreTrainedConfig):
system_token_id: int = 8948
user_token_id: int = 872
assistant_token_id: int = 77091
initializer_range: float | None = None

def __post_init__(self, **kwargs):
if self.thinker_config is None:
Expand All @@ -549,6 +550,9 @@ def __post_init__(self, **kwargs):
elif isinstance(self.code2wav_config, dict):
self.code2wav_config = Qwen3OmniMoeCode2WavConfig(**self.code2wav_config)

if self.initializer_range is None:
self.initializer_range = self.thinker_config.initializer_range

super().__post_init__(**kwargs)

def get_text_config(self, decoder=False) -> "PreTrainedConfig":
Expand Down Expand Up @@ -1074,10 +1078,6 @@ class Qwen3OmniMoeThinkerCausalLMOutputWithPast(MoeCausalLMOutputWithPast):


class Qwen3OmniMoeThinkerForConditionalGeneration(Qwen2_5OmniThinkerForConditionalGeneration):
_no_split_modules = [
"Qwen3OmniMoeAudioEncoderLayer",
"Qwen3OmniMoeThinkerTextDecoderLayer",
]
_can_record_outputs = {
"hidden_states": Qwen3OmniMoeThinkerTextDecoderLayer,
"attentions": Qwen3OmniMoeThinkerTextAttention,
Expand Down Expand Up @@ -1775,9 +1775,9 @@ def prepare_inputs_for_generation(
hidden_states = kwargs.pop("hidden_states", None)
inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values,
attention_mask,
inputs_embeds,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
is_first_iteration=is_first_iteration,
**kwargs,
)
Expand Down Expand Up @@ -2425,7 +2425,7 @@ def generate(
talker_codes = (
torch.stack([hid[-1] for hid in talker_result.hidden_states if hid[-1] is not None], dim=1)
.transpose(1, 2)
.to(self.code2wav.device)
.to(talker_result.hidden_states[-1][-1].device)
)
talker_wavs = self.code2wav.chunked_decode(talker_codes, chunk_size=300, left_context_size=25)

Expand Down
54 changes: 33 additions & 21 deletions tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
require_flash_attn,
require_torch,
require_torch_accelerator,
run_first,
slow,
torch_device,
)
Expand Down Expand Up @@ -677,7 +678,27 @@ def test_code_predictor_config_init(self):

@require_torch
class Qwen3OmniModelIntegrationTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = None

@classmethod
def get_model(cls):
if cls.model is None:
cls.model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-Omni-30B-A3B-Instruct", dtype=torch.bfloat16, device_map="auto"
)
return cls.model

@classmethod
def tearDownClass(cls):
if hasattr(cls, "model"):
del cls.model
cleanup(torch_device, gc_collect=True)

def setUp(self):
cleanup(torch_device, gc_collect=True)

self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen3-Omni-30B-A3B-Instruct", min_pixels=28 * 28, max_pixels=56 * 56
)
Expand Down Expand Up @@ -710,9 +731,7 @@ def tearDown(self):

@slow
def test_small_model_integration_test(self):
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-Omni-30B-A3B-Instruct", dtype=torch.bfloat16, device_map="auto"
)
model = self.get_model()

text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(
Expand Down Expand Up @@ -764,7 +783,7 @@ def test_small_model_integration_test(self):
)

EXPECTED_DECODED_TEXT = Expectations({
("cuda", (8, 6)): "user\nWhat's that sound and what kind of dog is this?\nassistant\nBased on the audio and visual information, here is a breakdown of what you're hearing and seeing:-",
("cuda", (8, 6)): "user\nWhat's that sound and what kind of dog is this?\nassistant\nBased on the audio and visual information, here is a breakdown of what you're hearing and seeing:\n\n",
("rocm", (9, 4)): "system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog is a Labrador Retriever.",
}).get_expectation() # fmt: skip

Expand All @@ -773,9 +792,7 @@ def test_small_model_integration_test(self):

@slow
def test_small_model_integration_test_batch(self):
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-Omni-30B-A3B-Instruct", dtype=torch.bfloat16, device_map="auto"
)
model = self.get_model()
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(
text=[text] * 2,
Expand All @@ -791,13 +808,9 @@ def test_small_model_integration_test_batch(self):

EXPECTED_DECODED_TEXTS = Expectations(
{
("cuda", 7) : [
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is of glass shattering, and the dog in the picture is a Labrador Retriever",
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is of glass shattering, and the dog in the picture is a Labrador Retriever",
],
("cuda", 8): [
"user\nWhat's that sound and what kind of dog is this?\nassistant\nBased on the audio and visual information, here is a breakdown of what you're hearing and seeing:\n\n",
"user\nWhat's that sound and what kind of dog is this?\nassistant\nBased on the audio and visual information, here is a breakdown of what you're hearing and seeing:\n\n"
"user\nWhat's that sound and what kind of dog is this?\nassistant\nBased on the audio and visual information provided:\n\nThe sound you hear is the distinct, high-pitched",
"user\nWhat's that sound and what kind of dog is this?\nassistant\nBased on the audio and visual information provided:\n\nThe sound you hear is the distinct, high-pitched",
],
("rocm", (9, 4)): [
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog is a Labrador Retriever.",
Expand All @@ -811,9 +824,7 @@ def test_small_model_integration_test_batch(self):

@slow
def test_small_model_integration_test_multiturn(self):
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-Omni-30B-A3B-Instruct", dtype=torch.bfloat16, device_map="auto"
)
model = self.get_model()

messages = [
self.messages[0],
Expand Down Expand Up @@ -857,9 +868,7 @@ def test_small_model_integration_test_multiturn(self):

@slow
def test_small_model_integration_test_w_audio(self):
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-Omni-30B-A3B-Instruct", dtype=torch.bfloat16, device_map="auto"
)
model = self.get_model()
audio_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav"

messages = [
Expand Down Expand Up @@ -894,8 +903,7 @@ def test_small_model_integration_test_w_audio(self):

EXPECTED_DECODED_TEXTS = Expectations(
{
("cuda", 7): "system\nYou are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.\nuser\n\nassistant\nWell, I can try. But it's not always that accurate. I might be able to make",
("cuda", 8): "'system\nYou are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.\nuser\n\nassistant\nYes, I can analyze audio inputs to understand spoken content, and I can also make inferences about'",
("cuda", 8): "system\nYou are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.\nuser\n\nassistant\nYes, I can analyze audio inputs to understand spoken content, and I can also process and respond to",
}
) # fmt: skip
EXPECTED_DECODED_TEXT = EXPECTED_DECODED_TEXTS.get_expectation()
Expand All @@ -906,6 +914,10 @@ def test_small_model_integration_test_w_audio(self):
)
self.assertFalse(torch.isnan(output[1]).any().item())

# Run this test first because it needs to load the model with `flash_attention_2`. For other tests, we need to keep
# the loaded model (without FA) in `cls.model`. If this test is not run first, when loading the flash attention
# model here, there is already a previous loaded model `cls.model` and we will get GPU OOM.
@run_first
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we want this to run first?

@slow
@require_flash_attn
@require_torch_accelerator
Expand Down
Loading