[Bugfix] Fix qwen3-omni async thinker to talker decode alignment for #1758#2019
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7421b4df47
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if "run_talker_mtp" not in update_dict: | ||
| update_dict["run_talker_mtp"] = True | ||
| if not update_dict["run_talker_mtp"]: | ||
| update_dict["code_predictor_codes"] = torch.zeros( |
There was a problem hiding this comment.
Honor
run_talker_mtp before scheduling decode MTP
run_talker_mtp=False does not currently suppress the MTP pass. I checked vllm_omni/worker/gpu_model_runner.py:1264-1284 and vllm_omni/platforms/npu/worker/npu_model_runner.py:409-417: both runners enqueue every span_len == 1 decode request and always call talker_mtp, without ever reading this flag. In async-chunk runs where the thinker is temporarily behind the talker, these wait/pad steps still execute the code predictor and overwrite the zeroed code_predictor_codes, so the PR's intended “skip talker_mtp while waiting” behavior never actually takes effect and bogus codec frames can still be emitted.
Useful? React with 👍 / 👎.
| elif cur_len >= expected_condition_len: | ||
| cached_thinker_decode_embeds = thinker_decode_embed[-expected_condition_len:] |
There was a problem hiding this comment.
Exclude
<|im_end|> before tail-slicing full-prefix embeddings
This tail-slice is wrong on the finishing step when the runtime returns the full decode prefix. vllm_omni/model_executor/stage_input_processors/qwen3_omni.py:143-146 forwards pooling_output["0"] unchanged alongside thinker_output_token_ids, so a finished prefix can be [... text_embeds, im_end_embed]. With expected_condition_len == 1 (for example, a one-word answer plus <|im_end|>), thinker_decode_embed[-1:] caches the terminal embedding instead of the remaining text condition, which reintroduces the thinker/talker misalignment on the last spoken token for the same “full prefix” runtime this patch is trying to support.
Useful? React with 👍 / 👎.
|
@LJH-LBJ @ZeldaHuang PTAL |
7421b4d to
84c4c42
Compare
|
Thanks for the detailed root-cause analysis. I've done an independent trace of the async decode path and have some observations: #1758 did not modify I think the For the "Beijing" example (
The |
|
Thank you so much for your fix. I tested your PR, and the accuracy is indeed normal now. Could you provide a performance comparison before and after the PR? |
I will test it later. Thanks a lot. |
Root cause to be checked later. I am still working on find it out. |
Test resultTest plan is same as #1758 .
Under high concurrency, #2019 introduced some performance regression, but this is suspected to be caused by an erroneous reduction in RTF due to some audio generation being incorrect in the original implementation. Maybe we need check it further. Plz verify it and PTAK @amy-why-3459 |
Possible Root Cause ExplainationThe old More specifically:
middle_hidden_states.append(current_input[:, 2:-1, :])
mid_residual_hiddens = middle_hidden_states[pos]
pos_codec_hiddens = torch.cat(
[layer0_embed] + mid_list + [last_residual_hidden],
dim=1,
)
pos_summed = pos_codec_hiddens.sum(dim=1, keepdim=True)
summed_embeddings[:, pos, :] = proj_buf[:, 1:, :].sum(dim=1)As a result, a small misalignment in So PTAK @LJH-LBJ @amy-why-3459 |
|
In my opinion, the old |
lishunyang12
left a comment
There was a problem hiding this comment.
Left a few comments on the cache-rebuild logic and the hardcoded dtype.
| expected_condition_len = max(0, len(thinker_output_token_ids) - 1 - int(has_terminal_token)) | ||
|
|
||
| if cached_thinker_decode_embeds is not None: | ||
| cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device=device, dtype=torch.bfloat16) |
There was a problem hiding this comment.
Hardcoding torch.bfloat16 will silently down-cast on fp32 models (or up-cast on fp16). Use self.tts_eos_embed.dtype or the model dtype instead — applies to line 936 too.
| if missing > 0: | ||
| tail = thinker_decode_embed[-min(cur_len, missing) :] | ||
| cached_thinker_decode_embeds = torch.cat([cached_thinker_decode_embeds, tail], dim=0) | ||
| elif cur_len > cached_len: |
There was a problem hiding this comment.
This branch is only reachable when missing <= 0, i.e. cached_len >= expected_condition_len — the cache is already full. Replacing it with thinker_decode_embed[-expected_condition_len:] silently discards the existing valid cache. Is this intentional, or is the branch dead in practice? If dead, drop it.
| ) | ||
|
|
||
| update_dict["num_processed_tokens"] = info_dict.get("num_processed_tokens", 0) + span_len | ||
| processed_delta = update_dict.pop("num_processed_tokens_delta", span_len) |
There was a problem hiding this comment.
The fallback to span_len means the prefill path silently relies on never setting this key. Worth an explicit update_dict["num_processed_tokens_delta"] = span_len in the prefill branch so the contract is obvious.
|
|
||
| update_dict = {} | ||
| text_step = model._thinker_decode_to_talker_decode( | ||
| { |
There was a problem hiding this comment.
This test covers the wait step (run_talker_mtp=False) but never follows up with a call where new embeddings arrive and the model resumes. A wait-then-resume round-trip would strengthen coverage of the state machine.
|
I observed that after enabling |
I use main branch to test test 1 is wrong but test 2 is right |
Plz re-test by this way.
Do not full compileoff. |
|
In high-concurrency scenarios, there is some impact on performance. Can we optimize the operation to reduce the impact on performance? |
Signed-off-by: Sy03 <1370724210@qq.com>
713eb01 to
d83cb9c
Compare
Latest Change SummaryChanges Made:
Rationale:
Benchmark Results:Same vLLM 0.17.1 benchmark path, c10 results:
Performance Conclusion:Overall, this patch maintains performance close to main while fixing async handoff correctness. |
00a58b5 to
f6baa1c
Compare
Signed-off-by: Sy03 <1370724210@qq.com>
f6baa1c to
f918ebf
Compare
|
LGTM |
1 similar comment
|
LGTM |
gcanlin
left a comment
There was a problem hiding this comment.
It's better to add a test to avoid regression.
Signed-off-by: Sy03 <1370724210@qq.com>
3c0d9c5 to
d2d3c9a
Compare
Done. Thanks for the advice. |
…llm-project#1758 (vllm-project#2019) Signed-off-by: Zhang <jianmusings@gmail.com>
This minor precision error needs to be re-check like #2274. @amy-why-3459 @LJH-LBJ |
Purpose
Fix the Qwen3-Omni async_chunk text/audio misalignment introduced by #1758 and reported in #1830.
Root cause
The bug is in
Qwen3OmniMoeForConditionalGeneration._thinker_decode_to_talker_decode().In async mode, the old logic used the cumulative
thinker_output_token_idslength to decide how many decode conditions were still available:This is incorrect for Qwen3-Omni async handoff because:
<|im_end|>token is a terminal marker, not a normal text decode condition.As a result, async decode consumed one extra condition step. In practice this caused the talker-side decode embedding sequence to become misaligned with the thinker output tokens, so text could still be correct while the generated codec frames drifted and the final audio no longer matched the text.
For example, for the prompt:
What is the capital of China? Reply with exactly one word.the thinker output is effectively:
Beijing<|im_end|>The correct async decode conditions should be:
ijingeosBut the old logic treated the cumulative token list as if there were two normal text decode conditions before EOS, so async consumed the wrong condition boundary and produced mismatched audio.
What this PR changes
This PR fixes the async decode-state handoff by:
Computing the expected async decode condition length correctly:
<|im_end|>tokenRebuilding / reusing
cached_thinker_decode_embeddingsagainst that corrected condition length, so async consume steps stay aligned with the remaining thinker decode embeddings.Advancing
num_processed_tokensonly when a real text condition is consumed.Skipping
talker_mtpwhen async is only waiting for additional valid decode condition cache.Adding focused unit tests for the async handoff boundary logic.
In short: this PR makes async decode semantics match the sync path’s effective behavior for the remaining text conditions after talker prefill.
Test Plan
1. Unit tests
Added:
tests/model_executor/models/test_qwen3_omni_async_decode.pyCovers:
<|im_end|>from normal consume stepsCommand:
2. Remote async e2e validation
Validated on a real multi-GPU machine using:
What is the capital of China? Reply with exactly one word.Client:
examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.pyTest Result
Before fix
Beijingbefore_fix_async_bad.wavin the PR composer before submissionbefore_fix.wav
After fix
Beijingafter_fix_async_good.wavin the PR composer before submissionfix.wav
Additional notes
qwen3_omni.py.cc @LJH-LBJ