-
Notifications
You must be signed in to change notification settings - Fork 1k
[Bugfix] Length-aware trim of thinker_emb/hid in non-async-chunk Stage-1 builder for Qwen3 Omni #3645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Bugfix] Length-aware trim of thinker_emb/hid in non-async-chunk Stage-1 builder for Qwen3 Omni #3645
Changes from all commits
3c0ed9c
004dc4b
9098fca
32ac2b5
946c2dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -484,29 +484,100 @@ def thinker2talker_full_payload( | |
| return None | ||
|
|
||
| prompt_token_ids = _ensure_list(getattr(request, "prompt_token_ids", []) or []) | ||
| output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) | ||
| all_token_ids = _ensure_list(getattr(request, "all_token_ids", None) or []) | ||
| if not all_token_ids: | ||
| output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) | ||
| all_token_ids = list(prompt_token_ids) + list(output_token_ids) | ||
|
|
||
| # Trim the trailing stop-token row from the accumulated thinker output. | ||
| # The accumulator captures one hidden-state row per executed thinker | ||
| # forward (prefill + every decode step including the one that emitted | ||
| # the stop_token), so for a finished request thinker_emb has exactly one | ||
| # row more than the rows the talker should consume. async_chunk's | ||
| # chunk-0 path naturally captures only the prefill / non-stop portion, | ||
| # which is why the [async_chunk] parametrization passes while [default] | ||
| # over-generates one codec frame on short outputs (e.g. | ||
| # test_one_word_prompt_001[default]: audio extends "London" with | ||
| # spurious phonemes). | ||
| if isinstance(thinker_emb, torch.Tensor) and thinker_emb.shape[0] > 0: | ||
| thinker_emb_prefill = thinker_emb[:-1] | ||
| else: | ||
| thinker_emb_prefill = thinker_emb | ||
| if isinstance(thinker_hid, torch.Tensor) and thinker_hid.shape[0] > 0: | ||
| thinker_hid_prefill = thinker_hid[:-1] | ||
| else: | ||
| thinker_hid_prefill = thinker_hid | ||
| # Length-aware trim of accumulated thinker output, finish-reason-aware. | ||
| # vLLM appends the sampled token to `output_token_ids` BEFORE | ||
| # `check_stop` (scheduler.py:1641-1651), so a stop-finished request | ||
| # has accumulator_rows == len(all_token_ids) including the stop | ||
| # emission row -- the talker must NOT consume that row (fba23325 | ||
| # spurious-phoneme regression). Max-token finishes do not append | ||
| # an extra forward, so no drop is needed (BK 9702 long-output | ||
| # regression). Primary: distinguish via `request.status`. Fallback | ||
| # only when status is absent: last-token-in-stop-id heuristic. | ||
| status = getattr(request, "status", None) | ||
| status_name = getattr(status, "name", None) or "" | ||
| if not status_name and status is not None: | ||
| status_name = str(status).rsplit(".", 1)[-1] | ||
| stop_emission_drop = 1 if status_name == "FINISHED_STOPPED" else 0 | ||
| if stop_emission_drop == 0 and not status_name and output_token_ids: | ||
| # Worker-side CachedRequestState has no `.status` field in vLLM | ||
| # v1, so this fallback runs for every production request. When | ||
| # `sampling_params.ignore_eos=True` vLLM continues past EOS, so | ||
| # a length-capped finish whose last sampled token coincidentally | ||
| # equals EOS must NOT be trimmed -- skip EOS from the stop set | ||
| # in that case. Custom `stop_token_ids` are still treated as | ||
| # stops; vLLM's `check_stop` runs stop-id matching before the | ||
| # length cap and ignores `ignore_eos` for `stop_token_ids`, so | ||
| # a last-token match there is unambiguously a stop finish. | ||
| sampling_params = getattr(request, "sampling_params", None) | ||
| if sampling_params is not None: | ||
| stop_ids: set[int] = set() | ||
| ignore_eos = bool(getattr(sampling_params, "ignore_eos", False)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need to discuss the testing scenarios and solutions in more detail.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. Then let's delay the merge until we have exactly the matching testing cases. |
||
| # Custom stop_token_ids always trigger stop in vLLM, regardless | ||
| # of ignore_eos (vLLM v1: `update_from_generation_config` writes | ||
| # secondary EOSes here too). Read the public list. | ||
| for sid in getattr(sampling_params, "stop_token_ids", None) or (): | ||
| if isinstance(sid, int): | ||
| stop_ids.add(sid) | ||
| # EOS sources are only stops when ignore_eos=False. Read both | ||
| # the public @property (`eos_token_id`, `all_stop_token_ids`) | ||
| # AND the private fields (`_eos_token_id`, `_all_stop_token_ids`) | ||
| # because property behavior can vary across msgspec serialization | ||
| # boundaries while the private fields are always serialized. | ||
| if not ignore_eos: | ||
| for eos in ( | ||
| getattr(sampling_params, "eos_token_id", None), | ||
| getattr(sampling_params, "_eos_token_id", None), | ||
| ): | ||
| if isinstance(eos, int): | ||
| stop_ids.add(eos) | ||
| for sid in ( | ||
| getattr(sampling_params, "all_stop_token_ids", None) | ||
| or getattr(sampling_params, "_all_stop_token_ids", None) | ||
| or () | ||
| ): | ||
| if isinstance(sid, int): | ||
| stop_ids.add(sid) | ||
| if stop_ids and output_token_ids[-1] in stop_ids: | ||
| stop_emission_drop = 1 | ||
| target_rows = max(0, len(all_token_ids) - stop_emission_drop) | ||
|
|
||
| def _trim_to_target(t): | ||
| if not isinstance(t, torch.Tensor) or t.dim() < 1 or t.shape[0] == 0: | ||
| return t | ||
| if target_rows <= 0: | ||
|
hsliuustc0106 marked this conversation as resolved.
|
||
| # Defensive: empty prompt+output (or stop-only output) should | ||
| # not reach this builder; keep all rows rather than slicing | ||
| # to zero. | ||
| return t | ||
| if t.shape[0] > target_rows + 1: | ||
| logger.warning( | ||
| "thinker2talker_full_payload: unexpected excess rows " | ||
| "(got %d, target %d, stop_drop %d) for req=%s; trimming to target", | ||
| int(t.shape[0]), | ||
| target_rows, | ||
| stop_emission_drop, | ||
| getattr(request, "request_id", None), | ||
| ) | ||
| if t.shape[0] > target_rows: | ||
| return t[:target_rows] | ||
| if t.shape[0] < target_rows: | ||
| logger.debug( | ||
| "thinker2talker_full_payload: under-captured rows " | ||
| "(got %d, target %d, stop_drop %d) for req=%s; talker may index past end", | ||
| int(t.shape[0]), | ||
| target_rows, | ||
| stop_emission_drop, | ||
| getattr(request, "request_id", None), | ||
| ) | ||
| return t | ||
|
|
||
| thinker_emb_prefill = _trim_to_target(thinker_emb) | ||
| thinker_hid_prefill = _trim_to_target(thinker_hid) | ||
|
|
||
| payload: OmniPayload = { | ||
| "embed": { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.