From 3c0ed9c6bb1025fa5cda2a2930d4cab72ad3bc7e Mon Sep 17 00:00:00 2001 From: natureofnature Date: Fri, 15 May 2026 08:15:06 +0000 Subject: [PATCH 1/3] [BugFix] Length-aware trim of thinker_emb/hid for non-async-chunk Stage-1 input. The connector accumulator emits P rows during the prefill forward (one per prompt token) and 1 row per decode forward step. When a request finishes by stop_token, the final emission step adds one extra row; when it finishes by max_tokens (FINISHED_LENGTH_CAPPED), vLLM does not run another forward, so there is no extra row. The previous unconditional `[:-1]` correctly trimmed the stop-token row but over-trimmed in the max_tokens path: the talker then ran out of conditioning before its codec finished, causing long-output drift (e.g. test_mix_to_text_audio_001 long-repeat regression on Buildkite 9702 main build, commit 82a0b3a4). Switch the trim to be length-aware, keyed off the talker prefill target `len(all_token_ids)` -- the downstream _thinker_to_talker_prefill indexes by ids["all"] length. This equals prompt + output under the standard contract, but the baseline thinker2talker path already reshapes both prompt_token_ids and output_ids via _get_streaming_talker_tokens when a streaming context is active, so all_token_ids and prompt + output can diverge in streaming / PD-disagg scenarios. Aligning the trim target with len(all_token_ids) keeps the full-payload builder robust to any caller that builds all_token_ids independently of prompt + output, without needing to retrofit this site each time. Behaviour by finish reason: - stop_token finish (rows = target + 1) -> trim 1 row (unchanged) - max_tokens finish (rows = target) -> no trim (fixes regression) - accumulator under-capture (rows < target) -> keep all rows (safe degrade) - empty payload (target_rows <= 0) -> keep all rows (defensive) Also lift output_token_ids retrieval out of the all_token_ids fallback so it is always defined regardless of which token-id source is present. Add observability: - logger.warning when accumulator rows exceed target by more than one (catches future invariant drift) - logger.debug when under-captured (catches silent upstream frame loss) Tests: - Extend test_thinker2talker_full_payload_packs_complete_tensors to lock down shape == target for the rows == target case. - Add test_thinker2talker_full_payload_trims_excess_stop_token_row to verify the rows == target + 1 trim still works (prevents regression to unconditional [:-1] not catching the stop-token row). - Add test_thinker2talker_full_payload_preserves_under_capture to cover the rows < target safe-degrade path. Signed-off-by: natureofnature --- .../test_qwen3_omni_streaming_helpers.py | 46 +++++++++++++ .../stage_input_processors/qwen3_omni.py | 68 +++++++++++++------ 2 files changed, 95 insertions(+), 19 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index f11a4654ec2..4ad48b9d47d 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -168,6 +168,7 @@ def test_talker2code2wav_full_payload_keeps_all_zero_codec_rows() -> None: def test_thinker2talker_full_payload_packs_complete_tensors() -> None: + """Standard max_tokens finish path: rows == target → no trim.""" request = SimpleNamespace( request_id="thinker", prompt_token_ids=[151644, 872], @@ -187,3 +188,48 @@ def test_thinker2talker_full_payload_packs_complete_tensors() -> None: assert payload["embed"]["prefill"].device.type == "cpu" assert payload["hidden_states"]["output"].device.type == "cpu" assert payload["next_stage_prompt_len"] > 0 + # Lock down the no-trim invariant for rows == target. + assert payload["embed"]["prefill"].shape[0] == 3 + assert payload["hidden_states"]["output"].shape[0] == 3 + + +def test_thinker2talker_full_payload_trims_excess_stop_token_row() -> None: + """Stop_token finish path: rows == target + 1 → trim trailing row.""" + request = SimpleNamespace( + request_id="thinker-stop", + prompt_token_ids=[151644, 872], + output_token_ids=[3], + all_token_ids=[151644, 872, 3], + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(4, 2), + "hidden_states.layer_24": torch.full((4, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 3 + assert payload["hidden_states"]["output"].shape[0] == 3 + + +def test_thinker2talker_full_payload_preserves_under_capture() -> None: + """Under-capture path: rows < target → no trim, safe degrade.""" + request = SimpleNamespace( + request_id="thinker-undercap", + prompt_token_ids=[151644, 872], + output_token_ids=[3], + all_token_ids=[151644, 872, 3], + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(2, 2), + "hidden_states.layer_24": torch.full((2, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 2 + assert payload["hidden_states"]["output"].shape[0] == 2 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index b1672612bf3..939a2714689 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -483,29 +483,59 @@ def thinker2talker_full_payload( return None prompt_token_ids = _ensure_list(getattr(request, "prompt_token_ids", []) or []) + output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) all_token_ids = _ensure_list(getattr(request, "all_token_ids", None) or []) if not all_token_ids: - output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) all_token_ids = list(prompt_token_ids) + list(output_token_ids) - # Trim the trailing stop-token row from the accumulated thinker output. - # The accumulator captures one hidden-state row per executed thinker - # forward (prefill + every decode step including the one that emitted - # the stop_token), so for a finished request thinker_emb has exactly one - # row more than the rows the talker should consume. async_chunk's - # chunk-0 path naturally captures only the prefill / non-stop portion, - # which is why the [async_chunk] parametrization passes while [default] - # over-generates one codec frame on short outputs (e.g. - # test_one_word_prompt_001[default]: audio extends "London" with - # spurious phonemes). - if isinstance(thinker_emb, torch.Tensor) and thinker_emb.shape[0] > 0: - thinker_emb_prefill = thinker_emb[:-1] - else: - thinker_emb_prefill = thinker_emb - if isinstance(thinker_hid, torch.Tensor) and thinker_hid.shape[0] > 0: - thinker_hid_prefill = thinker_hid[:-1] - else: - thinker_hid_prefill = thinker_hid + # Length-aware trim of the accumulated thinker output. + # The accumulator emits P rows during the prefill forward (one per + # prompt token) and 1 row per decode forward step. When the request + # finishes by stop_token, the final emission step adds one extra row; + # when it finishes by max_tokens (FINISHED_LENGTH_CAPPED), vLLM does + # not run another forward, so there is no extra row. The previous + # unconditional `[:-1]` correctly trimmed the stop-token row but + # over-trimmed in the max_tokens path -- the talker then ran out of + # conditioning before its codec finished, causing long-output drift + # (e.g. test_mix_to_text_audio_001 long-repeat regression on BK 9702). + # Target the talker's actual prefill consumption length -- the + # downstream `_thinker_to_talker_prefill` indexes by + # `len(ids["all"])`, which equals `prompt + output` in the standard + # contract but can diverge under streaming / PD-disagg paths. Using + # `len(all_token_ids)` keeps the trim aligned with the consumer. + target_rows = len(all_token_ids) + + def _trim_to_target(t): + if not isinstance(t, torch.Tensor) or t.dim() < 1 or t.shape[0] == 0: + return t + if target_rows <= 0: + # Defensive: an empty prompt+output should not reach this + # builder (the pooler short-circuits upstream), but guard + # against slicing valid rows to zero if some future caller + # ever invokes us in that state. + return t + if t.shape[0] > target_rows + 1: + logger.warning( + "thinker2talker_full_payload: unexpected excess rows " + "(got %d, target %d) for req=%s; trimming to target", + int(t.shape[0]), + target_rows, + getattr(request, "request_id", None), + ) + if t.shape[0] > target_rows: + return t[:target_rows] + if t.shape[0] < target_rows: + logger.debug( + "thinker2talker_full_payload: under-captured rows " + "(got %d, target %d) for req=%s; talker may index past end", + int(t.shape[0]), + target_rows, + getattr(request, "request_id", None), + ) + return t + + thinker_emb_prefill = _trim_to_target(thinker_emb) + thinker_hid_prefill = _trim_to_target(thinker_hid) payload: OmniPayload = { "embed": { From 004dc4bd691afbf3064f6063c9d1a9b6a76a3b2d Mon Sep 17 00:00:00 2001 From: natureofnature Date: Fri, 15 May 2026 14:17:28 +0000 Subject: [PATCH 2/3] [Refactor] Drop leftover output_token_ids hoist in thinker2talker_full_payload. The output_token_ids variable was hoisted out of the all_token_ids fallback branch in the prior commit (3c0ed9c6), when an earlier draft of the trim used `len(prompt_token_ids) + len(output_token_ids)`. After switching the trim target to `len(all_token_ids)`, the hoist no longer has any caller outside the fallback branch. The hoist became dead code -- this commit reverts it to keep the diff surgical and scoped to the actual fix. No functional change: output_token_ids is still computed in the fallback when all_token_ids is empty, exactly as it was before 3c0ed9c6. Signed-off-by: natureofnature --- vllm_omni/model_executor/stage_input_processors/qwen3_omni.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 939a2714689..26696accda2 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -483,9 +483,9 @@ def thinker2talker_full_payload( return None prompt_token_ids = _ensure_list(getattr(request, "prompt_token_ids", []) or []) - output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) all_token_ids = _ensure_list(getattr(request, "all_token_ids", None) or []) if not all_token_ids: + output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) all_token_ids = list(prompt_token_ids) + list(output_token_ids) # Length-aware trim of the accumulated thinker output. From 9098fca9f7ff1a7da9d98f7de284d161f922c705 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Fri, 15 May 2026 16:56:21 +0000 Subject: [PATCH 3/3] [BugFix] Finish-reason-aware trim of thinker_emb/hid; drop stop-emission row. Codex P1 review on the prior commit (`3c0ed9c6`) identified that the length-aware trim using `len(all_token_ids)` skipped trimming for FINISHED_STOPPED requests, reintroducing the original `fba23325` spurious-phoneme regression on short stop-finished outputs. Mechanism: - vLLM appends the sampled token to `request.output_token_ids` BEFORE `check_stop` (`vllm/v1/core/sched/scheduler.py:1641-1651`), so for a stop-token finished request `len(all_token_ids)` includes the stop-token slot. - The accumulator captured the forward step that produced that stop_token, so its row count equals `len(all_token_ids)`. - With `target_rows = len(all_token_ids)` the trim is a no-op and the stop emission's hidden state leaks into talker prefill. - For FINISHED_LENGTH_CAPPED (max_tokens) requests vLLM does NOT run another forward, so accumulator rows == len(all_token_ids) and no extra trim is needed. The prior unconditional `[:-1]` over-trimmed here (BK 9702 long-repeat regression). Fix: - Detect FINISHED_STOPPED via exact match on `request.status.name` (set by `check_stop` at `vllm/v1/core/sched/utils.py:103-117`); drop one extra row only on that path. - Status-name extraction handles both `RequestStatus.FINISHED_STOPPED` and plain `FINISHED_STOPPED` forms via `rsplit(".", 1)[-1]`. - Worker-side `CachedRequestState` does not carry `.status` in vLLM v1 (`vllm/v1/worker/gpu_input_batch.py:30-79`), so for the production path we fall back to a last-token-in-stop-set heuristic. The fallback respects `sampling_params.ignore_eos`: when `ignore_eos=True` vLLM continues past EOS, so EOS is excluded from the stop set -- this prevents a length-capped finish whose last sampled token happens to equal EOS from being incorrectly trimmed. Residual limitation (filed as follow-up): - A propagated finish reason from the scheduler is the only way to fully disambiguate `FINISHED_LENGTH_CAPPED` from stop-token completion when the last token coincidentally matches a custom `stop_token_id`. In vLLM's current `check_stop` ordering the `stop_token_ids` check runs before the length cap, so a last-token match in `stop_token_ids` is unambiguously a stop finish in practice; production qwen3-omni paths don't configure custom `stop_token_ids`, so the residual is theoretical. Behaviour by finish reason: - stop_token (status=FINISHED_STOPPED) -> trim 1 (talker gets P+O-1) - max_tokens (status=FINISHED_LENGTH_CAPPED) -> no trim (talker gets P+O) - status absent + last token in stop_ids + not ignore_eos -> trim 1 (fallback) - status absent + last token == EOS + ignore_eos=True -> no trim (fallback gated) - status absent + last token not in stop_ids -> no trim - under-capture -> no trim, safe degrade - defensive: target_rows <= 0 -> no trim (don't slice to zero) Tests (17 total, was 12): - `..._drops_stop_emission_row_when_finished_stopped`: status path - `..._drops_stop_emission_via_eos_fallback`: fallback path (default ignore_eos=False) - `..._no_drop_when_finished_length_capped`: BK 9702 guard - `..._no_drop_when_length_capped_with_trailing_eos`: status precedence guard - `..._no_drop_when_ignore_eos_and_trailing_eos`: ignore_eos respects vLLM semantics Signed-off-by: natureofnature --- .../test_qwen3_omni_streaming_helpers.py | 211 +++++++++++++++++- .../stage_input_processors/qwen3_omni.py | 87 ++++++-- 2 files changed, 273 insertions(+), 25 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index 4ad48b9d47d..8ec4f9cda75 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -194,12 +194,14 @@ def test_thinker2talker_full_payload_packs_complete_tensors() -> None: def test_thinker2talker_full_payload_trims_excess_stop_token_row() -> None: - """Stop_token finish path: rows == target + 1 → trim trailing row.""" + """Excess-rows path: rows == target + 1 → trim trailing row.""" request = SimpleNamespace( - request_id="thinker-stop", + request_id="thinker-excess", prompt_token_ids=[151644, 872], output_token_ids=[3], all_token_ids=[151644, 872, 3], + sampling_params=None, + status=None, ) pooling_output = { "hidden_states.layer_0": torch.ones(4, 2), @@ -214,6 +216,211 @@ def test_thinker2talker_full_payload_trims_excess_stop_token_row() -> None: assert payload["hidden_states"]["output"].shape[0] == 3 +def test_thinker2talker_full_payload_drops_stop_emission_row_when_finished_stopped() -> None: + """FINISHED_STOPPED: drop 1 extra row even when rows == target. + + vLLM appends the stop-token to output_token_ids before check_stop, so + len(all_token_ids) includes the stop slot AND the accumulator has the + stop emission's forward row. Both counts equal P+O (here 3). Talker + target should be P+O-1 (=2), not P+O. Without the extra drop the + stop emission's hidden state leaks into talker prefill (fba23325 + spurious-phoneme regression). + """ + request = SimpleNamespace( + request_id="thinker-stop-finished", + prompt_token_ids=[151644, 872], + output_token_ids=[3], + all_token_ids=[151644, 872, 3], + sampling_params=None, + status=SimpleNamespace(name="FINISHED_STOPPED"), + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(3, 2), + "hidden_states.layer_24": torch.full((3, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 2 + assert payload["hidden_states"]["output"].shape[0] == 2 + + +def test_thinker2talker_full_payload_drops_stop_emission_via_eos_fallback() -> None: + """Stop-detection fallback: last token in sampling_params.eos_token_id.""" + EOS = 151645 + request = SimpleNamespace( + request_id="thinker-stop-fallback", + prompt_token_ids=[151644, 872], + output_token_ids=[3, EOS], + all_token_ids=[151644, 872, 3, EOS], + sampling_params=SimpleNamespace(eos_token_id=EOS, stop_token_ids=None), + status=None, + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(4, 2), + "hidden_states.layer_24": torch.full((4, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 3 + assert payload["hidden_states"]["output"].shape[0] == 3 + + +def test_thinker2talker_full_payload_no_drop_when_finished_length_capped() -> None: + """FINISHED_LENGTH_CAPPED (max_tokens): no extra drop; BK 9702 regression guard.""" + request = SimpleNamespace( + request_id="thinker-length-capped", + prompt_token_ids=[151644, 872], + output_token_ids=[3], + all_token_ids=[151644, 872, 3], + sampling_params=SimpleNamespace(eos_token_id=999, stop_token_ids=None), + status=SimpleNamespace(name="FINISHED_LENGTH_CAPPED"), + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(3, 2), + "hidden_states.layer_24": torch.full((3, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 3 + assert payload["hidden_states"]["output"].shape[0] == 3 + + +def test_thinker2talker_full_payload_drops_via_private_eos_field() -> None: + """Worker-side sampling_params where the public `eos_token_id` property is + None but the private `_eos_token_id` / `_all_stop_token_ids` carry the + primary EOS (the msgspec-deserialization shape on the worker boundary). + + The fallback must read the private fields to detect the stop. + """ + EOS = 151643 + request = SimpleNamespace( + request_id="thinker-private-eos", + prompt_token_ids=[151644, 872], + output_token_ids=[3, EOS], + all_token_ids=[151644, 872, 3, EOS], + # Public `eos_token_id` looks empty; only the private fields carry it. + sampling_params=SimpleNamespace( + eos_token_id=None, + stop_token_ids=None, + ignore_eos=False, + _eos_token_id=EOS, + _all_stop_token_ids={EOS}, + ), + status=None, + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(4, 2), + "hidden_states.layer_24": torch.full((4, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 3 + assert payload["hidden_states"]["output"].shape[0] == 3 + + +def test_thinker2talker_full_payload_drops_via_all_stop_token_ids() -> None: + """Secondary EOS only in `_all_stop_token_ids` (not in `_eos_token_id`): + multi-EOS Qwen3 case where the model finished on a secondary EOS. + """ + SECONDARY_EOS = 151645 + request = SimpleNamespace( + request_id="thinker-secondary-eos", + prompt_token_ids=[151644, 872], + output_token_ids=[3, SECONDARY_EOS], + all_token_ids=[151644, 872, 3, SECONDARY_EOS], + sampling_params=SimpleNamespace( + eos_token_id=151643, # primary, not the one we hit + stop_token_ids=None, + ignore_eos=False, + _eos_token_id=151643, + _all_stop_token_ids={151643, SECONDARY_EOS}, + ), + status=None, + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(4, 2), + "hidden_states.layer_24": torch.full((4, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 3 + assert payload["hidden_states"]["output"].shape[0] == 3 + + +def test_thinker2talker_full_payload_no_drop_when_ignore_eos_and_trailing_eos() -> None: + """ignore_eos=True + length-capped + last token == EOS: no drop. + + Production worker uses CachedRequestState (no `.status` field), so + the status path doesn't catch this case; we rely on the + `sampling_params.ignore_eos` flag in the fallback to suppress the + EOS-as-stop heuristic. + """ + EOS = 151645 + request = SimpleNamespace( + request_id="thinker-ignore-eos-trailing-eos", + prompt_token_ids=[151644, 872], + output_token_ids=[3, EOS], + all_token_ids=[151644, 872, 3, EOS], + sampling_params=SimpleNamespace(eos_token_id=EOS, stop_token_ids=None, ignore_eos=True), + status=None, # production worker state has no status + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(4, 2), + "hidden_states.layer_24": torch.full((4, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 4 + assert payload["hidden_states"]["output"].shape[0] == 4 + + +def test_thinker2talker_full_payload_no_drop_when_length_capped_with_trailing_eos() -> None: + """FINISHED_LENGTH_CAPPED + last token == EOS coincidence: no drop. + + Status path takes precedence over last-token heuristic. Without + this guard the fallback would incorrectly drop a row when a length-capped + request happens to end on the EOS token id. + """ + EOS = 151645 + request = SimpleNamespace( + request_id="thinker-len-cap-trailing-eos", + prompt_token_ids=[151644, 872], + output_token_ids=[3, EOS], + all_token_ids=[151644, 872, 3, EOS], + sampling_params=SimpleNamespace(eos_token_id=EOS, stop_token_ids=None), + status=SimpleNamespace(name="FINISHED_LENGTH_CAPPED"), + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(4, 2), + "hidden_states.layer_24": torch.full((4, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 4 + assert payload["hidden_states"]["output"].shape[0] == 4 + + def test_thinker2talker_full_payload_preserves_under_capture() -> None: """Under-capture path: rows < target → no trim, safe degrade.""" request = SimpleNamespace( diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 26696accda2..5be79ef2eac 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -483,43 +483,83 @@ def thinker2talker_full_payload( return None prompt_token_ids = _ensure_list(getattr(request, "prompt_token_ids", []) or []) + output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) all_token_ids = _ensure_list(getattr(request, "all_token_ids", None) or []) if not all_token_ids: - output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) all_token_ids = list(prompt_token_ids) + list(output_token_ids) - # Length-aware trim of the accumulated thinker output. - # The accumulator emits P rows during the prefill forward (one per - # prompt token) and 1 row per decode forward step. When the request - # finishes by stop_token, the final emission step adds one extra row; - # when it finishes by max_tokens (FINISHED_LENGTH_CAPPED), vLLM does - # not run another forward, so there is no extra row. The previous - # unconditional `[:-1]` correctly trimmed the stop-token row but - # over-trimmed in the max_tokens path -- the talker then ran out of - # conditioning before its codec finished, causing long-output drift - # (e.g. test_mix_to_text_audio_001 long-repeat regression on BK 9702). - # Target the talker's actual prefill consumption length -- the - # downstream `_thinker_to_talker_prefill` indexes by - # `len(ids["all"])`, which equals `prompt + output` in the standard - # contract but can diverge under streaming / PD-disagg paths. Using - # `len(all_token_ids)` keeps the trim aligned with the consumer. - target_rows = len(all_token_ids) + # Length-aware trim of accumulated thinker output, finish-reason-aware. + # vLLM appends the sampled token to `output_token_ids` BEFORE + # `check_stop` (scheduler.py:1641-1651), so a stop-finished request + # has accumulator_rows == len(all_token_ids) including the stop + # emission row -- the talker must NOT consume that row (fba23325 + # spurious-phoneme regression). Max-token finishes do not append + # an extra forward, so no drop is needed (BK 9702 long-output + # regression). Primary: distinguish via `request.status`. Fallback + # only when status is absent: last-token-in-stop-id heuristic. + status = getattr(request, "status", None) + status_name = getattr(status, "name", None) or "" + if not status_name and status is not None: + status_name = str(status).rsplit(".", 1)[-1] + stop_emission_drop = 1 if status_name == "FINISHED_STOPPED" else 0 + if stop_emission_drop == 0 and not status_name and output_token_ids: + # Worker-side CachedRequestState has no `.status` field in vLLM + # v1, so this fallback runs for every production request. When + # `sampling_params.ignore_eos=True` vLLM continues past EOS, so + # a length-capped finish whose last sampled token coincidentally + # equals EOS must NOT be trimmed -- skip EOS from the stop set + # in that case. Custom `stop_token_ids` are still treated as + # stops; vLLM's `check_stop` runs stop-id matching before the + # length cap and ignores `ignore_eos` for `stop_token_ids`, so + # a last-token match there is unambiguously a stop finish. + sampling_params = getattr(request, "sampling_params", None) + if sampling_params is not None: + stop_ids: set[int] = set() + ignore_eos = bool(getattr(sampling_params, "ignore_eos", False)) + # Custom stop_token_ids always trigger stop in vLLM, regardless + # of ignore_eos (vLLM v1: `update_from_generation_config` writes + # secondary EOSes here too). Read the public list. + for sid in getattr(sampling_params, "stop_token_ids", None) or (): + if isinstance(sid, int): + stop_ids.add(sid) + # EOS sources are only stops when ignore_eos=False. Read both + # the public @property (`eos_token_id`, `all_stop_token_ids`) + # AND the private fields (`_eos_token_id`, `_all_stop_token_ids`) + # because property behavior can vary across msgspec serialization + # boundaries while the private fields are always serialized. + if not ignore_eos: + for eos in ( + getattr(sampling_params, "eos_token_id", None), + getattr(sampling_params, "_eos_token_id", None), + ): + if isinstance(eos, int): + stop_ids.add(eos) + for sid in ( + getattr(sampling_params, "all_stop_token_ids", None) + or getattr(sampling_params, "_all_stop_token_ids", None) + or () + ): + if isinstance(sid, int): + stop_ids.add(sid) + if stop_ids and output_token_ids[-1] in stop_ids: + stop_emission_drop = 1 + target_rows = max(0, len(all_token_ids) - stop_emission_drop) def _trim_to_target(t): if not isinstance(t, torch.Tensor) or t.dim() < 1 or t.shape[0] == 0: return t if target_rows <= 0: - # Defensive: an empty prompt+output should not reach this - # builder (the pooler short-circuits upstream), but guard - # against slicing valid rows to zero if some future caller - # ever invokes us in that state. + # Defensive: empty prompt+output (or stop-only output) should + # not reach this builder; keep all rows rather than slicing + # to zero. return t if t.shape[0] > target_rows + 1: logger.warning( "thinker2talker_full_payload: unexpected excess rows " - "(got %d, target %d) for req=%s; trimming to target", + "(got %d, target %d, stop_drop %d) for req=%s; trimming to target", int(t.shape[0]), target_rows, + stop_emission_drop, getattr(request, "request_id", None), ) if t.shape[0] > target_rows: @@ -527,9 +567,10 @@ def _trim_to_target(t): if t.shape[0] < target_rows: logger.debug( "thinker2talker_full_payload: under-captured rows " - "(got %d, target %d) for req=%s; talker may index past end", + "(got %d, target %d, stop_drop %d) for req=%s; talker may index past end", int(t.shape[0]), target_rows, + stop_emission_drop, getattr(request, "request_id", None), ) return t