Skip to content

[Bugfix] Fix qwen3-omni async thinker to talker decode alignment for #1758#2019

Merged
Gaohan123 merged 4 commits intovllm-project:mainfrom
Sy0307:fix/qwen3-omni-async-decode-latest-main
Mar 26, 2026
Merged

[Bugfix] Fix qwen3-omni async thinker to talker decode alignment for #1758#2019
Gaohan123 merged 4 commits intovllm-project:mainfrom
Sy0307:fix/qwen3-omni-async-decode-latest-main

Conversation

@Sy0307
Copy link
Copy Markdown
Contributor

@Sy0307 Sy0307 commented Mar 19, 2026

Purpose

Fix the Qwen3-Omni async_chunk text/audio misalignment introduced by #1758 and reported in #1830.

Root cause

The bug is in Qwen3OmniMoeForConditionalGeneration._thinker_decode_to_talker_decode().

In async mode, the old logic used the cumulative thinker_output_token_ids length to decide how many decode conditions were still available:

if start_index >= len(thinker_output_token_ids) - 1:

This is incorrect for Qwen3-Omni async handoff because:

  1. The first assistant text token has already been consumed during talker prefill.
  2. The trailing <|im_end|> token is a terminal marker, not a normal text decode condition.

As a result, async decode consumed one extra condition step. In practice this caused the talker-side decode embedding sequence to become misaligned with the thinker output tokens, so text could still be correct while the generated codec frames drifted and the final audio no longer matched the text.

For example, for the prompt:

What is the capital of China? Reply with exactly one word.

the thinker output is effectively:

  • Be
  • ijing
  • <|im_end|>

The correct async decode conditions should be:

  • ijing
  • eos

But the old logic treated the cumulative token list as if there were two normal text decode conditions before EOS, so async consumed the wrong condition boundary and produced mismatched audio.

What this PR changes

This PR fixes the async decode-state handoff by:

  1. Computing the expected async decode condition length correctly:

    • exclude the first text token already consumed by talker prefill
    • exclude the terminal <|im_end|> token
  2. Rebuilding / reusing cached_thinker_decode_embeddings against that corrected condition length, so async consume steps stay aligned with the remaining thinker decode embeddings.

  3. Advancing num_processed_tokens only when a real text condition is consumed.

    • wait / eos / pad steps no longer incorrectly advance decode-state
  4. Skipping talker_mtp when async is only waiting for additional valid decode condition cache.

    • this avoids generating codec frames from invalid wait/pad steps
  5. Adding focused unit tests for the async handoff boundary logic.

In short: this PR makes async decode semantics match the sync path’s effective behavior for the remaining text conditions after talker prefill.

Test Plan

1. Unit tests

Added:

  • tests/model_executor/models/test_qwen3_omni_async_decode.py

Covers:

  • single EOS emission followed by pad steps
  • consuming cached decode embeddings before EOS
  • skipping the first decode embedding already used by prefill
  • handling runtimes that return the full decode prefix
  • excluding terminal <|im_end|> from normal consume steps

Command:

python -m pytest tests/model_executor/models/test_qwen3_omni_async_decode.py

2. Remote async e2e validation

Validated on a real multi-GPU machine using:

  • Qwen3-Omni online serving
  • async_chunk stage config
  • fixed prompt:
    • What is the capital of China? Reply with exactly one word.
  • returned audio checked manually

Client:

  • examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py

Test Result

Before fix

  • text output: Beijing
  • audio sample: attach before_fix_async_bad.wav in the PR composer before submission
  • the audio does not match the text output when listened to manually
    before_fix.wav

After fix

  • text output: Beijing
  • audio sample: attach after_fix_async_good.wav in the PR composer before submission
  • the audio matches the text output when listened to manually
    fix.wav

Additional notes

  • The failure is async-specific and is caused by incorrect thinker->talker decode-state alignment.
  • This is not a sampling-only issue.
  • The fix is intentionally focused on the async handoff state machine in qwen3_omni.py.

cc @LJH-LBJ

@Sy0307 Sy0307 requested a review from hsliuustc0106 as a code owner March 19, 2026 14:38
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7421b4df47

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +613 to +616
if "run_talker_mtp" not in update_dict:
update_dict["run_talker_mtp"] = True
if not update_dict["run_talker_mtp"]:
update_dict["code_predictor_codes"] = torch.zeros(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Honor run_talker_mtp before scheduling decode MTP

run_talker_mtp=False does not currently suppress the MTP pass. I checked vllm_omni/worker/gpu_model_runner.py:1264-1284 and vllm_omni/platforms/npu/worker/npu_model_runner.py:409-417: both runners enqueue every span_len == 1 decode request and always call talker_mtp, without ever reading this flag. In async-chunk runs where the thinker is temporarily behind the talker, these wait/pad steps still execute the code predictor and overwrite the zeroed code_predictor_codes, so the PR's intended “skip talker_mtp while waiting” behavior never actually takes effect and bogus codec frames can still be emitted.

Useful? React with 👍 / 👎.

Comment on lines +935 to +936
elif cur_len >= expected_condition_len:
cached_thinker_decode_embeds = thinker_decode_embed[-expected_condition_len:]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Exclude <|im_end|> before tail-slicing full-prefix embeddings

This tail-slice is wrong on the finishing step when the runtime returns the full decode prefix. vllm_omni/model_executor/stage_input_processors/qwen3_omni.py:143-146 forwards pooling_output["0"] unchanged alongside thinker_output_token_ids, so a finished prefix can be [... text_embeds, im_end_embed]. With expected_condition_len == 1 (for example, a one-word answer plus <|im_end|>), thinker_decode_embed[-1:] caches the terminal embedding instead of the remaining text condition, which reintroduces the thinker/talker misalignment on the last spoken token for the same “full prefix” runtime this patch is trying to support.

Useful? React with 👍 / 👎.

@amy-why-3459
Copy link
Copy Markdown
Contributor

@LJH-LBJ @ZeldaHuang PTAL

@LJH-LBJ
Copy link
Copy Markdown
Contributor

LJH-LBJ commented Mar 20, 2026

Thanks for the detailed root-cause analysis. I've done an independent trace of the async decode path and have some observations:

#1758 did not modify _thinker_decode_to_talker_decode
The async decode alignment logic has been identical since it was first introduced — it predates #1758 entirely.

I think the len(thinker_output_token_ids) - 1 boundary is not the cause of the audio mismatch

For the "Beijing" example (thinker_output_token_ids = [Be, ijing, <|im_end|>], len=3):

Talker Step start_index >= len-1 (2)? text_step
Decode 1 0 No Be_embed (projected)
Decode 2 1 No ijing_embed (projected)
Decode 3 2 Yes tts_eos_embed
Decode 4+ 3 Yes (finished) tts_pad_embed

The - 1 means the <|im_end|> embedding is never consumed, going straight from ijing_embed to tts_eos. Compared to the sync path which feeds [ijing_embed, im_end_embed, tts_eos_embed], async skips the <|im_end|> embedding. This is a real divergence from sync — but <|im_end|> is a ChatML structural marker, not a spoken content token. The EOS signal is still emitted at the correct time. In practice this does not produce perceptible audio quality degradation.

@amy-why-3459
Copy link
Copy Markdown
Contributor

Thank you so much for your fix. I tested your PR, and the accuracy is indeed normal now. Could you provide a performance comparison before and after the PR?

@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Mar 20, 2026

Thank you so much for your fix. I tested your PR, and the accuracy is indeed normal now. Could you provide a performance comparison before and after the PR?

I will test it later. Thanks a lot.

@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Mar 20, 2026

Thanks for the detailed root-cause analysis. I've done an independent trace of the async decode path and have some observations:

#1758 did not modify _thinker_decode_to_talker_decode The async decode alignment logic has been identical since it was first introduced — it predates #1758 entirely.

I think the len(thinker_output_token_ids) - 1 boundary is not the cause of the audio mismatch

For the "Beijing" example (thinker_output_token_ids = [Be, ijing, <|im_end|>], len=3):

Talker Step start_index >= len-1 (2)? text_step
Decode 1 0 No Be_embed (projected)
Decode 2 1 No ijing_embed (projected)
Decode 3 2 Yes tts_eos_embed
Decode 4+ 3 Yes (finished) tts_pad_embed
The - 1 means the <|im_end|> embedding is never consumed, going straight from ijing_embed to tts_eos. Compared to the sync path which feeds [ijing_embed, im_end_embed, tts_eos_embed], async skips the <|im_end|> embedding. This is a real divergence from sync — but <|im_end|> is a ChatML structural marker, not a spoken content token. The EOS signal is still emitted at the correct time. In practice this does not produce perceptible audio quality degradation.

Root cause to be checked later. I am still working on find it out.

@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Mar 20, 2026
@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Mar 21, 2026

Test result

Test plan is same as #1758 .

Concurrency Branch Success/Fail Mean TTFT (ms) Mean E2EL (ms) Mean AUDIO_TTFP (ms) Mean AUDIO_RTF Audio Throughput
c1 main 10/0 261.18 22675.68 718.25 0.1524 6.6216
c1 #2019 10/0 262.14 22915.80 724.47 0.1596 6.2733
c4 main 10/0 838.88 38108.03 2027.76 0.2318 14.8967
c4 #2019 10/0 668.04 52984.44 1577.94 0.2499 13.1793
c10 main 10/0 7697.04 69542.82 11409.55 0.3885 20.1547
c10 #2019 10/0 7688.19 74148.83 10219.65 0.4500 17.7650

Under high concurrency, #2019 introduced some performance regression, but this is suspected to be caused by an erroneous reduction in RTF due to some audio generation being incorrect in the original implementation. Maybe we need check it further.

Plz verify it and PTAK @amy-why-3459

@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Mar 21, 2026

Possible Root Cause Explaination

The old origin/main (before #2012 or #1758 merged) did not visibly expose this async handoff bug because the old pipeline did not use the more sensitive re-prefill producer/consumer path.

More specifically:

  • The old origin/main fed the talker with predictor hidden states, e.g. the old talker path built summed_embeddings from intermediate hidden states instead of raw sampled codec embeddings:
middle_hidden_states.append(current_input[:, 2:-1, :])
mid_residual_hiddens = middle_hidden_states[pos]
pos_codec_hiddens = torch.cat(
    [layer0_embed] + mid_list + [last_residual_hidden],
    dim=1,
)
pos_summed = pos_codec_hiddens.sum(dim=1, keepdim=True)
  • The #2012 line still keeps the re-prefill predictor path that re-runs the whole proj_buf prefix on every step, writes newly sampled codec embeddings back into that prefix, and then lets the talker consume them via:
summed_embeddings[:, pos, :] = proj_buf[:, 1:, :].sum(dim=1)

As a result, a small misalignment in _thinker_decode_to_talker_decode() is no longer mediated by the old hidden-state path: it turns into a wrong sampled codec earlier, gets fed back into the next re-prefill step, and is then amplified into an audible text/audio mismatch.

So #2012 does not introduce the handoff bug itself. It preserves the #1758 re-prefill-sensitive path that makes the pre-existing async handoff bug reproducible. #2019 is the fix for the actual async handoff root cause.

PTAK @LJH-LBJ @amy-why-3459

@LJH-LBJ
Copy link
Copy Markdown
Contributor

LJH-LBJ commented Mar 21, 2026

In my opinion, the old code_predictor_forward uses the same discrete-sample → embedding-lookup → concat autoregressive loop. The middle_hidden_states is actually current_input[:, 2:-1, :] (codec embedding lookup results), not intermediate transformer hidden state outputs.

new_embed = self.model.codec_embedding[layer_idx](code)  # [batch, 1, hidden_size]
current_input = torch.cat([current_input, new_embed], dim=1)  # [batch, 3~n, hidden_size]

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments on the cache-rebuild logic and the hardcoded dtype.

expected_condition_len = max(0, len(thinker_output_token_ids) - 1 - int(has_terminal_token))

if cached_thinker_decode_embeds is not None:
cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device=device, dtype=torch.bfloat16)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoding torch.bfloat16 will silently down-cast on fp32 models (or up-cast on fp16). Use self.tts_eos_embed.dtype or the model dtype instead — applies to line 936 too.

if missing > 0:
tail = thinker_decode_embed[-min(cur_len, missing) :]
cached_thinker_decode_embeds = torch.cat([cached_thinker_decode_embeds, tail], dim=0)
elif cur_len > cached_len:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is only reachable when missing <= 0, i.e. cached_len >= expected_condition_len — the cache is already full. Replacing it with thinker_decode_embed[-expected_condition_len:] silently discards the existing valid cache. Is this intentional, or is the branch dead in practice? If dead, drop it.

)

update_dict["num_processed_tokens"] = info_dict.get("num_processed_tokens", 0) + span_len
processed_delta = update_dict.pop("num_processed_tokens_delta", span_len)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback to span_len means the prefill path silently relies on never setting this key. Worth an explicit update_dict["num_processed_tokens_delta"] = span_len in the prefill branch so the contract is obvious.


update_dict = {}
text_step = model._thinker_decode_to_talker_decode(
{
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test covers the wait step (run_talker_mtp=False) but never follows up with a call where new embeddings arrive and the model resumes. A wait-then-resume round-trip would strengthen coverage of the state machine.

@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Mar 23, 2026

I observed that after enabling @torch._dynamo.disable for both Qwen3OmniCodePredictorAttention.forward and Qwen3OmniCodePredictorMLP.forward, even without handling the bug fix in #2019, the correct audio is regenerated (the hash matches the audio generated after applying #2019). Perhaps the torch compile optimization here amplifies the precision error at this location. Please help check this as well. @LJH-LBJ

@LJH-LBJ
Copy link
Copy Markdown
Contributor

LJH-LBJ commented Mar 23, 2026

I observed that after enabling @torch._dynamo.disable for both Qwen3OmniCodePredictorAttention.forward and Qwen3OmniCodePredictorMLP.forward, even without handling the bug fix in #2019, the correct audio is regenerated (the hash matches the audio generated after applying #2019). Perhaps the torch compile optimization here amplifies the precision error at this location. Please help check this as well. @LJH-LBJ

I use main branch to test

python examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py     --query-type text     --modalities audio --port 46354  --model /workspace/models/Qwen3-Omni-30B-A3B-Instruct
test1: prompt:"What is the capital of China? Reply with exactly one word."
test2: prompt:"What is the capital of China? Reply with exactly ten word."

test 1 is wrong but test 2 is right

@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Mar 23, 2026

I observed that after enabling @torch._dynamo.disable for both Qwen3OmniCodePredictorAttention.forward and Qwen3OmniCodePredictorMLP.forward, even without handling the bug fix in #2019, the correct audio is regenerated (the hash matches the audio generated after applying #2019). Perhaps the torch compile optimization here amplifies the precision error at this location. Please help check this as well. @LJH-LBJ

I use main branch to test

python examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py     --query-type text     --modalities audio --port 46354  --model /workspace/models/Qwen3-Omni-30B-A3B-Instruct
test1: prompt:"What is the capital of China? Reply with exactly one word."
test2: prompt:"What is the capital of China? Reply with exactly ten word."

test 1 is wrong but test 2 is right

Plz re-test by this way.

  1. Attention
    In Qwen3OmniCodePredictorAttention.forward (line 93) :
+@torch._dynamo.disable
def forward(
    self,
    hidden_states: torch.Tensor,
    position_ids: torch.Tensor,
) -> torch.Tensor:
  1. MLP
    In Qwen3OmniCodePredictorMLP.forward (line 163) :
+@torch._dynamo.disable
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

Do not full compileoff.

@amy-why-3459
Copy link
Copy Markdown
Contributor

In high-concurrency scenarios, there is some impact on performance. Can we optimize the operation to reduce the impact on performance?

Signed-off-by: Sy03 <1370724210@qq.com>
@Sy0307 Sy0307 force-pushed the fix/qwen3-omni-async-decode-latest-main branch from 713eb01 to d83cb9c Compare March 24, 2026 12:36
@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Mar 24, 2026

Latest Change Summary

Changes Made:

  • Explicitly express prefill/decode contract: Record that talker prefill has already consumed the first text token.
  • Async decode starts from post-prefill boundary: Instead of consuming the first text token again.
  • num_processed_tokens only advances: When text condition is actually consumed.
  • Skip talker_mtp on async wait step: Avoid invalid codec work during wait/pad steps.
  • Keep only one cached_thinker_decode_embeddings: As a producer-consumer cache.
  • Do not introduce: decode_condition_queue, pending_thinker_decode_embeds, or multi-cache state machines.

Rationale:

  • Async handoff is on the decode hot path.
  • Introducing heavier state structures like queue/pending/multi-cache would significantly increase per-step bookkeeping, tensor movement, and synchronization overhead.
  • Single linear cache + explicit consume-state: Fixes handoff semantics without introducing noticeable steady-state performance loss.

Benchmark Results:

Same vLLM 0.17.1 benchmark path, c10 results:

Branch Success/Fail Mean TTFT (ms) Mean E2EL (ms) Mean AUDIO_TTFP (ms) Mean AUDIO_RTF Audio Throughput
main 10/0 8646.03 75451.92 12580.72 0.3910 20.5673
current patch 10/0 8581.19 76697.11 11879.83 0.3995 21.3097

Performance Conclusion:

Overall, this patch maintains performance close to main while fixing async handoff correctness.

cc @amy-why-3459 @LJH-LBJ

@Sy0307 Sy0307 force-pushed the fix/qwen3-omni-async-decode-latest-main branch 3 times, most recently from 00a58b5 to f6baa1c Compare March 24, 2026 15:35
Signed-off-by: Sy03 <1370724210@qq.com>
@Sy0307 Sy0307 force-pushed the fix/qwen3-omni-async-decode-latest-main branch from f6baa1c to f918ebf Compare March 24, 2026 15:52
@amy-why-3459
Copy link
Copy Markdown
Contributor

LGTM

1 similar comment
@LJH-LBJ
Copy link
Copy Markdown
Contributor

LJH-LBJ commented Mar 25, 2026

LGTM

Copy link
Copy Markdown
Collaborator

@gcanlin gcanlin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to add a test to avoid regression.

Signed-off-by: Sy03 <1370724210@qq.com>
@Sy0307 Sy0307 force-pushed the fix/qwen3-omni-async-decode-latest-main branch from 3c0d9c5 to d2d3c9a Compare March 25, 2026 09:46
@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Mar 25, 2026

It's better to add a test to avoid regression.

Done. Thanks for the advice.

Copy link
Copy Markdown
Collaborator

@Gaohan123 Gaohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks

@Gaohan123 Gaohan123 merged commit d15c91d into vllm-project:main Mar 26, 2026
8 checks passed
zhangj1an pushed a commit to zhangj1an/vllm-omni that referenced this pull request Mar 26, 2026
@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Mar 28, 2026

I observed that after enabling @torch._dynamo.disable for both Qwen3OmniCodePredictorAttention.forward and Qwen3OmniCodePredictorMLP.forward, even without handling the bug fix in #2019, the correct audio is regenerated (the hash matches the audio generated after applying #2019). Perhaps the torch compile optimization here amplifies the precision error at this location. Please help check this as well. @LJH-LBJ

This minor precision error needs to be re-check like #2274. @amy-why-3459 @LJH-LBJ

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants