Skip to content

[Bugfix] Length-aware trim of thinker_emb/hid in non-async-chunk Stage-1 builder for Qwen3 Omni#3645

Draft
natureofnature wants to merge 5 commits into
vllm-project:mainfrom
natureofnature:bugfix/refactor/pr1_5/hotfix-applied
Draft

[Bugfix] Length-aware trim of thinker_emb/hid in non-async-chunk Stage-1 builder for Qwen3 Omni#3645
natureofnature wants to merge 5 commits into
vllm-project:mainfrom
natureofnature:bugfix/refactor/pr1_5/hotfix-applied

Conversation

@natureofnature
Copy link
Copy Markdown
Contributor

@natureofnature natureofnature commented May 15, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Fixes a long-output audio drift regression introduced by #2677 (the assumed every finished request added one extra "stop-token" row to the accumulator, but requests that finish by max_tokens (FINISHED_LENGTH_CAPPED) do not run that extra forward, so the trim removed a real conditioning row instead. The talker codec then ran out of conditioning before its audio ended, producing the "test test test ... → T-T-T-T..." stutter observed on Buildkite 9702
test_mix_to_text_audio_001 (commit 82a0b3a4).

The fix replaces the blind [:-1] with a length-aware trim keyed off len(all_token_ids), which matches what _thinker_to_talker_prefill actually indexes downstream.

Behaviour by finish reason:

accumulator rows target_rows result
P + O + 1 (stop_token) P + O trim 1 row (unchanged)
P + O (max_tokens) P + O no trim (fixes regression)
< P + O (under-capture) P + O no trim, safe degrade
0 (defensive) 0 no trim

Also adds logger.warning on unexpected excess rows and logger.debug on under-capture so future invariant drift is observable.

Test Plan

  • New tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py cases lock down all three row-count branches
    (== target, == target + 1, < target). pytest -q of that file: 12 passed.
  • pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py tests/e2e/online_serving/test_qwen3_omni.py -m 'advanced_model and cuda' --run-level 'advanced_model' twice

Test Result

pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py tests/e2e/online_serving/test_qwen3_omni.py -m 'advanced_model and cuda' --run-level 'advanced_model'
1st Run:
Screenshot from 2026-05-15 06-05-25
2nd Run:
Screenshot from 2026-05-15 06-24-32


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

…ge-1 input.

The connector accumulator emits P rows during the prefill forward (one
per prompt token) and 1 row per decode forward step.  When a request
finishes by stop_token, the final emission step adds one extra row;
when it finishes by max_tokens (FINISHED_LENGTH_CAPPED), vLLM does not
run another forward, so there is no extra row.  The previous
unconditional `[:-1]` correctly trimmed the stop-token row but
over-trimmed in the max_tokens path: the talker then ran out of
conditioning before its codec finished, causing long-output drift
(e.g. test_mix_to_text_audio_001 long-repeat regression on
Buildkite 9702 main build, commit 82a0b3a).

Switch the trim to be length-aware, keyed off the talker prefill target
`len(all_token_ids)` -- the downstream _thinker_to_talker_prefill
indexes by ids["all"] length.  This equals prompt + output under the
standard contract, but the baseline thinker2talker path already
reshapes both prompt_token_ids and output_ids via
_get_streaming_talker_tokens when a streaming context is active, so
all_token_ids and prompt + output can diverge in streaming /
PD-disagg scenarios.  Aligning the trim target with
len(all_token_ids) keeps the full-payload builder robust to any caller
that builds all_token_ids independently of prompt + output, without
needing to retrofit this site each time.

Behaviour by finish reason:
- stop_token finish (rows = target + 1)   -> trim 1 row (unchanged)
- max_tokens finish (rows = target)       -> no trim (fixes regression)
- accumulator under-capture (rows < target) -> keep all rows (safe degrade)
- empty payload (target_rows <= 0)        -> keep all rows (defensive)

Also lift output_token_ids retrieval out of the all_token_ids fallback
so it is always defined regardless of which token-id source is present.

Add observability:
- logger.warning when accumulator rows exceed target by more than one
  (catches future invariant drift)
- logger.debug when under-captured (catches silent upstream frame loss)

Tests:
- Extend test_thinker2talker_full_payload_packs_complete_tensors to
  lock down shape == target for the rows == target case.
- Add test_thinker2talker_full_payload_trims_excess_stop_token_row to
  verify the rows == target + 1 trim still works (prevents regression
  to unconditional [:-1] not catching the stop-token row).
- Add test_thinker2talker_full_payload_preserves_under_capture to
  cover the rows < target safe-degrade path.

Signed-off-by: natureofnature <wzliu@connect.hku.hk>
@natureofnature natureofnature changed the title [BugFix] Length-aware trim of thinker_emb/hid in non-async-chunk Stage-1 builder [BugFix] Length-aware trim of thinker_emb/hid in non-async-chunk Stage-1 builder for Qwen3 Omni May 15, 2026
@hsliuustc0106 hsliuustc0106 added ready label to trigger buildkite CI merge-test label to trigger buildkite merge test CI and removed ready label to trigger buildkite CI labels May 15, 2026
Comment thread vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
…l_payload.

The output_token_ids variable was hoisted out of the all_token_ids
fallback branch in the prior commit (3c0ed9c), when an earlier draft
of the trim used `len(prompt_token_ids) + len(output_token_ids)`.

After switching the trim target to `len(all_token_ids)`, the hoist no
longer has any caller outside the fallback branch.  The hoist became
dead code -- this commit reverts it to keep the diff surgical and
scoped to the actual fix.

No functional change: output_token_ids is still computed in the
fallback when all_token_ids is empty, exactly as it was before
3c0ed9c.

Signed-off-by: natureofnature <wzliu@connect.hku.hk>
@natureofnature natureofnature force-pushed the bugfix/refactor/pr1_5/hotfix-applied branch from 7c0083b to 004dc4b Compare May 15, 2026 14:23
@natureofnature
Copy link
Copy Markdown
Contributor Author

@codex review

Comment thread vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 4705e1f7c7

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread vllm_omni/model_executor/stage_input_processors/qwen3_omni.py Outdated
@natureofnature natureofnature force-pushed the bugfix/refactor/pr1_5/hotfix-applied branch from b50a804 to fb9f993 Compare May 15, 2026 17:28
@natureofnature
Copy link
Copy Markdown
Contributor Author

@codex review

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1442dfe10c

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
@natureofnature natureofnature force-pushed the bugfix/refactor/pr1_5/hotfix-applied branch from 1442dfe to 8067795 Compare May 15, 2026 17:59
@natureofnature
Copy link
Copy Markdown
Contributor Author

@codex review

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b9cc31420a

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread vllm_omni/model_executor/stage_input_processors/qwen3_omni.py Outdated
…ion row.

Codex P1 review on the prior commit (`3c0ed9c6`) identified that the
length-aware trim using `len(all_token_ids)` skipped trimming for
FINISHED_STOPPED requests, reintroducing the original `fba23325`
spurious-phoneme regression on short stop-finished outputs.

Mechanism:
- vLLM appends the sampled token to `request.output_token_ids` BEFORE
  `check_stop` (`vllm/v1/core/sched/scheduler.py:1641-1651`), so for a
  stop-token finished request `len(all_token_ids)` includes the
  stop-token slot.
- The accumulator captured the forward step that produced that
  stop_token, so its row count equals `len(all_token_ids)`.
- With `target_rows = len(all_token_ids)` the trim is a no-op and the
  stop emission's hidden state leaks into talker prefill.
- For FINISHED_LENGTH_CAPPED (max_tokens) requests vLLM does NOT run
  another forward, so accumulator rows == len(all_token_ids) and no
  extra trim is needed.  The prior unconditional `[:-1]` over-trimmed
  here (BK 9702 long-repeat regression).

Fix:
- Detect FINISHED_STOPPED via exact match on `request.status.name`
  (set by `check_stop` at `vllm/v1/core/sched/utils.py:103-117`); drop
  one extra row only on that path.
- Status-name extraction handles both `RequestStatus.FINISHED_STOPPED`
  and plain `FINISHED_STOPPED` forms via `rsplit(".", 1)[-1]`.
- Worker-side `CachedRequestState` does not carry `.status` in vLLM v1
  (`vllm/v1/worker/gpu_input_batch.py:30-79`), so for the production
  path we fall back to a last-token-in-stop-set heuristic.  The
  fallback respects `sampling_params.ignore_eos`: when `ignore_eos=True`
  vLLM continues past EOS, so EOS is excluded from the stop set --
  this prevents a length-capped finish whose last sampled token
  happens to equal EOS from being incorrectly trimmed.

Residual limitation (filed as follow-up):
- A propagated finish reason from the scheduler is the only way to
  fully disambiguate `FINISHED_LENGTH_CAPPED` from stop-token
  completion when the last token coincidentally matches a custom
  `stop_token_id`.  In vLLM's current `check_stop` ordering the
  `stop_token_ids` check runs before the length cap, so a last-token
  match in `stop_token_ids` is unambiguously a stop finish in
  practice; production qwen3-omni paths don't configure custom
  `stop_token_ids`, so the residual is theoretical.

Behaviour by finish reason:
- stop_token  (status=FINISHED_STOPPED)        -> trim 1 (talker gets P+O-1)
- max_tokens  (status=FINISHED_LENGTH_CAPPED)  -> no trim (talker gets P+O)
- status absent + last token in stop_ids + not ignore_eos  -> trim 1 (fallback)
- status absent + last token == EOS + ignore_eos=True      -> no trim (fallback gated)
- status absent + last token not in stop_ids               -> no trim
- under-capture                                             -> no trim, safe degrade
- defensive: target_rows <= 0                              -> no trim (don't slice to zero)

Tests (17 total, was 12):
- `..._drops_stop_emission_row_when_finished_stopped`: status path
- `..._drops_stop_emission_via_eos_fallback`: fallback path (default ignore_eos=False)
- `..._no_drop_when_finished_length_capped`: BK 9702 guard
- `..._no_drop_when_length_capped_with_trailing_eos`: status precedence guard
- `..._no_drop_when_ignore_eos_and_trailing_eos`: ignore_eos respects vLLM semantics

Signed-off-by: natureofnature <wzliu@connect.hku.hk>
@natureofnature natureofnature force-pushed the bugfix/refactor/pr1_5/hotfix-applied branch from b9cc314 to 9098fca Compare May 16, 2026 01:27
@natureofnature
Copy link
Copy Markdown
Contributor Author

@codex review

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex Review: Didn't find any major issues. Bravo.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@Gaohan123 Gaohan123 added this to the v0.22.0 milestone May 16, 2026
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One observation on the status-name extraction at line ~517 (str(status).rsplit(".", 1)[-1]): this depends on the enum's string representation. If vLLM changes the status enum __str__, it silently falls through to the last-token heuristic. Since the fallback is reasonable this is not blocking, but consider importing the actual enum type if feasible from the omni plugin context.

Otherwise LGTM — the fix is well-reasoned, the tests cover all row-count branches, and the e2e evidence is solid.

sampling_params = getattr(request, "sampling_params", None)
if sampling_params is not None:
stop_ids: set[int] = set()
ignore_eos = bool(getattr(sampling_params, "ignore_eos", False))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to discuss the testing scenarios and solutions in more detail.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Then let's delay the merge until we have exactly the matching testing cases.

@natureofnature
Copy link
Copy Markdown
Contributor Author

natureofnature commented May 18, 2026

As discussed with @amy-why-3459 , delay merge since the failure happens in very rare cases and the implementation is too complicated. Will be reopen when the matching test case is integrated and more general approach is introduced.

@natureofnature natureofnature changed the title [BugFix] Length-aware trim of thinker_emb/hid in non-async-chunk Stage-1 builder for Qwen3 Omni [WIP] Length-aware trim of thinker_emb/hid in non-async-chunk Stage-1 builder for Qwen3 Omni May 18, 2026
@natureofnature natureofnature marked this pull request as draft May 18, 2026 03:20
@natureofnature natureofnature changed the title [WIP] Length-aware trim of thinker_emb/hid in non-async-chunk Stage-1 builder for Qwen3 Omni [Bugfix] Length-aware trim of thinker_emb/hid in non-async-chunk Stage-1 builder for Qwen3 Omni May 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge-test label to trigger buildkite merge test CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants