Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def test_talker2code2wav_full_payload_keeps_all_zero_codec_rows() -> None:


def test_thinker2talker_full_payload_packs_complete_tensors() -> None:
"""Standard max_tokens finish path: rows == target → no trim."""
request = SimpleNamespace(
request_id="thinker",
prompt_token_ids=[151644, 872],
Expand All @@ -187,3 +188,255 @@ def test_thinker2talker_full_payload_packs_complete_tensors() -> None:
assert payload["embed"]["prefill"].device.type == "cpu"
assert payload["hidden_states"]["output"].device.type == "cpu"
assert payload["next_stage_prompt_len"] > 0
# Lock down the no-trim invariant for rows == target.
assert payload["embed"]["prefill"].shape[0] == 3
assert payload["hidden_states"]["output"].shape[0] == 3


def test_thinker2talker_full_payload_trims_excess_stop_token_row() -> None:
"""Excess-rows path: rows == target + 1 → trim trailing row."""
request = SimpleNamespace(
request_id="thinker-excess",
prompt_token_ids=[151644, 872],
output_token_ids=[3],
all_token_ids=[151644, 872, 3],
sampling_params=None,
status=None,
)
pooling_output = {
"hidden_states.layer_0": torch.ones(4, 2),
"hidden_states.layer_24": torch.full((4, 2), 2.0),
"embed.tts_bos": torch.zeros(1, 2),
}

payload = q3.thinker2talker_full_payload(None, pooling_output, request)

assert payload is not None
assert payload["embed"]["prefill"].shape[0] == 3
assert payload["hidden_states"]["output"].shape[0] == 3


def test_thinker2talker_full_payload_drops_stop_emission_row_when_finished_stopped() -> None:
"""FINISHED_STOPPED: drop 1 extra row even when rows == target.

vLLM appends the stop-token to output_token_ids before check_stop, so
len(all_token_ids) includes the stop slot AND the accumulator has the
stop emission's forward row. Both counts equal P+O (here 3). Talker
target should be P+O-1 (=2), not P+O. Without the extra drop the
stop emission's hidden state leaks into talker prefill (fba23325
spurious-phoneme regression).
"""
request = SimpleNamespace(
request_id="thinker-stop-finished",
prompt_token_ids=[151644, 872],
output_token_ids=[3],
all_token_ids=[151644, 872, 3],
sampling_params=None,
status=SimpleNamespace(name="FINISHED_STOPPED"),
)
pooling_output = {
"hidden_states.layer_0": torch.ones(3, 2),
"hidden_states.layer_24": torch.full((3, 2), 2.0),
"embed.tts_bos": torch.zeros(1, 2),
}

payload = q3.thinker2talker_full_payload(None, pooling_output, request)

assert payload is not None
assert payload["embed"]["prefill"].shape[0] == 2
assert payload["hidden_states"]["output"].shape[0] == 2


def test_thinker2talker_full_payload_drops_stop_emission_via_eos_fallback() -> None:
"""Stop-detection fallback: last token in sampling_params.eos_token_id."""
EOS = 151645
request = SimpleNamespace(
request_id="thinker-stop-fallback",
prompt_token_ids=[151644, 872],
output_token_ids=[3, EOS],
all_token_ids=[151644, 872, 3, EOS],
sampling_params=SimpleNamespace(eos_token_id=EOS, stop_token_ids=None),
status=None,
)
pooling_output = {
"hidden_states.layer_0": torch.ones(4, 2),
"hidden_states.layer_24": torch.full((4, 2), 2.0),
"embed.tts_bos": torch.zeros(1, 2),
}

payload = q3.thinker2talker_full_payload(None, pooling_output, request)

assert payload is not None
assert payload["embed"]["prefill"].shape[0] == 3
assert payload["hidden_states"]["output"].shape[0] == 3


def test_thinker2talker_full_payload_no_drop_when_finished_length_capped() -> None:
"""FINISHED_LENGTH_CAPPED (max_tokens): no extra drop; BK 9702 regression guard."""
request = SimpleNamespace(
request_id="thinker-length-capped",
prompt_token_ids=[151644, 872],
output_token_ids=[3],
all_token_ids=[151644, 872, 3],
sampling_params=SimpleNamespace(eos_token_id=999, stop_token_ids=None),
status=SimpleNamespace(name="FINISHED_LENGTH_CAPPED"),
)
pooling_output = {
"hidden_states.layer_0": torch.ones(3, 2),
"hidden_states.layer_24": torch.full((3, 2), 2.0),
"embed.tts_bos": torch.zeros(1, 2),
}

payload = q3.thinker2talker_full_payload(None, pooling_output, request)

assert payload is not None
assert payload["embed"]["prefill"].shape[0] == 3
assert payload["hidden_states"]["output"].shape[0] == 3


def test_thinker2talker_full_payload_drops_via_private_eos_field() -> None:
"""Worker-side sampling_params where the public `eos_token_id` property is
None but the private `_eos_token_id` / `_all_stop_token_ids` carry the
primary EOS (the msgspec-deserialization shape on the worker boundary).

The fallback must read the private fields to detect the stop.
"""
EOS = 151643
request = SimpleNamespace(
request_id="thinker-private-eos",
prompt_token_ids=[151644, 872],
output_token_ids=[3, EOS],
all_token_ids=[151644, 872, 3, EOS],
# Public `eos_token_id` looks empty; only the private fields carry it.
sampling_params=SimpleNamespace(
eos_token_id=None,
stop_token_ids=None,
ignore_eos=False,
_eos_token_id=EOS,
_all_stop_token_ids={EOS},
),
status=None,
)
pooling_output = {
"hidden_states.layer_0": torch.ones(4, 2),
"hidden_states.layer_24": torch.full((4, 2), 2.0),
"embed.tts_bos": torch.zeros(1, 2),
}

payload = q3.thinker2talker_full_payload(None, pooling_output, request)

assert payload is not None
assert payload["embed"]["prefill"].shape[0] == 3
assert payload["hidden_states"]["output"].shape[0] == 3


def test_thinker2talker_full_payload_drops_via_all_stop_token_ids() -> None:
"""Secondary EOS only in `_all_stop_token_ids` (not in `_eos_token_id`):
multi-EOS Qwen3 case where the model finished on a secondary EOS.
"""
SECONDARY_EOS = 151645
request = SimpleNamespace(
request_id="thinker-secondary-eos",
prompt_token_ids=[151644, 872],
output_token_ids=[3, SECONDARY_EOS],
all_token_ids=[151644, 872, 3, SECONDARY_EOS],
sampling_params=SimpleNamespace(
eos_token_id=151643, # primary, not the one we hit
stop_token_ids=None,
ignore_eos=False,
_eos_token_id=151643,
_all_stop_token_ids={151643, SECONDARY_EOS},
),
status=None,
)
pooling_output = {
"hidden_states.layer_0": torch.ones(4, 2),
"hidden_states.layer_24": torch.full((4, 2), 2.0),
"embed.tts_bos": torch.zeros(1, 2),
}

payload = q3.thinker2talker_full_payload(None, pooling_output, request)

assert payload is not None
assert payload["embed"]["prefill"].shape[0] == 3
assert payload["hidden_states"]["output"].shape[0] == 3


def test_thinker2talker_full_payload_no_drop_when_ignore_eos_and_trailing_eos() -> None:
"""ignore_eos=True + length-capped + last token == EOS: no drop.

Production worker uses CachedRequestState (no `.status` field), so
the status path doesn't catch this case; we rely on the
`sampling_params.ignore_eos` flag in the fallback to suppress the
EOS-as-stop heuristic.
"""
EOS = 151645
request = SimpleNamespace(
request_id="thinker-ignore-eos-trailing-eos",
prompt_token_ids=[151644, 872],
output_token_ids=[3, EOS],
all_token_ids=[151644, 872, 3, EOS],
sampling_params=SimpleNamespace(eos_token_id=EOS, stop_token_ids=None, ignore_eos=True),
status=None, # production worker state has no status
)
pooling_output = {
"hidden_states.layer_0": torch.ones(4, 2),
"hidden_states.layer_24": torch.full((4, 2), 2.0),
"embed.tts_bos": torch.zeros(1, 2),
}

payload = q3.thinker2talker_full_payload(None, pooling_output, request)

assert payload is not None
assert payload["embed"]["prefill"].shape[0] == 4
assert payload["hidden_states"]["output"].shape[0] == 4


def test_thinker2talker_full_payload_no_drop_when_length_capped_with_trailing_eos() -> None:
"""FINISHED_LENGTH_CAPPED + last token == EOS coincidence: no drop.

Status path takes precedence over last-token heuristic. Without
this guard the fallback would incorrectly drop a row when a length-capped
request happens to end on the EOS token id.
"""
EOS = 151645
request = SimpleNamespace(
request_id="thinker-len-cap-trailing-eos",
prompt_token_ids=[151644, 872],
output_token_ids=[3, EOS],
all_token_ids=[151644, 872, 3, EOS],
sampling_params=SimpleNamespace(eos_token_id=EOS, stop_token_ids=None),
status=SimpleNamespace(name="FINISHED_LENGTH_CAPPED"),
)
pooling_output = {
"hidden_states.layer_0": torch.ones(4, 2),
"hidden_states.layer_24": torch.full((4, 2), 2.0),
"embed.tts_bos": torch.zeros(1, 2),
}

payload = q3.thinker2talker_full_payload(None, pooling_output, request)

assert payload is not None
assert payload["embed"]["prefill"].shape[0] == 4
assert payload["hidden_states"]["output"].shape[0] == 4


def test_thinker2talker_full_payload_preserves_under_capture() -> None:
"""Under-capture path: rows < target → no trim, safe degrade."""
request = SimpleNamespace(
request_id="thinker-undercap",
prompt_token_ids=[151644, 872],
output_token_ids=[3],
all_token_ids=[151644, 872, 3],
)
pooling_output = {
"hidden_states.layer_0": torch.ones(2, 2),
"hidden_states.layer_24": torch.full((2, 2), 2.0),
"embed.tts_bos": torch.zeros(1, 2),
}

payload = q3.thinker2talker_full_payload(None, pooling_output, request)

assert payload is not None
assert payload["embed"]["prefill"].shape[0] == 2
assert payload["hidden_states"]["output"].shape[0] == 2
109 changes: 90 additions & 19 deletions vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,29 +484,100 @@ def thinker2talker_full_payload(
return None

prompt_token_ids = _ensure_list(getattr(request, "prompt_token_ids", []) or [])
output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or [])
all_token_ids = _ensure_list(getattr(request, "all_token_ids", None) or [])
if not all_token_ids:
output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or [])
all_token_ids = list(prompt_token_ids) + list(output_token_ids)

# Trim the trailing stop-token row from the accumulated thinker output.
# The accumulator captures one hidden-state row per executed thinker
# forward (prefill + every decode step including the one that emitted
# the stop_token), so for a finished request thinker_emb has exactly one
# row more than the rows the talker should consume. async_chunk's
# chunk-0 path naturally captures only the prefill / non-stop portion,
# which is why the [async_chunk] parametrization passes while [default]
# over-generates one codec frame on short outputs (e.g.
# test_one_word_prompt_001[default]: audio extends "London" with
# spurious phonemes).
if isinstance(thinker_emb, torch.Tensor) and thinker_emb.shape[0] > 0:
thinker_emb_prefill = thinker_emb[:-1]
else:
thinker_emb_prefill = thinker_emb
if isinstance(thinker_hid, torch.Tensor) and thinker_hid.shape[0] > 0:
thinker_hid_prefill = thinker_hid[:-1]
else:
thinker_hid_prefill = thinker_hid
# Length-aware trim of accumulated thinker output, finish-reason-aware.
# vLLM appends the sampled token to `output_token_ids` BEFORE
# `check_stop` (scheduler.py:1641-1651), so a stop-finished request
# has accumulator_rows == len(all_token_ids) including the stop
# emission row -- the talker must NOT consume that row (fba23325
# spurious-phoneme regression). Max-token finishes do not append
# an extra forward, so no drop is needed (BK 9702 long-output
# regression). Primary: distinguish via `request.status`. Fallback
# only when status is absent: last-token-in-stop-id heuristic.
status = getattr(request, "status", None)
status_name = getattr(status, "name", None) or ""
if not status_name and status is not None:
status_name = str(status).rsplit(".", 1)[-1]
stop_emission_drop = 1 if status_name == "FINISHED_STOPPED" else 0
if stop_emission_drop == 0 and not status_name and output_token_ids:
# Worker-side CachedRequestState has no `.status` field in vLLM
# v1, so this fallback runs for every production request. When
# `sampling_params.ignore_eos=True` vLLM continues past EOS, so
# a length-capped finish whose last sampled token coincidentally
# equals EOS must NOT be trimmed -- skip EOS from the stop set
# in that case. Custom `stop_token_ids` are still treated as
# stops; vLLM's `check_stop` runs stop-id matching before the
# length cap and ignores `ignore_eos` for `stop_token_ids`, so
# a last-token match there is unambiguously a stop finish.
sampling_params = getattr(request, "sampling_params", None)
Comment thread
natureofnature marked this conversation as resolved.
if sampling_params is not None:
stop_ids: set[int] = set()
ignore_eos = bool(getattr(sampling_params, "ignore_eos", False))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to discuss the testing scenarios and solutions in more detail.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Then let's delay the merge until we have exactly the matching testing cases.

# Custom stop_token_ids always trigger stop in vLLM, regardless
# of ignore_eos (vLLM v1: `update_from_generation_config` writes
# secondary EOSes here too). Read the public list.
for sid in getattr(sampling_params, "stop_token_ids", None) or ():
if isinstance(sid, int):
stop_ids.add(sid)
# EOS sources are only stops when ignore_eos=False. Read both
# the public @property (`eos_token_id`, `all_stop_token_ids`)
# AND the private fields (`_eos_token_id`, `_all_stop_token_ids`)
# because property behavior can vary across msgspec serialization
# boundaries while the private fields are always serialized.
if not ignore_eos:
for eos in (
getattr(sampling_params, "eos_token_id", None),
getattr(sampling_params, "_eos_token_id", None),
):
if isinstance(eos, int):
stop_ids.add(eos)
for sid in (
getattr(sampling_params, "all_stop_token_ids", None)
or getattr(sampling_params, "_all_stop_token_ids", None)
or ()
):
if isinstance(sid, int):
stop_ids.add(sid)
if stop_ids and output_token_ids[-1] in stop_ids:
stop_emission_drop = 1
target_rows = max(0, len(all_token_ids) - stop_emission_drop)

def _trim_to_target(t):
if not isinstance(t, torch.Tensor) or t.dim() < 1 or t.shape[0] == 0:
return t
if target_rows <= 0:
Comment thread
hsliuustc0106 marked this conversation as resolved.
# Defensive: empty prompt+output (or stop-only output) should
# not reach this builder; keep all rows rather than slicing
# to zero.
return t
if t.shape[0] > target_rows + 1:
logger.warning(
"thinker2talker_full_payload: unexpected excess rows "
"(got %d, target %d, stop_drop %d) for req=%s; trimming to target",
int(t.shape[0]),
target_rows,
stop_emission_drop,
getattr(request, "request_id", None),
)
if t.shape[0] > target_rows:
return t[:target_rows]
if t.shape[0] < target_rows:
logger.debug(
"thinker2talker_full_payload: under-captured rows "
"(got %d, target %d, stop_drop %d) for req=%s; talker may index past end",
int(t.shape[0]),
target_rows,
stop_emission_drop,
getattr(request, "request_id", None),
)
return t

thinker_emb_prefill = _trim_to_target(thinker_emb)
thinker_hid_prefill = _trim_to_target(thinker_hid)

payload: OmniPayload = {
"embed": {
Expand Down
Loading