From 7aadeeef2464cab069074925d5797dcc1b67aea0 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Fri, 15 May 2026 03:59:57 +0000 Subject: [PATCH 01/19] [Refactor] Address PR #2677 review feedback: scheduler hygiene, IPC type safety, DRY. Five reviewer items bundled (file:line per item in the original review): + standard-freeing block in try/finally so the input_coordinator entry is pruned along every return path. Mirrors the pattern already in omni_generation_scheduler._free_request. Four inline "if self.input_coordinator is not None: self._free_input_coordinator_request(...)" calls collapsed into a single finally clause. self._omni_connector_initialized = True at the end of init. OmniGPUModelRunner._update_states gates cleanup_finished_request on this explicit flag instead of probing the private "_request_ids_mapping" attribute name. Removes the implicit "is the mixin done initialising" contract. list[OmniInputRegistration] (new minimal dataclass in vllm_omni/core/sched/output.py carrying request_id + external_req_id only - the two fields register_chunk_recv actually consumes). Replaces the prior list[Any], which msgspec falls back to JSON-ish serialisation for under PD-disagg / multi-node executor IPC. Wire payload also drops by ~one Request struct per pending registration. Tests stay green via duck-typed attribute access. capture boilerplate duplicated between omni_ar_scheduler and omni_generation_scheduler: - _consume_pending_connector_output(model_mode) -- drains _latest_omni_connector_output at top of schedule() - _capture_omni_connector_output(model_runner_output, model_mode) -- stashes omni_connector_output at tail of update_from_output() - _wrap_omni_scheduler_output(base, **extras) -- builds OmniSchedulerOutput from a base SchedulerOutput AR + generation schedulers each lose 3 copy-pasted blocks. Verified on H800 dev environment with --run-level full_model -m "full_model and H800 and omni": test_one_word_prompt_001 + test_speaker_002 ([default] and [async_chunk]) 4 passed in 6:32. Signed-off-by: natureofnature --- vllm_omni/core/sched/omni_ar_scheduler.py | 157 ++++++++---------- .../core/sched/omni_generation_scheduler.py | 43 +---- vllm_omni/core/sched/omni_scheduler_mixin.py | 72 ++++++++ .../core/sched/omni_scheduling_coordinator.py | 22 ++- vllm_omni/core/sched/output.py | 20 ++- vllm_omni/worker/gpu_model_runner.py | 13 +- .../omni_connector_model_runner_mixin.py | 7 + 7 files changed, 197 insertions(+), 137 deletions(-) diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 09ee55ba972..39dd4d04957 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -24,7 +24,6 @@ OmniSchedulingCoordinator, uses_qwen3_omni_full_payload_input_coordinator, ) -from vllm_omni.core.sched.output import OmniSchedulerOutput from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, ) @@ -211,18 +210,7 @@ def schedule(self) -> SchedulerOutput: # type: ignore[override] for req in list(queue): if getattr(req, "status", None) == RequestStatus.FINISHED_ABORTED: queue.remove(req) - connector_output = self._latest_omni_connector_output - self._latest_omni_connector_output = None - if self.input_coordinator: - if connector_output and connector_output.request_metadata: - self.input_coordinator.update_request_metadata( - self.requests, connector_output.request_metadata, model_mode="ar" - ) - self.input_coordinator.process_pending_full_payload_inputs( - self.waiting, - self.running, - connector_output.stage_recv_req_ids if connector_output else set(), - ) + self._consume_pending_connector_output(model_mode="ar") if self.chunk_transfer_adapter: self.chunk_transfer_adapter.process_pending_chunks(self.waiting, self.running) @@ -278,13 +266,9 @@ def schedule(self) -> SchedulerOutput: # type: ignore[override] finished_reqs = {} # Wrap in omni scheduler output to carry transfer metadata. - base_fields = SchedulerOutput.__dataclass_fields__.keys() - base_data = {name: getattr(scheduler_output, name) for name in base_fields} - input_regs = self.input_coordinator.pending_input_registrations if self.input_coordinator else [] - return OmniSchedulerOutput( - **base_data, + return self._wrap_omni_scheduler_output( + scheduler_output, finished_requests_needing_kv_transfer=finished_reqs, - pending_input_registrations=input_regs, ) def update_from_output( @@ -581,15 +565,7 @@ def update_from_output( engine_core_outputs[0] = eco = EngineCoreOutputs() eco.scheduler_stats = stats - omni_output = getattr(model_runner_output, "omni_connector_output", None) - if omni_output is not None: - self._latest_omni_connector_output = omni_output - if self.input_coordinator and omni_output.request_metadata: - self.input_coordinator.update_request_metadata( - self.requests, - omni_output.request_metadata, - model_mode="ar", - ) + self._capture_omni_connector_output(model_runner_output, model_mode="ar") # Free blocks that were held for transfer (kv_ready and # active_kv_transfers updates already done before the per-request loop). @@ -668,70 +644,73 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di if self.finished_req_ids_dict is not None: self.finished_req_ids_dict[request.client_index].add(request_id) - # 2. Omni Specific: Check if we need to transfer KV - if self._should_transfer_kv_for_request(request_id): - already_triggered = request_id in self.transfer_triggered_requests - is_active = request_id in self.active_kv_transfers - - if already_triggered: - if is_active: - # It triggered but hasn't finished yet. We MUST wait. - logger.debug(f"[Omni] Request {request_id} finished but transfer is still ACTIVE. Waiting.") + # Mirror the generation scheduler's try/finally pattern so the + # input_coordinator entry is always pruned along every return path, + # including the early returns for in-flight / waiting KV transfers + # below. _free_input_coordinator_request is a no-op when the + # coordinator is None, so the unconditional finally is safe. + try: + # 2. Omni Specific: Check if we need to transfer KV + if self._should_transfer_kv_for_request(request_id): + already_triggered = request_id in self.transfer_triggered_requests + is_active = request_id in self.active_kv_transfers + + if already_triggered: + if is_active: + # It triggered but hasn't finished yet. We MUST wait. + logger.debug(f"[Omni] Request {request_id} finished but transfer is still ACTIVE. Waiting.") + self.waiting_for_transfer_free.add(request_id) + kv_xfer_params = None + return kv_xfer_params + elif request_id in self.waiting_for_transfer_free: + # Blocks held until KV extraction completes in a future step. + return None + else: + logger.debug( + f"[Omni] Request {request_id} finished and transfer no longer ACTIVE (extracted/acked). " + "Freeing immediately." + ) + else: self.waiting_for_transfer_free.add(request_id) - if self.input_coordinator is not None: - self._free_input_coordinator_request(request_id) - kv_xfer_params = None + confirmed_computed = self._get_confirmed_num_computed_tokens(request) + self._mark_request_for_kv_transfer(request_id, confirmed_computed) + # Return KV transfer metadata so it propagates to RequestOutput + if request_id in self.requests_needing_kv_transfer: + transfer_data = self.requests_needing_kv_transfer[request_id] + kv_xfer_params = { + "past_key_values": transfer_data["block_ids"], + "kv_metadata": { + "seq_len": transfer_data["seq_len"], + "block_ids": transfer_data["block_ids"], + }, + } + # Also update request.additional_information for good measure + add_info = getattr(request, "additional_information", None) + # If additional_information is an AdditionalInformationPayload-like object, + # unpack it into a plain dict. + if ( + add_info is not None + and hasattr(add_info, "entries") + and isinstance(getattr(add_info, "entries"), dict) + ): + request.additional_information = deserialize_additional_information(add_info) + add_info = request.additional_information + if add_info is None: + request.additional_information = {} + add_info = request.additional_information + if isinstance(add_info, dict): + add_info.update(kv_xfer_params) + return kv_xfer_params - elif request_id in self.waiting_for_transfer_free: - # Blocks held until KV extraction completes in a future step. - if self.input_coordinator is not None: - self._free_input_coordinator_request(request_id) - return None - else: - logger.debug( - f"[Omni] Request {request_id} finished and transfer no longer ACTIVE (extracted/acked). " - "Freeing immediately." - ) - else: - self.waiting_for_transfer_free.add(request_id) - confirmed_computed = self._get_confirmed_num_computed_tokens(request) - self._mark_request_for_kv_transfer(request_id, confirmed_computed) - # Return KV transfer metadata so it propagates to RequestOutput - if request_id in self.requests_needing_kv_transfer: - transfer_data = self.requests_needing_kv_transfer[request_id] - kv_xfer_params = { - "past_key_values": transfer_data["block_ids"], - "kv_metadata": {"seq_len": transfer_data["seq_len"], "block_ids": transfer_data["block_ids"]}, - } - # Also update request.additional_information for good measure - add_info = getattr(request, "additional_information", None) - # If additional_information is an AdditionalInformationPayload-like object, - # unpack it into a plain dict. - if ( - add_info is not None - and hasattr(add_info, "entries") - and isinstance(getattr(add_info, "entries"), dict) - ): - request.additional_information = deserialize_additional_information(add_info) - add_info = request.additional_information - if add_info is None: - request.additional_information = {} - add_info = request.additional_information - if isinstance(add_info, dict): - add_info.update(kv_xfer_params) - - if self.input_coordinator is not None: - self._free_input_coordinator_request(request_id) - return kv_xfer_params - - # 3. Standard Freeing - delay_free_blocks |= connector_delay_free_blocks - if self.input_coordinator is not None: - self._free_input_coordinator_request(request_id) - if not delay_free_blocks: - self._free_blocks(request) - return kv_xfer_params + # 3. Standard Freeing + delay_free_blocks |= connector_delay_free_blocks + if not delay_free_blocks: + self._free_blocks(request) + + return kv_xfer_params + finally: + self._free_input_coordinator_request(request_id) def _free_blocks(self, request: Request): # Helper to match base class structure if not directly available diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 957b7e5f677..af0b7dcff46 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -28,7 +28,7 @@ OmniSchedulingCoordinator, uses_qwen3_omni_full_payload_input_coordinator, ) -from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData, OmniSchedulerOutput +from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, ) @@ -82,18 +82,7 @@ def schedule(self) -> SchedulerOutput: # Temporary queue: preserve waiting order, do not disturb non-diffusion requests skipped_waiting_requests = create_request_queue(self.policy) req_index = 0 - connector_output = self._latest_omni_connector_output - self._latest_omni_connector_output = None - if self.input_coordinator: - if connector_output and connector_output.request_metadata: - self.input_coordinator.update_request_metadata( - self.requests, connector_output.request_metadata, model_mode="generation" - ) - self.input_coordinator.process_pending_full_payload_inputs( - self.waiting, - self.running, - connector_output.stage_recv_req_ids if connector_output else set(), - ) + self._consume_pending_connector_output(model_mode="generation") if self.chunk_transfer_adapter: self.chunk_transfer_adapter.process_pending_chunks(self.waiting, self.running) @@ -227,14 +216,7 @@ def schedule(self) -> SchedulerOutput: res = super().schedule() if self.input_coordinator: self.input_coordinator.restore_queues(self.waiting, self.running) - base_fields = SchedulerOutput.__dataclass_fields__.keys() - base_data = {name: getattr(res, name) for name in base_fields} - return OmniSchedulerOutput( - **base_data, - pending_input_registrations=( - self.input_coordinator.pending_input_registrations if self.input_coordinator else [] - ), - ) + return self._wrap_omni_scheduler_output(res) # Compute common prefix blocks (aligned with v1) num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) @@ -362,14 +344,7 @@ def schedule(self) -> SchedulerOutput: if self.input_coordinator: self.input_coordinator.restore_queues(self.waiting, self.running) - base_fields = SchedulerOutput.__dataclass_fields__.keys() - base_data = {name: getattr(scheduler_output, name) for name in base_fields} - return OmniSchedulerOutput( - **base_data, - pending_input_registrations=( - self.input_coordinator.pending_input_registrations if self.input_coordinator else [] - ), - ) + return self._wrap_omni_scheduler_output(scheduler_output) def finish_requests(self, request_ids, finished_status: RequestStatus) -> list[tuple[str, int]]: """Handles the finish signal from outside the scheduler. @@ -683,15 +658,7 @@ def update_from_output( engine_core_outputs[0] = eco = EngineCoreOutputs() eco.scheduler_stats = stats - omni_output = getattr(model_runner_output, "omni_connector_output", None) - if omni_output is not None: - self._latest_omni_connector_output = omni_output - if self.input_coordinator and omni_output.request_metadata: - self.input_coordinator.update_request_metadata( - self.requests, - omni_output.request_metadata, - model_mode="generation", - ) + self._capture_omni_connector_output(model_runner_output, model_mode="generation") return engine_core_outputs diff --git a/vllm_omni/core/sched/omni_scheduler_mixin.py b/vllm_omni/core/sched/omni_scheduler_mixin.py index 606739e9087..fba514756a7 100644 --- a/vllm_omni/core/sched/omni_scheduler_mixin.py +++ b/vllm_omni/core/sched/omni_scheduler_mixin.py @@ -1,8 +1,13 @@ from __future__ import annotations +from typing import Any + +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import EngineCoreEventType from vllm.v1.request import Request, RequestStatus, StreamingUpdate +from vllm_omni.core.sched.output import OmniInputRegistration, OmniSchedulerOutput + class OmniSchedulerMixin: """Shared scheduler helpers for omni-specific request handling.""" @@ -13,6 +18,73 @@ def _free_input_coordinator_request(self, request_id: str) -> None: if input_coordinator is not None: input_coordinator.free_finished_request(request_id) + # ------------------------------------------------------------------ # + # Shared scheduler/output helpers (lift the AR / generation duplicates) + # ------------------------------------------------------------------ # + + def _consume_pending_connector_output(self, model_mode: str) -> None: + """Drain ``self._latest_omni_connector_output`` into the coordinator. + + Called at the top of every ``schedule()`` cycle. Identical between + AR and generation schedulers except for the ``model_mode`` argument + forwarded to ``update_request_metadata``. + """ + connector_output = getattr(self, "_latest_omni_connector_output", None) + self._latest_omni_connector_output = None + input_coordinator = getattr(self, "input_coordinator", None) + if input_coordinator is None: + return + if connector_output and connector_output.request_metadata: + input_coordinator.update_request_metadata( + self.requests, connector_output.request_metadata, model_mode=model_mode + ) + input_coordinator.process_pending_full_payload_inputs( + self.waiting, + self.running, + connector_output.stage_recv_req_ids if connector_output else set(), + ) + + def _capture_omni_connector_output(self, model_runner_output: Any, model_mode: str) -> None: + """Stash the model runner's omni_connector_output for next schedule(). + + Called at the tail of every ``update_from_output()``. Identical + between AR and generation schedulers except for ``model_mode``. + """ + omni_output = getattr(model_runner_output, "omni_connector_output", None) + if omni_output is None: + return + self._latest_omni_connector_output = omni_output + input_coordinator = getattr(self, "input_coordinator", None) + if input_coordinator and omni_output.request_metadata: + input_coordinator.update_request_metadata( + self.requests, + omni_output.request_metadata, + model_mode=model_mode, + ) + + def _wrap_omni_scheduler_output( + self, + base: SchedulerOutput, + *, + finished_requests_needing_kv_transfer: dict | None = None, + pending_input_registrations: list[OmniInputRegistration] | None = None, + ) -> OmniSchedulerOutput: + """Wrap a base ``SchedulerOutput`` in ``OmniSchedulerOutput``. + + Pulls each base ``SchedulerOutput`` dataclass field via ``getattr`` + and forwards optional omni-specific fields. Lifted from 4 separate + copy-pastes between AR (1) and generation (3) schedulers. + """ + base_data = {name: getattr(base, name) for name in SchedulerOutput.__dataclass_fields__} + input_coordinator = getattr(self, "input_coordinator", None) + if pending_input_registrations is None: + pending_input_registrations = input_coordinator.pending_input_registrations if input_coordinator else [] + return OmniSchedulerOutput( + **base_data, + finished_requests_needing_kv_transfer=finished_requests_needing_kv_transfer or {}, + pending_input_registrations=pending_input_registrations, + ) + def _replace_session_with_streaming_update( self, session: Request, diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index 6c32ed4cda8..28ded619fc2 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -19,6 +19,8 @@ from vllm.logger import init_logger from vllm.v1.request import Request, RequestStatus +from vllm_omni.core.sched.output import OmniInputRegistration + logger = init_logger(__name__) @@ -59,7 +61,11 @@ def __init__(self, scheduler_max_num_seqs: int, stage_id: int = 0, async_chunk: # Requests waiting for full_payload stage input (WAITING_FOR_INPUT). self._waiting_for_input: deque[Any] = deque() - self.pending_input_registrations: list[Any] = [] + # Per-cycle list of minimal handles to ship to the model runner so it + # can call register_chunk_recv(). Typed concretely (not list[Any]) so + # the surrounding OmniSchedulerOutput stays msgspec-friendly across + # default, PD-disagg, and multi-node executor IPC paths. + self.pending_input_registrations: list[OmniInputRegistration] = [] # Monotonic timestamp recording when each request first entered # WAITING_FOR_CHUNK or WAITING_FOR_INPUT. Used by @@ -166,7 +172,12 @@ def process_pending_full_payload_inputs( self._waiting_since.setdefault(request.request_id, time.monotonic()) to_remove.append(request) self._waiting_for_input.append(request) - self.pending_input_registrations.append(request) + self.pending_input_registrations.append( + OmniInputRegistration( + request_id=request.request_id, + external_req_id=getattr(request, "external_req_id", None), + ) + ) elif request.status == RequestStatus.WAITING_FOR_INPUT: if request.request_id in stage_recv_req_ids: request.status = RequestStatus.WAITING @@ -174,7 +185,12 @@ def process_pending_full_payload_inputs( else: to_remove.append(request) self._waiting_for_input.append(request) - self.pending_input_registrations.append(request) + self.pending_input_registrations.append( + OmniInputRegistration( + request_id=request.request_id, + external_req_id=getattr(request, "external_req_id", None), + ) + ) if to_remove: # Use the bulk-remove helper: one O(N) sweep instead of N # repeated O(N) removes from a list-backed queue. diff --git a/vllm_omni/core/sched/output.py b/vllm_omni/core/sched/output.py index 800881d9ff8..800eaf39815 100644 --- a/vllm_omni/core/sched/output.py +++ b/vllm_omni/core/sched/output.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, field -from typing import Any from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.request import Request @@ -72,9 +71,26 @@ class OmniCachedRequestData(CachedRequestData): additional_information: dict[str, dict | None] +@dataclass +class OmniInputRegistration: + """Minimal identifier carried from scheduler to runner for chunk-recv + registration. + + The runner's ``register_chunk_recv`` only consumes ``request_id`` and + ``external_req_id`` from each pending request, so we ship just those + two fields instead of the full Request object. Concrete typing + keeps msgspec serialization deterministic across IPC (default, + PD-disagg, multi-node executor variants) and avoids the + ``list[Any]`` fallback path. + """ + + request_id: str + external_req_id: str | None = None + + @dataclass class OmniSchedulerOutput(SchedulerOutput): """Scheduler output with omni-specific transfer metadata.""" finished_requests_needing_kv_transfer: dict[str, dict] = field(default_factory=dict) - pending_input_registrations: list[Any] = field(default_factory=list) + pending_input_registrations: list[OmniInputRegistration] = field(default_factory=list) diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 0f164fea6df..f157e5db1a1 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -361,12 +361,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None # Remove finished requests from the cached states. # cleanup_finished_request lives on OmniConnectorModelRunnerMixin and - # is only safe to call once init_omni_connectors() has populated the - # mixin state. Archs that inherit the method via MRO without running - # that init must be skipped, so probe a mixin-owned attribute as the - # "state initialized" gate. + # is only safe to call once init_omni_connectors() has finished + # populating mixin state (it sets ``_omni_connector_initialized = True`` + # at the very end). Archs that inherit the method via MRO without + # running that init must be skipped, so gate on the explicit flag + # rather than probing private attribute names. cleanup_finished_request = ( - getattr(self, "cleanup_finished_request", None) if hasattr(self, "_request_ids_mapping") else None + getattr(self, "cleanup_finished_request", None) + if getattr(self, "_omni_connector_initialized", False) + else None ) for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index eb5b47b53e3..b316b5ade44 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -189,6 +189,13 @@ def init_omni_connectors( ) self._save_thread.start() + # Explicit "fully initialised" marker so other parts of the runner + # (e.g. _update_states cleanup) can branch on a stable contract + # instead of probing for private mixin attribute names. Must be set + # only after every field above has been bound, so a partially + # constructed mixin is never observable as initialised. + self._omni_connector_initialized = True + def shutdown_omni_connectors(self) -> None: """Stop background threads and release connector resources.""" self._stop_event.set() From 1c3e4b1c353843b93d9338dfe03be88346742779 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Fri, 15 May 2026 08:15:06 +0000 Subject: [PATCH 02/19] [BugFix] qwen3_omni thinker_emb/hid trim refinement series Three squashed bugfix commits all targeting qwen3_omni thinker2talker_full_payload's row-counting logic for the worker connector data plane: 1. Length-aware trim (orig c805e487): switch from unconditional [:-1] trim to a target_rows = len(all_token_ids) computation, so max-token finishes (which do not append a stop-emission row) are not over-trimmed. Fixes long-output regression on test_mix_to_text_audio from BK 9702 main build. 2. Drop output_token_ids hoist (orig 0f14863a): surgical revert of an unused hoist left behind by the prior trim commit; no functional change. 3. Finish-reason-aware trim (orig 6d95f8bd): add a stop_emission_drop subtraction so FINISHED_STOPPED requests still drop their extra accumulated hidden-state row. Codex P1 review on the prior commit identified this as a regression on short stop-finished outputs (spurious-phoneme on test_speaker_002). Detection: primary via request.status; fallback heuristic via last-token-in-stop-set when worker-side CachedRequestState has no .status field. Signed-off-by: natureofnature --- .../test_qwen3_omni_streaming_helpers.py | 253 ++++++++++++++++++ .../stage_input_processors/qwen3_omni.py | 109 ++++++-- 2 files changed, 343 insertions(+), 19 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index f11a4654ec2..8ec4f9cda75 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -168,6 +168,7 @@ def test_talker2code2wav_full_payload_keeps_all_zero_codec_rows() -> None: def test_thinker2talker_full_payload_packs_complete_tensors() -> None: + """Standard max_tokens finish path: rows == target → no trim.""" request = SimpleNamespace( request_id="thinker", prompt_token_ids=[151644, 872], @@ -187,3 +188,255 @@ def test_thinker2talker_full_payload_packs_complete_tensors() -> None: assert payload["embed"]["prefill"].device.type == "cpu" assert payload["hidden_states"]["output"].device.type == "cpu" assert payload["next_stage_prompt_len"] > 0 + # Lock down the no-trim invariant for rows == target. + assert payload["embed"]["prefill"].shape[0] == 3 + assert payload["hidden_states"]["output"].shape[0] == 3 + + +def test_thinker2talker_full_payload_trims_excess_stop_token_row() -> None: + """Excess-rows path: rows == target + 1 → trim trailing row.""" + request = SimpleNamespace( + request_id="thinker-excess", + prompt_token_ids=[151644, 872], + output_token_ids=[3], + all_token_ids=[151644, 872, 3], + sampling_params=None, + status=None, + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(4, 2), + "hidden_states.layer_24": torch.full((4, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 3 + assert payload["hidden_states"]["output"].shape[0] == 3 + + +def test_thinker2talker_full_payload_drops_stop_emission_row_when_finished_stopped() -> None: + """FINISHED_STOPPED: drop 1 extra row even when rows == target. + + vLLM appends the stop-token to output_token_ids before check_stop, so + len(all_token_ids) includes the stop slot AND the accumulator has the + stop emission's forward row. Both counts equal P+O (here 3). Talker + target should be P+O-1 (=2), not P+O. Without the extra drop the + stop emission's hidden state leaks into talker prefill (fba23325 + spurious-phoneme regression). + """ + request = SimpleNamespace( + request_id="thinker-stop-finished", + prompt_token_ids=[151644, 872], + output_token_ids=[3], + all_token_ids=[151644, 872, 3], + sampling_params=None, + status=SimpleNamespace(name="FINISHED_STOPPED"), + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(3, 2), + "hidden_states.layer_24": torch.full((3, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 2 + assert payload["hidden_states"]["output"].shape[0] == 2 + + +def test_thinker2talker_full_payload_drops_stop_emission_via_eos_fallback() -> None: + """Stop-detection fallback: last token in sampling_params.eos_token_id.""" + EOS = 151645 + request = SimpleNamespace( + request_id="thinker-stop-fallback", + prompt_token_ids=[151644, 872], + output_token_ids=[3, EOS], + all_token_ids=[151644, 872, 3, EOS], + sampling_params=SimpleNamespace(eos_token_id=EOS, stop_token_ids=None), + status=None, + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(4, 2), + "hidden_states.layer_24": torch.full((4, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 3 + assert payload["hidden_states"]["output"].shape[0] == 3 + + +def test_thinker2talker_full_payload_no_drop_when_finished_length_capped() -> None: + """FINISHED_LENGTH_CAPPED (max_tokens): no extra drop; BK 9702 regression guard.""" + request = SimpleNamespace( + request_id="thinker-length-capped", + prompt_token_ids=[151644, 872], + output_token_ids=[3], + all_token_ids=[151644, 872, 3], + sampling_params=SimpleNamespace(eos_token_id=999, stop_token_ids=None), + status=SimpleNamespace(name="FINISHED_LENGTH_CAPPED"), + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(3, 2), + "hidden_states.layer_24": torch.full((3, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 3 + assert payload["hidden_states"]["output"].shape[0] == 3 + + +def test_thinker2talker_full_payload_drops_via_private_eos_field() -> None: + """Worker-side sampling_params where the public `eos_token_id` property is + None but the private `_eos_token_id` / `_all_stop_token_ids` carry the + primary EOS (the msgspec-deserialization shape on the worker boundary). + + The fallback must read the private fields to detect the stop. + """ + EOS = 151643 + request = SimpleNamespace( + request_id="thinker-private-eos", + prompt_token_ids=[151644, 872], + output_token_ids=[3, EOS], + all_token_ids=[151644, 872, 3, EOS], + # Public `eos_token_id` looks empty; only the private fields carry it. + sampling_params=SimpleNamespace( + eos_token_id=None, + stop_token_ids=None, + ignore_eos=False, + _eos_token_id=EOS, + _all_stop_token_ids={EOS}, + ), + status=None, + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(4, 2), + "hidden_states.layer_24": torch.full((4, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 3 + assert payload["hidden_states"]["output"].shape[0] == 3 + + +def test_thinker2talker_full_payload_drops_via_all_stop_token_ids() -> None: + """Secondary EOS only in `_all_stop_token_ids` (not in `_eos_token_id`): + multi-EOS Qwen3 case where the model finished on a secondary EOS. + """ + SECONDARY_EOS = 151645 + request = SimpleNamespace( + request_id="thinker-secondary-eos", + prompt_token_ids=[151644, 872], + output_token_ids=[3, SECONDARY_EOS], + all_token_ids=[151644, 872, 3, SECONDARY_EOS], + sampling_params=SimpleNamespace( + eos_token_id=151643, # primary, not the one we hit + stop_token_ids=None, + ignore_eos=False, + _eos_token_id=151643, + _all_stop_token_ids={151643, SECONDARY_EOS}, + ), + status=None, + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(4, 2), + "hidden_states.layer_24": torch.full((4, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 3 + assert payload["hidden_states"]["output"].shape[0] == 3 + + +def test_thinker2talker_full_payload_no_drop_when_ignore_eos_and_trailing_eos() -> None: + """ignore_eos=True + length-capped + last token == EOS: no drop. + + Production worker uses CachedRequestState (no `.status` field), so + the status path doesn't catch this case; we rely on the + `sampling_params.ignore_eos` flag in the fallback to suppress the + EOS-as-stop heuristic. + """ + EOS = 151645 + request = SimpleNamespace( + request_id="thinker-ignore-eos-trailing-eos", + prompt_token_ids=[151644, 872], + output_token_ids=[3, EOS], + all_token_ids=[151644, 872, 3, EOS], + sampling_params=SimpleNamespace(eos_token_id=EOS, stop_token_ids=None, ignore_eos=True), + status=None, # production worker state has no status + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(4, 2), + "hidden_states.layer_24": torch.full((4, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 4 + assert payload["hidden_states"]["output"].shape[0] == 4 + + +def test_thinker2talker_full_payload_no_drop_when_length_capped_with_trailing_eos() -> None: + """FINISHED_LENGTH_CAPPED + last token == EOS coincidence: no drop. + + Status path takes precedence over last-token heuristic. Without + this guard the fallback would incorrectly drop a row when a length-capped + request happens to end on the EOS token id. + """ + EOS = 151645 + request = SimpleNamespace( + request_id="thinker-len-cap-trailing-eos", + prompt_token_ids=[151644, 872], + output_token_ids=[3, EOS], + all_token_ids=[151644, 872, 3, EOS], + sampling_params=SimpleNamespace(eos_token_id=EOS, stop_token_ids=None), + status=SimpleNamespace(name="FINISHED_LENGTH_CAPPED"), + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(4, 2), + "hidden_states.layer_24": torch.full((4, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 4 + assert payload["hidden_states"]["output"].shape[0] == 4 + + +def test_thinker2talker_full_payload_preserves_under_capture() -> None: + """Under-capture path: rows < target → no trim, safe degrade.""" + request = SimpleNamespace( + request_id="thinker-undercap", + prompt_token_ids=[151644, 872], + output_token_ids=[3], + all_token_ids=[151644, 872, 3], + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(2, 2), + "hidden_states.layer_24": torch.full((2, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 2 + assert payload["hidden_states"]["output"].shape[0] == 2 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index fd7cfd2aa60..1ee84cf6019 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -484,29 +484,100 @@ def thinker2talker_full_payload( return None prompt_token_ids = _ensure_list(getattr(request, "prompt_token_ids", []) or []) + output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) all_token_ids = _ensure_list(getattr(request, "all_token_ids", None) or []) if not all_token_ids: - output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) all_token_ids = list(prompt_token_ids) + list(output_token_ids) - # Trim the trailing stop-token row from the accumulated thinker output. - # The accumulator captures one hidden-state row per executed thinker - # forward (prefill + every decode step including the one that emitted - # the stop_token), so for a finished request thinker_emb has exactly one - # row more than the rows the talker should consume. async_chunk's - # chunk-0 path naturally captures only the prefill / non-stop portion, - # which is why the [async_chunk] parametrization passes while [default] - # over-generates one codec frame on short outputs (e.g. - # test_one_word_prompt_001[default]: audio extends "London" with - # spurious phonemes). - if isinstance(thinker_emb, torch.Tensor) and thinker_emb.shape[0] > 0: - thinker_emb_prefill = thinker_emb[:-1] - else: - thinker_emb_prefill = thinker_emb - if isinstance(thinker_hid, torch.Tensor) and thinker_hid.shape[0] > 0: - thinker_hid_prefill = thinker_hid[:-1] - else: - thinker_hid_prefill = thinker_hid + # Length-aware trim of accumulated thinker output, finish-reason-aware. + # vLLM appends the sampled token to `output_token_ids` BEFORE + # `check_stop` (scheduler.py:1641-1651), so a stop-finished request + # has accumulator_rows == len(all_token_ids) including the stop + # emission row -- the talker must NOT consume that row (fba23325 + # spurious-phoneme regression). Max-token finishes do not append + # an extra forward, so no drop is needed (BK 9702 long-output + # regression). Primary: distinguish via `request.status`. Fallback + # only when status is absent: last-token-in-stop-id heuristic. + status = getattr(request, "status", None) + status_name = getattr(status, "name", None) or "" + if not status_name and status is not None: + status_name = str(status).rsplit(".", 1)[-1] + stop_emission_drop = 1 if status_name == "FINISHED_STOPPED" else 0 + if stop_emission_drop == 0 and not status_name and output_token_ids: + # Worker-side CachedRequestState has no `.status` field in vLLM + # v1, so this fallback runs for every production request. When + # `sampling_params.ignore_eos=True` vLLM continues past EOS, so + # a length-capped finish whose last sampled token coincidentally + # equals EOS must NOT be trimmed -- skip EOS from the stop set + # in that case. Custom `stop_token_ids` are still treated as + # stops; vLLM's `check_stop` runs stop-id matching before the + # length cap and ignores `ignore_eos` for `stop_token_ids`, so + # a last-token match there is unambiguously a stop finish. + sampling_params = getattr(request, "sampling_params", None) + if sampling_params is not None: + stop_ids: set[int] = set() + ignore_eos = bool(getattr(sampling_params, "ignore_eos", False)) + # Custom stop_token_ids always trigger stop in vLLM, regardless + # of ignore_eos (vLLM v1: `update_from_generation_config` writes + # secondary EOSes here too). Read the public list. + for sid in getattr(sampling_params, "stop_token_ids", None) or (): + if isinstance(sid, int): + stop_ids.add(sid) + # EOS sources are only stops when ignore_eos=False. Read both + # the public @property (`eos_token_id`, `all_stop_token_ids`) + # AND the private fields (`_eos_token_id`, `_all_stop_token_ids`) + # because property behavior can vary across msgspec serialization + # boundaries while the private fields are always serialized. + if not ignore_eos: + for eos in ( + getattr(sampling_params, "eos_token_id", None), + getattr(sampling_params, "_eos_token_id", None), + ): + if isinstance(eos, int): + stop_ids.add(eos) + for sid in ( + getattr(sampling_params, "all_stop_token_ids", None) + or getattr(sampling_params, "_all_stop_token_ids", None) + or () + ): + if isinstance(sid, int): + stop_ids.add(sid) + if stop_ids and output_token_ids[-1] in stop_ids: + stop_emission_drop = 1 + target_rows = max(0, len(all_token_ids) - stop_emission_drop) + + def _trim_to_target(t): + if not isinstance(t, torch.Tensor) or t.dim() < 1 or t.shape[0] == 0: + return t + if target_rows <= 0: + # Defensive: empty prompt+output (or stop-only output) should + # not reach this builder; keep all rows rather than slicing + # to zero. + return t + if t.shape[0] > target_rows + 1: + logger.warning( + "thinker2talker_full_payload: unexpected excess rows " + "(got %d, target %d, stop_drop %d) for req=%s; trimming to target", + int(t.shape[0]), + target_rows, + stop_emission_drop, + getattr(request, "request_id", None), + ) + if t.shape[0] > target_rows: + return t[:target_rows] + if t.shape[0] < target_rows: + logger.debug( + "thinker2talker_full_payload: under-captured rows " + "(got %d, target %d, stop_drop %d) for req=%s; talker may index past end", + int(t.shape[0]), + target_rows, + stop_emission_drop, + getattr(request, "request_id", None), + ) + return t + + thinker_emb_prefill = _trim_to_target(thinker_emb) + thinker_hid_prefill = _trim_to_target(thinker_hid) payload: OmniPayload = { "embed": { From 78fbd89acefa2da6a99050f96da5c22a80d1fdbb Mon Sep 17 00:00:00 2001 From: natureofnature Date: Sun, 17 May 2026 10:44:40 +0000 Subject: [PATCH 03/19] [PR3] Phase 2a/2d structural-gate infrastructure for per-model migration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consolidates the two infrastructure commits that prepared PR3 for per-model sync-data-plane migration: - 6d4b4890 [Phase 2a] Structural gate via `_is_sync_input` marker (drop hard-coded `model_arch == Qwen3OmniMoeForConditionalGeneration` in `omni_scheduling_coordinator.uses_full_payload_input_coordinator`). Marker is set on the consumer-side `*_token_only` builder in each model's SIP module; the consumer-side scheduler gate reads it via the resolved `custom_process_input_func` callable. - d7bc85fa [Phase 2d] REPLACE-keys accumulator hook + arch-gate cleanup. Per-model `_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str]` declared in the SIP module; the worker accumulator (`omni_connector_model_runner_mixin.accumulate_full_payload_output`) looks it up via `proc.__module__` and uses REPLACE semantics for those keys (default is CONCAT). Used by models where a key carries the full result so far rather than per-step deltas. Net effect: - `core/sched/omni_scheduling_coordinator.py`: marker-based structural gate - `worker/omni_connector_model_runner_mixin.py`: `accumulate_full_payload_output`, `_resolve_full_payload_replace_keys`, `should_accumulate_full_payload_output`, related arch-gate cleanup - `model_executor/stage_input_processors/qwen3_omni.py`: declares `_FULL_PAYLOAD_REPLACE_KEYS` for qwen3_omni's `talker2code2wav` keys - Per-stage `custom_process_input_func` and `sync_process_input_func` selection plumbing remains in `config/stage_config.py:_select_processor_funcs`. After this commit, per-arch SIP modules can declare: - a `*_token_only` builder (sync_process_input_func, marked `_is_sync_input = True`) - a `*_full_payload` builder (custom_process_next_stage_input_func) - optional `_FULL_PAYLOAD_REPLACE_KEYS` for REPLACE semantics and the worker connector + scheduler coordinator handle the rest uniformly. Per-arch migrations follow in 9 subsequent commits (covo_audio, dynin_omni × 2, mimo_audio, qwen3_tts, cosyvoice3, ming_flash_omni, qwen2_5_omni × 2). Signed-off-by: natureofnature --- .../test_qwen3_omni_streaming_helpers.py | 57 ++++++++++++ vllm_omni/core/sched/omni_ar_scheduler.py | 4 +- .../core/sched/omni_generation_scheduler.py | 4 +- .../core/sched/omni_scheduling_coordinator.py | 25 ++++-- .../stage_input_processors/qwen3_omni.py | 24 +++-- .../omni_connector_model_runner_mixin.py | 88 ++++++++++++++++--- 6 files changed, 165 insertions(+), 37 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index 8ec4f9cda75..9f4b453d765 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -440,3 +440,60 @@ def test_thinker2talker_full_payload_preserves_under_capture() -> None: assert payload is not None assert payload["embed"]["prefill"].shape[0] == 2 assert payload["hidden_states"]["output"].shape[0] == 2 + + +def test_accumulator_replaces_keys_in_replace_set() -> None: + """REPLACE-key semantics: subsequent emissions of the same key replace, not append.""" + from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin + + class _StubMixin(OmniConnectorModelRunnerMixin): + def __init__(self): + self._pending_full_payload_send = {} + self._full_payload_replace_keys_cached = frozenset({"model_outputs"}) + + stub = _StubMixin() + stub.accumulate_full_payload_output( + "req1", + { + "model_outputs": torch.tensor([[1.0, 2.0]]), + "hidden_states.output": torch.tensor([[10.0]]), + }, + request=None, + ) + stub.accumulate_full_payload_output( + "req1", + { + "model_outputs": torch.tensor([[3.0, 4.0]]), + "hidden_states.output": torch.tensor([[20.0]]), + }, + request=None, + ) + output, _ = stub._materialize_full_payload_entry(stub._pending_full_payload_send["req1"]) + # model_outputs REPLACED (second value only): + assert torch.equal(output["model_outputs"], torch.tensor([[3.0, 4.0]])) + # hidden_states.output CONCATENATED: + assert torch.equal(output["hidden_states.output"], torch.tensor([[10.0], [20.0]])) + + +def test_accumulator_concat_default_when_no_replace_keys() -> None: + """Default semantics: 2-D+ tensors concat across emissions when not in replace_keys.""" + from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin + + class _StubMixin(OmniConnectorModelRunnerMixin): + def __init__(self): + self._pending_full_payload_send = {} + self._full_payload_replace_keys_cached = frozenset() + + stub = _StubMixin() + stub.accumulate_full_payload_output( + "req1", + {"embed.prefill": torch.tensor([[1.0]])}, + request=None, + ) + stub.accumulate_full_payload_output( + "req1", + {"embed.prefill": torch.tensor([[2.0]])}, + request=None, + ) + output, _ = stub._materialize_full_payload_entry(stub._pending_full_payload_send["req1"]) + assert torch.equal(output["embed.prefill"], torch.tensor([[1.0], [2.0]])) diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 39dd4d04957..fb76fd52bf1 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -22,7 +22,7 @@ from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin from vllm_omni.core.sched.omni_scheduling_coordinator import ( OmniSchedulingCoordinator, - uses_qwen3_omni_full_payload_input_coordinator, + uses_full_payload_input_coordinator, ) from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, @@ -80,7 +80,7 @@ def __init__(self, *args, **kwargs): if getattr(model_config, "async_chunk", False): self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config) self.input_coordinator: OmniSchedulingCoordinator | None = None - if uses_qwen3_omni_full_payload_input_coordinator(model_config): + if uses_full_payload_input_coordinator(model_config): self.input_coordinator = OmniSchedulingCoordinator( scheduler_max_num_seqs=self.vllm_config.scheduler_config.max_num_seqs, stage_id=getattr(model_config, "stage_id", 0), diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index af0b7dcff46..efff106cf63 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -26,7 +26,7 @@ from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin from vllm_omni.core.sched.omni_scheduling_coordinator import ( OmniSchedulingCoordinator, - uses_qwen3_omni_full_payload_input_coordinator, + uses_full_payload_input_coordinator, ) from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs): self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config) self._pending_finish_reqs: list[Request] = [] self.input_coordinator: OmniSchedulingCoordinator | None = None - if uses_qwen3_omni_full_payload_input_coordinator(model_config): + if uses_full_payload_input_coordinator(model_config): self.input_coordinator = OmniSchedulingCoordinator( scheduler_max_num_seqs=self.vllm_config.scheduler_config.max_num_seqs, stage_id=getattr(model_config, "stage_id", 0), diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index 28ded619fc2..a6e844205c5 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -24,13 +24,24 @@ logger = init_logger(__name__) -def uses_qwen3_omni_full_payload_input_coordinator(model_config: Any) -> bool: - return ( - getattr(model_config, "stage_id", 0) > 0 - and not getattr(model_config, "async_chunk", False) - and getattr(model_config, "model_arch", None) == "Qwen3OmniMoeForConditionalGeneration" - and getattr(model_config, "model_stage", None) in {"talker", "code2wav"} - ) +def uses_full_payload_input_coordinator(model_config: Any) -> bool: + """Structural gate: a stage uses the full-payload input coordinator iff + it is a downstream multi-stage stage (stage_id > 0), is not in async_chunk + mode, and has a sync-side input builder wired (detected via the + `_is_sync_input` marker on the resolved `custom_process_input_func`). + + The marker is set per-builder in each model's stage_input_processor + module (e.g. `thinker2talker_token_only._is_sync_input = True` in + qwen3_omni.py). This avoids hard-coding the arch / stage_name whitelist. + """ + if getattr(model_config, "stage_id", 0) <= 0: + return False + if getattr(model_config, "async_chunk", False): + return False + proc = getattr(model_config, "custom_process_input_func", None) + if proc is None or not getattr(proc, "_is_sync_input", False): + return False + return getattr(model_config, "model_stage", None) is not None class OmniSchedulingCoordinator: diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 1ee84cf6019..b56e15a7362 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -35,6 +35,12 @@ # Pooling output layer keys: "0" = word embedding, "24" = accept_hidden_layer _EMBED_LAYER_KEY = "0" _HIDDEN_LAYER_KEY = "24" +# Per-model REPLACE-keys for the full-payload accumulator. Keys in this +# set use REPLACE semantics (subsequent emissions discard prior chunks) +# instead of CONCAT. qwen3-omni currently has none — model_outputs is +# not emitted by the thinker/talker forward. +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + _QWEN3_CODEC_CODEBOOK_SIZE = 2048 _QWEN3_CODEC_PAD_TOKEN_ID = 4196 _QWEN3_CODEC_BOS_TOKEN_ID = 4197 @@ -119,19 +125,6 @@ def _is_valid_qwen3_codec_token_id(token_id: Any) -> bool: return 0 <= token_id < _QWEN3_CODEC_CODEBOOK_SIZE -def should_accumulate_qwen3_omni_full_payload_output( - model_config: Any, - custom_process_func: Any, -) -> bool: - """Return whether Qwen3-Omni should accumulate full-payload outputs.""" - return ( - custom_process_func is not None - and not getattr(model_config, "async_chunk", False) - and getattr(model_config, "model_arch", None) == "Qwen3OmniMoeForConditionalGeneration" - and getattr(model_config, "model_stage", None) in {"thinker", "talker"} - ) - - def _extract_qwen3_full_payload_codec_rows( code_predictor_codes: torch.Tensor, output_token_ids: list[int], @@ -963,3 +956,8 @@ def talker2code2wav( ) return code2wav_inputs + + +# Mark sync-side builders for the structural full-payload gate (see +# should_accumulate_full_payload_output above). +thinker2talker_token_only._is_sync_input = True diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index b316b5ade44..ed36688d58b 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -46,6 +46,24 @@ logger = init_logger(__name__) +def should_accumulate_full_payload_output(model_config, custom_process_func) -> bool: + """Structural gate: accumulate full-payload outputs iff the configured + custom_process_func is a sync-side builder (marked by `_is_sync_input`) + and the stage is not in async_chunk mode. + + Lives at module level so the producer-side `OmniConnectorModelRunnerMixin + ._should_accumulate_full_payload_output()` does not need an arch-specific + import chain (the structural check is arch-agnostic). + """ + if custom_process_func is None: + return False + if getattr(model_config, "async_chunk", False): + return False + if not getattr(custom_process_func, "_is_sync_input", False): + return False + return getattr(model_config, "model_stage", None) is not None + + class OmniConnectorModelRunnerMixin: """Unified data-plane communication mixin for Model Runners. @@ -698,19 +716,12 @@ def _should_accumulate_full_payload_output(self) -> bool: if model_config is None: self._should_accumulate_full_payload_output_cached = False return False - if getattr(model_config, "model_arch", None) == "Qwen3OmniMoeForConditionalGeneration": - from vllm_omni.model_executor.stage_input_processors.qwen3_omni import ( - should_accumulate_qwen3_omni_full_payload_output, - ) - - result = should_accumulate_qwen3_omni_full_payload_output( - model_config, - getattr(self, "_custom_process_func", None), - ) - self._should_accumulate_full_payload_output_cached = result - return result - self._should_accumulate_full_payload_output_cached = False - return False + result = should_accumulate_full_payload_output( + model_config, + getattr(self, "_custom_process_func", None), + ) + self._should_accumulate_full_payload_output_cached = result + return result @staticmethod def _new_full_payload_accumulator(output: dict[str, Any]): @@ -736,6 +747,43 @@ def _materialize_full_payload_entry(entry): output[k] = tensors[0] if len(tensors) == 1 else torch.cat(tensors, dim=0) return output, request + def _resolve_full_payload_replace_keys(self) -> frozenset: + """Per-model REPLACE-key set for the full-payload accumulator. + + Looked up from the SIP module that ships the model's sync builder + (`model_config.custom_process_input_func.__module__`). The module + declares ``_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str]``; if absent, + returns the empty set. + + Cached per instance. Keys in this set use REPLACE semantics in the + accumulator (subsequent emissions discard prior chunks) instead of + the default CONCAT semantics. Use for tensors that carry the full + result so far rather than per-step deltas (e.g. ``model_outputs``). + """ + cached = getattr(self, "_full_payload_replace_keys_cached", None) + if cached is not None: + return cached + proc = getattr(self, "_custom_process_func", None) + if proc is None: + self._full_payload_replace_keys_cached = frozenset() + return self._full_payload_replace_keys_cached + module_name = getattr(proc, "__module__", None) + if module_name is None: + self._full_payload_replace_keys_cached = frozenset() + return self._full_payload_replace_keys_cached + try: + import importlib as _il + import sys as _sys + + mod = _sys.modules.get(module_name) or _il.import_module(module_name) + keys = getattr(mod, "_FULL_PAYLOAD_REPLACE_KEYS", frozenset()) + except ImportError: + keys = frozenset() + if not isinstance(keys, (frozenset, set)): + keys = frozenset() + self._full_payload_replace_keys_cached = frozenset(keys) + return self._full_payload_replace_keys_cached + def accumulate_full_payload_output( self, req_id: str, @@ -758,6 +806,7 @@ def accumulate_full_payload_output( The data is actually sent when ``flush_full_payload_outputs`` is called with the finished request IDs from the next scheduler cycle. """ + replace_keys = self._resolve_full_payload_replace_keys() existing = self._pending_full_payload_send.get(req_id) if existing is None: @@ -773,6 +822,19 @@ def accumulate_full_payload_output( for k, v in pooler_output.items(): if v is None: continue + if k in replace_keys: + # Explicit REPLACE semantics: the new value supersedes any + # prior chunks (e.g. `model_outputs` carries the full result + # so far, not an appendable per-step delta). + latest.pop(k, None) + if isinstance(v, torch.Tensor) and v.dim() >= 2: + chunks[k] = [v] + rows[k] = int(v.shape[0]) + else: + chunks.pop(k, None) + rows.pop(k, None) + latest[k] = v + continue if isinstance(v, torch.Tensor) and v.dim() >= 2: if k in chunks and chunks[k] and v.shape[1:] == chunks[k][0].shape[1:]: chunks[k].append(v) From 2fe03c3092bbc9e22f00310b5c7e2c95bf91638f Mon Sep 17 00:00:00 2001 From: natureofnature Date: Sat, 16 May 2026 03:13:36 +0000 Subject: [PATCH 04/19] [PR3] Per-arch SIP builders + coordinator gate generalization Squash of 9 commits that wire per-arch stage_input_processor (SIP) builders for the worker-connector data plane, plus the coordinator gate generalization that makes them reachable: - [Pilot] covo_audio (Group B llm->code2wav): pilot impl exercising the new structural gate. - dynin_omni Group B (both transitions): SIP builders + yaml wires for stage_configs. - qwen2_5_omni Group B half (talker->code2wav). - mimo_audio Group B (llm->code2wav). - qwen3_tts Group C (talker->code2wav). - cosyvoice3 Group D-ish (text->flow). - ming_flash_omni Group D (thinker->talker). - qwen2_5_omni Group A reduced to D-minimal (thinker->talker structural sync builder). - Coordinator gate generalization (drops hard-coded Qwen3-only check in uses_full_payload_input_coordinator and replaces it with the _FULL_PAYLOAD_INPUT_STAGES (arch, stage) whitelist). Tests in tests/core/sched/test_omni_scheduling_coordinator.py + tests/worker/test_omni_gpu_model_runner.py adjust to the new whitelist contract. Each per-arch SIP commit adds a builder pair (`*_token_only` and `*_full_payload`) and a pipeline.py wire; tests in test_qwen3_omni_streaming_helpers.py cover the structural expectations. Signed-off-by: natureofnature --- .../sched/test_omni_scheduling_coordinator.py | 56 +- .../test_qwen3_omni_streaming_helpers.py | 497 ++++++++++++++++++ tests/worker/test_omni_gpu_model_runner.py | 36 +- .../core/sched/omni_scheduling_coordinator.py | 43 +- .../models/cosyvoice3/pipeline.py | 4 +- .../models/covo_audio/pipeline.py | 2 + .../models/mimo_audio/pipeline.py | 4 +- .../models/ming_flash_omni/pipeline.py | 2 + .../models/qwen2_5_omni/pipeline.py | 5 + .../models/qwen3_tts/pipeline.py | 4 +- .../stage_configs/dynin_omni.yaml | 2 + .../dynin_omni_multiconnector.yaml | 2 + .../stage_input_processors/cosyvoice3.py | 99 ++++ .../stage_input_processors/covo_audio.py | 83 ++- .../stage_input_processors/dynin_omni.py | 121 +++++ .../stage_input_processors/mimo_audio.py | 118 +++++ .../stage_input_processors/ming_flash_omni.py | 92 ++++ .../stage_input_processors/qwen2_5_omni.py | 183 +++++++ .../stage_input_processors/qwen3_tts.py | 195 +++++++ .../omni_connector_model_runner_mixin.py | 29 +- 20 files changed, 1518 insertions(+), 59 deletions(-) diff --git a/tests/core/sched/test_omni_scheduling_coordinator.py b/tests/core/sched/test_omni_scheduling_coordinator.py index 8c7cc20b0d9..d2fea3a3f49 100644 --- a/tests/core/sched/test_omni_scheduling_coordinator.py +++ b/tests/core/sched/test_omni_scheduling_coordinator.py @@ -16,7 +16,7 @@ import vllm_omni.core.sched.omni_scheduling_coordinator as coord_mod from vllm_omni.core.sched.omni_scheduling_coordinator import ( OmniSchedulingCoordinator, - uses_qwen3_omni_full_payload_input_coordinator, + uses_full_payload_input_coordinator, ) # ------------------------------------------------------------------ # @@ -92,7 +92,16 @@ def remove_requests(self, requests): class TestFullPayloadCoordinatorSelection(unittest.TestCase): - def test_qwen3_omni_talker_and_code2wav_use_full_payload_input_coordinator(self): + """Tests for the (model_arch, model_stage) whitelist gate. + + The gate scope must stay aligned with init_omni_connectors arch scope in + gpu_ar_model_runner.py / gpu_generation_model_runner.py. Until those init + sites are generalised (planned for a later PR matching the tmp/trim_refactor + branch shape), only Qwen3-Omni talker / code2wav route full_payload stage + input through the worker connector. + """ + + def test_qwen3_omni_talker_and_code2wav_fire_gate(self): for model_stage in ("talker", "code2wav"): model_config = SimpleNamespace( stage_id=1, @@ -100,39 +109,44 @@ def test_qwen3_omni_talker_and_code2wav_use_full_payload_input_coordinator(self) model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage=model_stage, ) + self.assertTrue( + uses_full_payload_input_coordinator(model_config), + msg=f"expected gate to fire for Qwen3Omni/{model_stage}", + ) - self.assertTrue(uses_qwen3_omni_full_payload_input_coordinator(model_config)) - - def test_async_chunk_and_non_qwen3_omni_do_not_use_full_payload_input_coordinator(self): + def test_other_arch_or_stage_or_mode_does_not_fire(self): cases = [ SimpleNamespace( - stage_id=1, - async_chunk=True, - model_arch="Qwen3OmniMoeForConditionalGeneration", - model_stage="talker", + stage_id=1, async_chunk=True, model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage="talker" ), SimpleNamespace( - stage_id=1, - async_chunk=False, - model_arch="Qwen3TTSForConditionalGeneration", - model_stage="code2wav", + stage_id=0, async_chunk=False, model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage="thinker" ), SimpleNamespace( stage_id=1, async_chunk=False, - model_arch="Qwen2_5OmniForConditionalGeneration", - model_stage="talker", + model_arch="Qwen3OmniMoeForConditionalGeneration", + model_stage="some_future_stage", ), SimpleNamespace( - stage_id=0, - async_chunk=False, - model_arch="Qwen3OmniMoeForConditionalGeneration", - model_stage="thinker", + stage_id=1, async_chunk=False, model_arch="Qwen2_5OmniForConditionalGeneration", model_stage="talker" + ), + SimpleNamespace( + stage_id=1, async_chunk=False, model_arch="Qwen3TTSForConditionalGeneration", model_stage="code2wav" + ), + SimpleNamespace( + stage_id=1, async_chunk=False, model_arch="MingFlashOmniForConditionalGeneration", model_stage="talker" + ), + SimpleNamespace(stage_id=1, async_chunk=False, model_arch=None, model_stage="talker"), + SimpleNamespace( + stage_id=1, async_chunk=False, model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage=None ), ] - for model_config in cases: - self.assertFalse(uses_qwen3_omni_full_payload_input_coordinator(model_config)) + self.assertFalse( + uses_full_payload_input_coordinator(model_config), + msg=f"expected gate OFF for {model_config}", + ) class TestChunkCoordinatorStateTransition(unittest.TestCase): diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index 9f4b453d765..7b176b29b75 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -497,3 +497,500 @@ def __init__(self): ) output, _ = stub._materialize_full_payload_entry(stub._pending_full_payload_send["req1"]) assert torch.equal(output["embed.prefill"], torch.tensor([[1.0], [2.0]])) + + +def test_covo_audio_llm2code2wav_token_only_smoke() -> None: + """Smoke: covo_audio token-only builder marks `_is_sync_input` + and returns placeholder prompts sized to audio_codes count.""" + from vllm_omni.model_executor.stage_input_processors.covo_audio import ( + llm2code2wav_token_only, + ) + + assert getattr(llm2code2wav_token_only, "_is_sync_input", False) is True + + # source_outputs is a list of objects with .outputs[0].token_ids + from vllm_omni.model_executor.models.covo_audio.config_covo_audio import COVO_AUDIO_TOKEN_INDEX + + class _Out: + def __init__(self, tids): + self.token_ids = tids + + class _Wrapper: + def __init__(self, tids): + self.outputs = [_Out(tids)] + + # 3 codec tokens + 2 non-codec + src = [_Wrapper([COVO_AUDIO_TOKEN_INDEX + 0, COVO_AUDIO_TOKEN_INDEX + 1, COVO_AUDIO_TOKEN_INDEX + 2, 100, 200])] + out = llm2code2wav_token_only(src) + assert len(out) == 1 + assert len(out[0]["prompt_token_ids"]) == 3 + assert out[0]["additional_information"] is None + + +def test_covo_audio_llm2code2wav_full_payload_smoke() -> None: + """Smoke: covo_audio producer-side packer returns audio_codes + finished.""" + from types import SimpleNamespace + + from vllm_omni.model_executor.models.covo_audio.config_covo_audio import COVO_AUDIO_TOKEN_INDEX + from vllm_omni.model_executor.stage_input_processors.covo_audio import ( + llm2code2wav_full_payload, + ) + + req = SimpleNamespace( + output_token_ids=[COVO_AUDIO_TOKEN_INDEX + 5, COVO_AUDIO_TOKEN_INDEX + 6, 99], + ) + payload = llm2code2wav_full_payload(None, {}, req) + assert payload is not None + assert payload["codes"]["audio"] == [5, 6] + assert payload["meta"]["finished"].item() is True + + +def test_dynin_omni_token_only_smoke() -> None: + """Smoke: dynin_omni token-only builders mark _is_sync_input and return placeholders.""" + from vllm_omni.model_executor.stage_input_processors.dynin_omni import ( + token2image_to_token2audio_token_only, + token2text_to_token2image_token_only, + ) + + assert getattr(token2text_to_token2image_token_only, "_is_sync_input", False) is True + assert getattr(token2image_to_token2audio_token_only, "_is_sync_input", False) is True + + class _Out: + def __init__(self, tids, mm=None): + self.token_ids = tids + self.multimodal_output = mm + + class _Wrapper: + def __init__(self, tids, mm=None): + self.outputs = [_Out(tids, mm)] + self.request_id = "r0" + + class _Stage: + def __init__(self, outs): + self.engine_outputs = outs + + src = [_Wrapper([10, 11, 12])] + out = token2text_to_token2image_token_only([_Stage(src)], [0]) + assert len(out) == 1 + assert len(out[0]["prompt_token_ids"]) == 3 + assert out[0]["additional_information"] is None + + +def test_dynin_omni_full_payload_smoke() -> None: + """Smoke: dynin_omni producer-side packer returns token_ids + finished.""" + from types import SimpleNamespace + + from vllm_omni.model_executor.stage_input_processors.dynin_omni import ( + token2text_to_token2image_full_payload, + ) + + pooling = {"token_ids": [1, 2, 3]} + req = SimpleNamespace(output_token_ids=[], additional_information={"speaker": ["alice"]}) + payload = token2text_to_token2image_full_payload(None, pooling, req) + assert payload is not None + assert payload["code_predictor_codes"] == [1, 2, 3] + assert payload["finished"].item() is True + # additional_information carried forward as list-wrapped (speaker) + assert payload.get("speaker") == ["alice"] + + +def test_qwen2_5_omni_talker2code2wav_token_only_smoke() -> None: + """Smoke: qwen2_5_omni talker→code2wav token_only marker + boundary strip.""" + from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( + TALKER_CODEC_END_TOKEN_ID, + TALKER_CODEC_START_TOKEN_ID, + talker2code2wav_token_only, + ) + + assert getattr(talker2code2wav_token_only, "_is_sync_input", False) is True + + class _Out: + def __init__(self, tids): + self.cumulative_token_ids = tids + + class _Wrap: + def __init__(self, tids): + self.outputs = [_Out(tids)] + + # 3 inner codes wrapped by START + END + src = [_Wrap([TALKER_CODEC_START_TOKEN_ID, 10, 11, 12, TALKER_CODEC_END_TOKEN_ID])] + out = talker2code2wav_token_only(src) + assert len(out) == 1 + assert len(out[0]["prompt_token_ids"]) == 3 + assert out[0]["additional_information"] is None + + +def test_qwen2_5_omni_talker2code2wav_full_payload_smoke() -> None: + """Smoke: qwen2_5_omni producer-side packer strips boundaries.""" + from types import SimpleNamespace + + from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( + TALKER_CODEC_END_TOKEN_ID, + TALKER_CODEC_START_TOKEN_ID, + talker2code2wav_full_payload, + ) + + req = SimpleNamespace( + output_token_ids=[TALKER_CODEC_START_TOKEN_ID, 5, 6, 7, TALKER_CODEC_END_TOKEN_ID], + ) + payload = talker2code2wav_full_payload(None, {}, req) + assert payload is not None + assert payload["codes"]["audio"] == [5, 6, 7] + assert payload["meta"]["finished"].item() is True + + +def test_mimo_audio_llm2code2wav_token_only_smoke() -> None: + """Smoke: mimo_audio token-only builder marks _is_sync_input + sizes prompt.""" + import torch + + from vllm_omni.model_executor.stage_input_processors.mimo_audio import ( + llm2code2wav_token_only, + ) + + assert getattr(llm2code2wav_token_only, "_is_sync_input", False) is True + + class _Out: + def __init__(self, mm): + self.multimodal_output = mm + + class _Wrap: + def __init__(self, mm): + self.outputs = [_Out(mm)] + + # 3 batch rows of [1, 8, 4]: prepend_and_flatten_colmajor → 3*1*4*9 = 108 + codes = torch.arange(96, dtype=torch.long).reshape(3, 1, 8, 4) + codes = codes.clamp(min=1) # ensure nonzero so zero-row filter doesn't drop them + src = [_Wrap({"codes": {"audio": codes}})] + out = llm2code2wav_token_only(src) + assert len(out) == 1 + assert len(out[0]["prompt_token_ids"]) == 108 + assert out[0]["additional_information"] is None + + +def test_mimo_audio_llm2code2wav_full_payload_smoke() -> None: + """Smoke: mimo_audio producer-side packer reads flat codes.audio + flattens.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.mimo_audio import ( + TALKER_CODEC_PAD_TOKEN_ID, + llm2code2wav_full_payload, + ) + + # Simulate accumulator output: 2 steps of [1, 1, 8, 4] CONCAT'd → [2, 1, 8, 4] + audio = torch.arange(2 * 1 * 8 * 4, dtype=torch.long).reshape(2, 1, 8, 4) + audio = audio.clamp(min=1) # avoid zero-row drop + pooling_output = {"codes.audio": audio} + req = SimpleNamespace(output_token_ids=[]) + payload = llm2code2wav_full_payload(None, pooling_output, req) + assert payload is not None + assert "codes" in payload and "audio" in payload["codes"] + # Flattened length = numel + B*4 (per-batch pad_vec prepended by prepend_and_flatten_colmajor) + batch_size = int(audio.shape[0]) + assert len(payload["codes"]["audio"]) == audio.numel() + batch_size * 4 + # prepend_and_flatten_colmajor: PAD appears at column start in col-major flatten. + # For shape [B=2, 1, 9, 4], each column has 1 PAD then 8 codec vals → PAD at indices 0, 9, 18, 27. + out = payload["codes"]["audio"] + assert out[0] == TALKER_CODEC_PAD_TOKEN_ID + assert out[9] == TALKER_CODEC_PAD_TOKEN_ID + assert payload["meta"]["finished"].item() is True + + +def test_mimo_audio_full_payload_nested_fallback() -> None: + """Back-compat: full_payload still works if runtime returns nested codes.audio.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.mimo_audio import ( + llm2code2wav_full_payload, + ) + + audio = torch.arange(1 * 1 * 8 * 4, dtype=torch.long).reshape(1, 1, 8, 4) + audio = audio.clamp(min=1) + pooling_output = {"codes": {"audio": audio}} # nested, not flat + req = SimpleNamespace(output_token_ids=[]) + payload = llm2code2wav_full_payload(None, pooling_output, req) + assert payload is not None + assert len(payload["codes"]["audio"]) == audio.numel() + int(audio.shape[0]) * 4 + + +def test_qwen3_tts_talker2code2wav_token_only_smoke() -> None: + """Smoke: qwen3_tts token-only marks _is_sync_input + sizes placeholder.""" + import torch + + from vllm_omni.model_executor.stage_input_processors.qwen3_tts import ( + talker2code2wav_token_only, + ) + + assert getattr(talker2code2wav_token_only, "_is_sync_input", False) is True + + class _Out: + def __init__(self, mm, tids): + self.multimodal_output = mm + self.cumulative_token_ids = tids + + class _Wrap: + def __init__(self, mm, tids): + self.outputs = [_Out(mm, tids)] + self.finished = True + + # 3 valid codec frames Q=16; non-zero & under codebook size + audio = torch.arange(3 * 16, dtype=torch.long).reshape(3, 16) + 1 + mm = {"codes": {"audio": audio}} + src = [_Wrap(mm, list(range(10)))] # seq_len = 9; 3 < 9, no trim + out = talker2code2wav_token_only(src) + assert len(out) == 1 + # Codebook-major flat: 16 * 3 = 48 + assert len(out[0]["prompt_token_ids"]) == 48 + + +def test_qwen3_tts_talker2code2wav_full_payload_smoke() -> None: + """Smoke: qwen3_tts full_payload reads flat codes.audio + flattens col-major.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.qwen3_tts import ( + talker2code2wav_full_payload, + ) + + # 3 valid codec frames [3, 16] CONCAT'd from per-step emits via flatten + audio = torch.arange(3 * 16, dtype=torch.long).reshape(3, 16) + 1 + pooling_output = {"codes.audio": audio} + req = SimpleNamespace(output_token_ids=list(range(10))) # seq_len = 9 + payload = talker2code2wav_full_payload(None, pooling_output, req) + assert payload is not None + assert "codes" in payload and "audio" in payload["codes"] + # codebook-major: shape [3, 16] -> [16, 3] -> flatten = 48 entries + assert len(payload["codes"]["audio"]) == 48 + assert payload["meta"]["finished"].item() is True + + +def test_qwen3_tts_full_payload_with_ref_code() -> None: + """Smoke: ref_code prepended via codes.ref + meta.ref_code_len from flat path.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.qwen3_tts import ( + talker2code2wav_full_payload, + ) + + # Audio: 3 frames [3, 16] + audio = torch.arange(3 * 16, dtype=torch.long).reshape(3, 16) + 1 + # Ref code: 2 frames [2, 16] (already 2-D) + ref = torch.arange(2 * 16, dtype=torch.long).reshape(2, 16) + 100 + pooling_output = { + "codes.audio": audio, + "codes.ref": [ref], + "meta.ref_code_len": torch.tensor([2], dtype=torch.int32), + } + req = SimpleNamespace(output_token_ids=list(range(10))) + payload = talker2code2wav_full_payload(None, pooling_output, req) + assert payload is not None + # Total frames = 2 (ref) + 3 (audio) = 5; codebook-major: 16 * 5 = 80 + assert len(payload["codes"]["audio"]) == 80 + + +def test_qwen3_tts_full_payload_nested_fallback() -> None: + """Back-compat: full_payload works if pooler returns un-flattened nested dict.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.qwen3_tts import ( + talker2code2wav_full_payload, + ) + + audio = torch.arange(2 * 16, dtype=torch.long).reshape(2, 16) + 1 + pooling_output = {"codes": {"audio": audio}} # nested, not flat + req = SimpleNamespace(output_token_ids=list(range(10))) + payload = talker2code2wav_full_payload(None, pooling_output, req) + assert payload is not None + assert len(payload["codes"]["audio"]) == 32 # 16 * 2 + + +def test_cosyvoice3_text2flow_token_only_smoke() -> None: + """Smoke: cosyvoice3 token-only marks _is_sync_input + carries ids.prompt only.""" + from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import ( + text2flow_token_only, + ) + + assert getattr(text2flow_token_only, "_is_sync_input", False) is True + + class _Out: + def __init__(self, tids): + self.cumulative_token_ids = tids + self.multimodal_output = {} + + class _Wrap: + def __init__(self, output_tids, prompt_tids): + self.outputs = [_Out(output_tids)] + self.prompt_token_ids = prompt_tids + self.finished = True + + # multimodal_output has embed.* + we expect token_only to preserve it (Phase 4 #90 follow-up). + import torch + + embed = {"speech_token": torch.zeros(2, 4)} + src = [_Wrap(output_tids=[10, 20, 30], prompt_tids=[1, 2, 3, 4])] + src[0].outputs[0].multimodal_output = {"embed": embed} + out = text2flow_token_only(src) + assert len(out) == 1 + # prompt_token_ids is the talker's cumulative_token_ids (real codec tokens, not zeros). + assert out[0]["prompt_token_ids"] == [10, 20, 30] + # additional_information carries ids.prompt PLUS the original multimodal_output (embed.* still inline). + # Heavy embed.* removal pending the model_intermediate_buffer plumbing on the code2wav side. + assert out[0]["additional_information"]["ids"]["prompt"] == [1, 2, 3, 4] + assert "embed" in out[0]["additional_information"] + + +def test_cosyvoice3_text2flow_full_payload_smoke() -> None: + """Smoke: cosyvoice3 producer-side reads flat embed.* keys.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import ( + text2flow_full_payload, + ) + + speech_token = torch.randn(4, 8) + speech_feat = torch.randn(4, 16) + embedding = torch.randn(1, 32) + pooling_output = { + "embed.speech_token": speech_token, + "embed.speech_feat": speech_feat, + "embed.embedding": embedding, + } + req = SimpleNamespace(external_req_id="r-1") + payload = text2flow_full_payload(None, pooling_output, req) + assert payload is not None + assert "embed" in payload + assert torch.equal(payload["embed"]["speech_token"], speech_token) + assert torch.equal(payload["embed"]["speech_feat"], speech_feat) + assert torch.equal(payload["embed"]["embedding"], embedding) + assert payload["meta"]["finished"].item() is True + + +def test_cosyvoice3_text2flow_full_payload_nested_fallback() -> None: + """Back-compat: full_payload works if pooler returns un-flattened nested embed.""" + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import ( + text2flow_full_payload, + ) + + speech_token = torch.randn(3, 8) + pooling_output = {"embed": {"speech_token": speech_token}} # nested, not flat + req = SimpleNamespace(external_req_id="r-2") + payload = text2flow_full_payload(None, pooling_output, req) + assert payload is not None + assert "speech_token" in payload["embed"] + assert torch.equal(payload["embed"]["speech_token"], speech_token) + + +def test_cosyvoice3_full_payload_replace_keys_present() -> None: + """Confirm _FULL_PAYLOAD_REPLACE_KEYS lists the three embed.* keys.""" + from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import ( + _FULL_PAYLOAD_REPLACE_KEYS, + ) + + assert _FULL_PAYLOAD_REPLACE_KEYS == frozenset({"embed.speech_token", "embed.speech_feat", "embed.embedding"}) + + +def test_ming_flash_omni_thinker2talker_token_only_smoke() -> None: + """Smoke: ming_flash_omni token-only marks _is_sync_input + carries voice metadata.""" + from vllm_omni.model_executor.stage_input_processors.ming_flash_omni import ( + thinker2talker_token_only, + ) + + assert getattr(thinker2talker_token_only, "_is_sync_input", False) is True + + class _Out: + def __init__(self, text): + self.text = text + + class _Wrap: + def __init__(self, text): + self.outputs = [_Out(text)] + + class _Prompt: + def __init__(self, info): + self.additional_information = info + + src = [_Wrap("hello world")] + prompt = _Prompt({"voice_name": "ZH_FEMALE", "prompt_text": "ref text"}) + out = thinker2talker_token_only(src, prompt=prompt) + assert len(out) == 1 + assert out[0]["prompt_token_ids"] == [0] # talker self-tokenizes; dummy id + info = out[0]["additional_information"] + assert info["text"] == "hello world" + assert info["voice_name"] == "ZH_FEMALE" + assert info["prompt_text"] == "ref text" + assert info["ming_task"] == "omni" + + +def test_ming_flash_omni_thinker2talker_full_payload_noop() -> None: + """thinker2talker_full_payload returns None — no heavy tensor migration.""" + from vllm_omni.model_executor.stage_input_processors.ming_flash_omni import ( + thinker2talker_full_payload, + ) + + payload = thinker2talker_full_payload(None, {"anything": "ignored"}, None) + assert payload is None + + +def test_qwen2_5_omni_thinker2talker_token_only_smoke() -> None: + """Smoke: qwen2_5_omni thinker token-only marks _is_sync_input + ports legacy body.""" + import torch + + from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( + TALKER_CODEC_END_TOKEN_ID, + TALKER_CODEC_PAD_TOKEN_ID, + TALKER_CODEC_START_TOKEN_ID, + thinker2talker_token_only, + ) + + assert getattr(thinker2talker_token_only, "_is_sync_input", False) is True + + class _Out: + def __init__(self, ctids, mm): + self.cumulative_token_ids = ctids + self.multimodal_output = mm + + class _Wrap: + def __init__(self, prompt_tids, ctids, mm, rid): + self.outputs = [_Out(ctids, mm)] + self.prompt_token_ids = prompt_tids + self.request_id = rid + + class _Prompt(dict): + pass + + # Latent shaped [prompt_len + decode_len, hidden] = [5 + 3, 8] + latent = torch.randn(8, 8) + src = [_Wrap(prompt_tids=[1, 2, 3, 4, 5], ctids=[10, 20, 30], mm={"latent": latent}, rid="r-1")] + prompt = [_Prompt(multi_modal_data=None)] + out = thinker2talker_token_only(src, prompt=prompt) + assert len(out) == 1 + # Talker prompt = START + PAD*prompt_len + END + expected_prompt_len = 1 + len([1, 2, 3, 4, 5]) + 1 + assert len(out[0]["prompt_token_ids"]) == expected_prompt_len + assert out[0]["prompt_token_ids"][0] == TALKER_CODEC_START_TOKEN_ID + assert out[0]["prompt_token_ids"][-1] == TALKER_CODEC_END_TOKEN_ID + assert all(t == TALKER_CODEC_PAD_TOKEN_ID for t in out[0]["prompt_token_ids"][1:-1]) + + +def test_qwen2_5_omni_thinker2talker_full_payload_noop() -> None: + """thinker2talker_full_payload returns None — no heavy tensor migration today.""" + from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( + thinker2talker_full_payload, + ) + + payload = thinker2talker_full_payload(None, {"any": "thing"}, None) + assert payload is None diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index b834d8733b0..94ffd937ead 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -472,18 +472,40 @@ def test_accumulate_full_payload_output_keeps_all_zero_qwen3_omni_prefill_placeh def test_full_payload_output_accumulation_hook_matrix(): + """Producer-side gate: fires iff custom_process_func is loaded and not async_chunk. + + Phase 2a generalized the gate from an arch + stage whitelist to a structural + check on the loaded packer. `_custom_process_func is None` short-circuits; + that maps to terminal stages (e.g. code2wav, qwen3_tts code2wav, qwen2_5 + code2wav) whose stage_config has no `custom_process_next_stage_input_func` + and no `*_full_payload` derivative of `custom_process_input_func`. + """ + # Thinker / talker producer stages: packer loaded -> gate fires. assert _make_full_payload_accumulation_runner(model_stage="thinker")._should_accumulate_full_payload_output() assert _make_full_payload_accumulation_runner(model_stage="talker")._should_accumulate_full_payload_output() - assert not _make_full_payload_accumulation_runner(model_stage="code2wav")._should_accumulate_full_payload_output() + + # Terminal stage: emulate `_load_custom_func` returning None (no downstream). + runner = _make_full_payload_accumulation_runner(model_stage="code2wav") + runner._custom_process_func = None + runner._should_accumulate_full_payload_output_cached = None + assert not runner._should_accumulate_full_payload_output() + + # async_chunk mode -> gate off. assert not _make_full_payload_accumulation_runner( model_stage="talker", async_chunk=True )._should_accumulate_full_payload_output() - assert not _make_full_payload_accumulation_runner( - model_arch="Qwen3TTSForConditionalGeneration" - )._should_accumulate_full_payload_output() - assert not _make_full_payload_accumulation_runner( - model_arch="Qwen2_5OmniForConditionalGeneration" - )._should_accumulate_full_payload_output() + + # Non-qwen3 arches: gate is now arch-agnostic, but if the fixture's arch + # has no PR3 wire its runtime `_custom_process_func` would be None. + # Emulate that. + runner = _make_full_payload_accumulation_runner(model_arch="Qwen3TTSForConditionalGeneration") + runner._custom_process_func = None + runner._should_accumulate_full_payload_output_cached = None + assert not runner._should_accumulate_full_payload_output() + runner = _make_full_payload_accumulation_runner(model_arch="Qwen2_5OmniForConditionalGeneration") + runner._custom_process_func = None + runner._should_accumulate_full_payload_output_cached = None + assert not runner._should_accumulate_full_payload_output() def test_sync_local_stage_payloads_retains_payload_until_request_is_active(): diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index a6e844205c5..1aaa486873f 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -24,24 +24,43 @@ logger = init_logger(__name__) +# (arch, model_stage) pairs that route their full_payload stage input via +# the worker connector and therefore need the scheduler-side coordinator to +# park requests in WAITING_FOR_INPUT until the recv side delivers. This set +# must stay aligned with the arch scope of `init_omni_connectors` in +# gpu_ar_model_runner.py and gpu_generation_model_runner.py. Adding a stage +# here without also wiring its worker connector init produces a permanent +# Stage 1 hang (gate parks the request, no transport ever releases it). +# +# The `_is_sync_input` markers on per-model `*_token_only` builders in +# stage_input_processors/ remain as forward-compat documentation; when init +# is generalised (see tmp/trim_refactor branch) this whitelist can move back +# to a structural marker check or be dropped entirely. +_FULL_PAYLOAD_INPUT_STAGES: frozenset[tuple[str, str]] = frozenset( + { + ("Qwen3OmniMoeForConditionalGeneration", "talker"), + ("Qwen3OmniMoeForConditionalGeneration", "code2wav"), + } +) + + def uses_full_payload_input_coordinator(model_config: Any) -> bool: - """Structural gate: a stage uses the full-payload input coordinator iff - it is a downstream multi-stage stage (stage_id > 0), is not in async_chunk - mode, and has a sync-side input builder wired (detected via the - `_is_sync_input` marker on the resolved `custom_process_input_func`). - - The marker is set per-builder in each model's stage_input_processor - module (e.g. `thinker2talker_token_only._is_sync_input = True` in - qwen3_omni.py). This avoids hard-coding the arch / stage_name whitelist. + """Returns True iff this stage parks pending requests in + WAITING_FOR_INPUT awaiting a full_payload delivery on the worker connector. + + Gated by (model_arch, model_stage) — see _FULL_PAYLOAD_INPUT_STAGES for the + rationale on why this is a whitelist instead of a marker-driven structural + gate. """ if getattr(model_config, "stage_id", 0) <= 0: return False if getattr(model_config, "async_chunk", False): return False - proc = getattr(model_config, "custom_process_input_func", None) - if proc is None or not getattr(proc, "_is_sync_input", False): - return False - return getattr(model_config, "model_stage", None) is not None + key = ( + getattr(model_config, "model_arch", None), + getattr(model_config, "model_stage", None), + ) + return key in _FULL_PAYLOAD_INPUT_STAGES class OmniSchedulingCoordinator: diff --git a/vllm_omni/model_executor/models/cosyvoice3/pipeline.py b/vllm_omni/model_executor/models/cosyvoice3/pipeline.py index 4480a0dd831..ed35c93bd13 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/pipeline.py +++ b/vllm_omni/model_executor/models/cosyvoice3/pipeline.py @@ -31,6 +31,7 @@ owns_tokenizer=True, engine_output_type="latent", async_chunk_process_next_stage_input_func=(f"{_PROC}.talker2code2wav_async_chunk"), + custom_process_next_stage_input_func=f"{_PROC}.text2flow_full_payload", sampling_constraints={ # merged speech stop token (logsumexp of all 200 stop logits) "stop_token_ids": [6562], @@ -44,7 +45,8 @@ final_output=True, final_output_type="audio", engine_output_type="latent", - sync_process_input_func=f"{_PROC}.text2flow", + custom_process_input_func=f"{_PROC}.text2flow", + sync_process_input_func=f"{_PROC}.text2flow_token_only", ), ), ) diff --git a/vllm_omni/model_executor/models/covo_audio/pipeline.py b/vllm_omni/model_executor/models/covo_audio/pipeline.py index 5b1a31d6ea8..97053e3286f 100644 --- a/vllm_omni/model_executor/models/covo_audio/pipeline.py +++ b/vllm_omni/model_executor/models/covo_audio/pipeline.py @@ -29,6 +29,7 @@ owns_tokenizer=True, requires_multimodal_data=True, engine_output_type="latent", + custom_process_next_stage_input_func=f"{_PROC}.llm2code2wav_full_payload", sampling_constraints={ "detokenize": True, "stop_token_ids": [151645], @@ -44,6 +45,7 @@ final_output_type="audio", engine_output_type="audio", custom_process_input_func=f"{_PROC}.llm2code2wav", + sync_process_input_func=f"{_PROC}.llm2code2wav_token_only", sampling_constraints={"detokenize": False}, ), ), diff --git a/vllm_omni/model_executor/models/mimo_audio/pipeline.py b/vllm_omni/model_executor/models/mimo_audio/pipeline.py index 70d14ef78aa..126c901763c 100644 --- a/vllm_omni/model_executor/models/mimo_audio/pipeline.py +++ b/vllm_omni/model_executor/models/mimo_audio/pipeline.py @@ -39,6 +39,7 @@ owns_tokenizer=True, engine_output_type="latent", async_chunk_process_next_stage_input_func=(f"{_PROC}.llm2code2wav_async_chunk"), + custom_process_next_stage_input_func=f"{_PROC}.llm2code2wav_full_payload", sampling_constraints={ "detokenize": True, # Stop once the speech/text interleaved span ends. Code2Wav @@ -55,7 +56,8 @@ final_output=True, final_output_type="audio", engine_output_type="audio", - sync_process_input_func=f"{_PROC}.llm2code2wav", + custom_process_input_func=f"{_PROC}.llm2code2wav", + sync_process_input_func=f"{_PROC}.llm2code2wav_token_only", sampling_constraints={"detokenize": False}, ), ), diff --git a/vllm_omni/model_executor/models/ming_flash_omni/pipeline.py b/vllm_omni/model_executor/models/ming_flash_omni/pipeline.py index a9d66fbc22b..c818aebad3d 100644 --- a/vllm_omni/model_executor/models/ming_flash_omni/pipeline.py +++ b/vllm_omni/model_executor/models/ming_flash_omni/pipeline.py @@ -42,6 +42,7 @@ # Thinker reads the LLM sub-config of BailingMM2Config hf_config_name="llm_config", engine_output_type="text", + custom_process_next_stage_input_func=f"{_PROC}.thinker2talker_full_payload", sampling_constraints={"detokenize": True}, ), StagePipelineConfig( @@ -58,6 +59,7 @@ engine_output_type="audio", tokenizer_subdir="talker/llm", custom_process_input_func=f"{_PROC}.thinker2talker", + sync_process_input_func=f"{_PROC}.thinker2talker_token_only", ), ), ) diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py b/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py index de0644803b5..afd0a92a531 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py @@ -30,6 +30,7 @@ requires_multimodal_data=True, engine_output_type="latent", sampling_constraints={"detokenize": True}, + custom_process_next_stage_input_func=f"{_PROC}.thinker2talker_full_payload", ), StagePipelineConfig( stage_id=1, @@ -38,6 +39,8 @@ input_sources=(0,), engine_output_type="latent", custom_process_input_func=f"{_PROC}.thinker2talker", + sync_process_input_func=f"{_PROC}.thinker2talker_token_only", + custom_process_next_stage_input_func=f"{_PROC}.talker2code2wav_full_payload", sampling_constraints={ "detokenize": True, "stop_token_ids": [8294], @@ -52,6 +55,7 @@ final_output_type="audio", engine_output_type="audio", custom_process_input_func=f"{_PROC}.talker2code2wav", + sync_process_input_func=f"{_PROC}.talker2code2wav_token_only", sampling_constraints={"detokenize": True}, ), ), @@ -74,6 +78,7 @@ requires_multimodal_data=True, engine_output_type="latent", sampling_constraints={"detokenize": True}, + custom_process_next_stage_input_func=f"{_PROC}.thinker2talker_full_payload", ), ), ) diff --git a/vllm_omni/model_executor/models/qwen3_tts/pipeline.py b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py index 5051715ceac..d37dd23c4fe 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/pipeline.py +++ b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py @@ -26,6 +26,7 @@ owns_tokenizer=True, engine_output_type="latent", async_chunk_process_next_stage_input_func=(f"{_PROC}.talker2code2wav_async_chunk"), + custom_process_next_stage_input_func=f"{_PROC}.talker2code2wav_full_payload", sampling_constraints={ "detokenize": False, "stop_token_ids": [2150], @@ -40,7 +41,8 @@ final_output_type="audio", engine_output_type="audio", model_arch="Qwen3TTSCode2Wav", - sync_process_input_func=f"{_PROC}.talker2code2wav", + custom_process_input_func=f"{_PROC}.talker2code2wav", + sync_process_input_func=f"{_PROC}.talker2code2wav_token_only", sampling_constraints={"detokenize": True}, extras={"tts_args": {"max_instructions_length": 500}}, ), diff --git a/vllm_omni/model_executor/stage_configs/dynin_omni.yaml b/vllm_omni/model_executor/stage_configs/dynin_omni.yaml index 131a0d1cd70..0e7171eb9ee 100644 --- a/vllm_omni/model_executor/stage_configs/dynin_omni.yaml +++ b/vllm_omni/model_executor/stage_configs/dynin_omni.yaml @@ -6,6 +6,7 @@ stage_args: max_batch_size: 1 engine_args: model_stage: token2text + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image_full_payload model_arch: DyninOmniForConditionalGeneration worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler @@ -39,6 +40,7 @@ stage_args: max_num_batched_tokens: 32768 engine_input_source: [0] custom_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image + sync_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image_token_only final_output: true final_output_type: image diff --git a/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml index 4a54f8188aa..0189718cea0 100644 --- a/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml +++ b/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml @@ -6,6 +6,7 @@ stage_args: max_batch_size: 1 engine_args: model_stage: token2text + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image_full_payload model_arch: DyninOmniForConditionalGeneration worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler @@ -38,6 +39,7 @@ stage_args: max_num_batched_tokens: 32768 engine_input_source: [0] custom_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image + sync_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image_token_only final_output: true final_output_type: image input_connectors: diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index cf1ca39ee59..912c227fb5e 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -271,3 +271,102 @@ def talker2code2wav_async_chunk( state["emitted_chunks"] = int(state.get("emitted_chunks", 0)) + 1 return payload + + +# ============================================================================ +# PR3 worker-connector data plane (non-async-chunk path) — Group D-ish. +# cosyvoice3 talker emits `multimodal_outputs={"embed": {"speech_token": t, +# "speech_feat": t, "embedding": t}}` ONLY at prefill (decode steps emit +# `{}`). After flatten_payload (data_entry_keys.py:280-302) these become +# flat top-level keys `embed.speech_token` etc., persisted across decode +# steps by the accumulator (decode doesn't re-emit them). Shipping via +# the connector keeps the orchestrator off the heavy-tensor path. +# ============================================================================ + +# All three embed tensors are emitted once at prefill and must REPLACE-not- +# CONCAT across the (already trivial) per-request accumulator history so a +# regression where decode unexpectedly re-emits them does not silently +# duplicate the prefill tensor. See mixin._FULL_PAYLOAD_REPLACE_KEYS. +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset({"embed.speech_token", "embed.speech_feat", "embed.embedding"}) + + +def text2flow_token_only( + source_outputs: list, + prompt: OmniTokensPrompt | TextPrompt = None, + _requires_multimodal_data: bool = True, +): + """Sync-side builder for the non-async-chunk text→flow path. + + Mirrors the legacy `text2flow` shape but strips the prefill embed + tensors out of `additional_information` — those travel via the worker + connector payload built by `text2flow_full_payload`. Small metadata + (`ids.prompt` = the talker's original prompt token prefix) stays + inline so the flow stage can still locate the prefix. + + prompt_token_ids = the talker's `cumulative_token_ids` (real codec + tokens, not [0]*N) because the flow stage consumes talker output + verbatim. + """ + del prompt + engine_inputs: list[OmniTokensPrompt] = [] + for source_output in source_outputs: + if not source_output.finished: + continue + output = source_output.outputs[0] + output_ids = _ensure_list(output.cumulative_token_ids) + prefix_ids = _ensure_list(source_output.prompt_token_ids) + # Preserve full multimodal_output (incl. embed.*) in additional_information + # for now: the worker connector wire to model_intermediate_buffer on the + # code2wav side is not yet plumbed. Filed as Phase 4 follow-up (parallel + # to #90 speaker/language re-routing). text2flow_full_payload still ships + # embed.* via the connector — currently redundant but ready for activation. + multi_modal_data = output.multimodal_output + if multi_modal_data is None: + raise RuntimeError(f"Missing multimodal_output for request {source_output.request_id}") + additional_info: dict[str, Any] = dict(multi_modal_data) + additional_info.setdefault("ids", {})["prompt"] = prefix_ids + engine_inputs.append( + OmniTokensPrompt( + prompt_token_ids=output_ids, + additional_information=additional_info, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + return engine_inputs + + +text2flow_token_only._is_sync_input = True + + +def text2flow_full_payload( + transfer_manager, + pooling_output, + request, +): + """Producer-side packer. + + Reads the prefill-emitted `embed.{speech_token, speech_feat, embedding}` + from the accumulator (flat dotted keys after flatten_payload; nested + fallback for safety) and ships them as a single connector payload. + The downstream flow stage reads these from `model_intermediate_buffer` + (see cosyvoice3.py:671 in the code2wav forward — runtime_info pickup). + """ + del transfer_manager, request + if not isinstance(pooling_output, dict): + return None + embed_out: dict[str, Any] = {} + for key in ("speech_token", "speech_feat", "embedding"): + v = pooling_output.get(f"embed.{key}") + if v is None: + nested = pooling_output.get("embed") + if isinstance(nested, dict): + v = nested.get(key) + if isinstance(v, torch.Tensor) and v.numel() > 0: + embed_out[key] = v + if not embed_out: + return None + return { + "embed": embed_out, + "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, + } diff --git a/vllm_omni/model_executor/stage_input_processors/covo_audio.py b/vllm_omni/model_executor/stage_input_processors/covo_audio.py index a0a964bdd2f..52ca8a44bca 100644 --- a/vllm_omni/model_executor/stage_input_processors/covo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/covo_audio.py @@ -1,33 +1,102 @@ # Copyright 2026 Tencent. from typing import Any +import torch + from vllm_omni.inputs.data import OmniTokensPrompt from vllm_omni.model_executor.models.covo_audio.config_covo_audio import COVO_AUDIO_TOKEN_INDEX +# Per-model REPLACE-keys for the full-payload accumulator (none for covo_audio: +# the producer side does not emit per-step hidden_states / model_outputs; +# llm2code2wav_full_payload reads token_ids directly from `request`). +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + + +def _filter_audio_codes(token_ids: list[int]) -> list[int]: + """Filter codec-range token ids and rebase by COVO_AUDIO_TOKEN_INDEX.""" + audio_codes = [t - COVO_AUDIO_TOKEN_INDEX for t in token_ids if t >= COVO_AUDIO_TOKEN_INDEX] + if not audio_codes: + audio_codes = [-1] + return audio_codes + def llm2code2wav( source_outputs: list[Any], prompt: Any = None, requires_multimodal_data: bool = False, ) -> list[OmniTokensPrompt]: + """Legacy orchestrator-path builder (retained for async_chunk + back-compat). + + The non-async-chunk path now goes through ``llm2code2wav_token_only`` + + worker connector + ``llm2code2wav_full_payload`` (PR3). + """ talker_outputs = source_outputs code2wav_inputs = [] - for i, talker_output in enumerate(talker_outputs): + for talker_output in talker_outputs: output = talker_output.outputs[0] - token_ids = output.token_ids + audio_codes = _filter_audio_codes(list(output.token_ids)) + code2wav_inputs.append( + OmniTokensPrompt( + prompt_token_ids=audio_codes, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) - audio_codes = [t - COVO_AUDIO_TOKEN_INDEX for t in token_ids if t >= COVO_AUDIO_TOKEN_INDEX] + return code2wav_inputs - if not audio_codes: - audio_codes = [-1] +def llm2code2wav_token_only( + source_outputs: list[Any], + prompt: Any = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """Sync-side placeholder for the non-async-chunk Stage-1 input. + + Returns an OmniTokensPrompt sized to the code2wav stage's expected + prefill length (one slot per audio code). The actual codec ids are + delivered via the worker connector payload built by + ``llm2code2wav_full_payload``. + """ + code2wav_inputs: list[OmniTokensPrompt] = [] + for output_wrapper in source_outputs: + output = output_wrapper.outputs[0] + audio_codes = _filter_audio_codes(list(output.token_ids)) code2wav_inputs.append( OmniTokensPrompt( - prompt_token_ids=audio_codes, + prompt_token_ids=[0] * len(audio_codes), + additional_information=None, multi_modal_data=None, mm_processor_kwargs=None, ) ) - return code2wav_inputs + + +# Mark as the sync-side input builder — the structural full-payload gate +# (omni_connector_model_runner_mixin.should_accumulate_full_payload_output) +# fires only when the resolved custom_process_func carries this marker. +llm2code2wav_token_only._is_sync_input = True + + +def llm2code2wav_full_payload( + transfer_manager: Any, + pooling_output: dict[str, Any], + request: Any, +) -> dict[str, Any] | None: + """Producer-side packer for the worker connector data plane. + + covo_audio's fused_thinker_talker stage emits codec ids via + ``request.output_token_ids`` (token-id-only Group B shape — no + hidden_states or embed tensors), so the connector payload is + just the filtered audio codes plus a finished marker. + """ + output_token_ids = list(getattr(request, "output_token_ids", None) or []) + if not output_token_ids: + return None + audio_codes = _filter_audio_codes(output_token_ids) + return { + "codes": {"audio": audio_codes}, + "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, + } diff --git a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py index 87cecc1033d..e2843d9394b 100644 --- a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py @@ -146,3 +146,124 @@ def token2image_to_token2audio( requires_multimodal_data: bool = False, ): return _bridge_tokens(source_outputs, prompt, requires_multimodal_data) + + +# ============================================================================ +# PR3 worker-connector data plane (non-async-chunk path). +# ============================================================================ + +# Per-model REPLACE-keys for the full-payload accumulator. dynin_omni's +# producer model emits new chunks per step (token_ids / runtime_info_json), +# all of which use the default CONCAT/replace semantics — no model_outputs +# entry needs explicit REPLACE. +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + + +def _build_full_payload(pooling_output: dict[str, Any] | None, request: Any) -> dict[str, Any] | None: + """Producer-side packer: assemble dynin_omni connector payload. + + Reads token_ids from ``pooling_output["token_ids"]`` (preferred) or + ``request.output_token_ids`` (fallback). Reads structured non-tensor + metadata from ``pooling_output["runtime_info_json"]`` (JSON-in-uint8) + if present, falling back to ``pooling_output["runtime_info"]`` dict. + Carries forward ``request.additional_information`` so prompt-side + metadata (speaker / language / detok_id) survives the IPC boundary. + """ + if not isinstance(pooling_output, dict): + pooling_output = {} + + token_ids = _to_token_id_list(pooling_output.get("token_ids")) + if not token_ids: + token_ids = _to_token_id_list(pooling_output.get("text_tokens")) + if not token_ids and request is not None: + token_ids = _to_token_id_list(getattr(request, "output_token_ids", None)) + if not token_ids: + return None + + src_additional_info = getattr(request, "additional_information", {}) if request is not None else {} + if not isinstance(src_additional_info, dict): + src_additional_info = {} + + runtime_bridge_info = _decode_runtime_bridge_info(pooling_output.get("runtime_info_json")) + if not runtime_bridge_info: + runtime_bridge_info = pooling_output.get("runtime_info", {}) or {} + + payload = _normalize_additional_info(src_additional_info) + payload.update(_normalize_additional_info(runtime_bridge_info)) + payload["detok_id"] = [_to_int(pooling_output.get("detok_id"), default=_to_int(payload.get("detok_id"), default=0))] + payload["code_predictor_codes"] = token_ids + payload["finished"] = torch.tensor(True, dtype=torch.bool) + return payload + + +def token2text_to_token2image_full_payload( + transfer_manager: Any, + pooling_output: dict[str, Any], + request: Any, +) -> dict[str, Any] | None: + """Producer-side packer for the Stage-0 → Stage-1 (text → image) transition.""" + del transfer_manager + return _build_full_payload(pooling_output, request) + + +def token2image_to_token2audio_full_payload( + transfer_manager: Any, + pooling_output: dict[str, Any], + request: Any, +) -> dict[str, Any] | None: + """Producer-side packer for the Stage-1 → Stage-2 (image → audio) transition.""" + del transfer_manager + return _build_full_payload(pooling_output, request) + + +def _token_only_from_source(source_outputs: list[Any]) -> list[OmniTokensPrompt]: + """Length-only placeholder list mirroring ``_bridge_tokens`` token counts.""" + inputs: list[OmniTokensPrompt] = [] + for source_output in source_outputs: + output = source_output.outputs[0] + mm_out = getattr(output, "multimodal_output", None) or {} + token_ids = _to_token_id_list(mm_out.get("token_ids")) + if not token_ids: + token_ids = _to_token_id_list(mm_out.get("text_tokens")) + if not token_ids: + token_ids = list(getattr(output, "token_ids", []) or []) + if not token_ids: + token_ids = [0] + inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0] * len(token_ids), + additional_information=None, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + return inputs + + +def token2text_to_token2image_token_only( + stage_list, + engine_input_source, + prompt: OmniTokensPrompt | TextPrompt = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """Sync-side placeholder for Stage-1 input (token2image).""" + source_stage_id = engine_input_source[0] if engine_input_source else 0 + source_outputs = stage_list[source_stage_id].engine_outputs + return _token_only_from_source(source_outputs) + + +def token2image_to_token2audio_token_only( + stage_list, + engine_input_source, + prompt: OmniTokensPrompt | TextPrompt = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """Sync-side placeholder for Stage-2 input (token2audio).""" + source_stage_id = engine_input_source[0] if engine_input_source else 0 + source_outputs = stage_list[source_stage_id].engine_outputs + return _token_only_from_source(source_outputs) + + +# Mark sync-side builders for the structural full-payload gate. +token2text_to_token2image_token_only._is_sync_input = True +token2image_to_token2audio_token_only._is_sync_input = True diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index b0b6e887857..e4debdf9d06 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -300,3 +300,121 @@ def llm2code2wav( ) return code2wav_inputs + + +# ============================================================================ +# PR3 worker-connector data plane (non-async-chunk path) — Group B. +# AR runner's `flatten_payload` (data_entry_keys.py:280-302) converts the +# model emit `multimodal_outputs={"codes": {"audio": ...}}` to flat +# `pooling_output["codes.audio"]` before the accumulator runs, so default +# CONCAT semantics build the full codec tensor across all decode steps. +# ============================================================================ + +# Per-model REPLACE-keys for the full-payload accumulator. mimo_audio's +# producer side emits per-step codec frames that should be CONCAT'd across +# steps (not REPLACE'd), so this stays empty. +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + + +def _filter_zero_codec_rows(codec_codes: torch.Tensor) -> torch.Tensor: + """Drop zero-padded codec rows from a 4-D `[N, 1, 8, 4]` tensor. + + Mirrors the zero-row filter in the orchestrator-path `llm2code2wav` + body (see this file's ``llm2code2wav`` around line 224). + """ + if codec_codes.ndim != 4 or codec_codes.numel() == 0: + return codec_codes + is_all_zero = (codec_codes == 0).all(dim=(1, 2, 3)) + nonzero_idx = (~is_all_zero).nonzero(as_tuple=True)[0] + if len(nonzero_idx) == 0: + return codec_codes + if len(nonzero_idx) < codec_codes.shape[0]: + return codec_codes[nonzero_idx] + return codec_codes + + +def llm2code2wav_token_only( + source_outputs: list, + _prompt=None, + _requires_multimodal_data: bool = False, +) -> list: + """Sync-side placeholder for the non-async-chunk Stage-1 (code2wav) input. + + Returns an ``OmniTokensPrompt`` sized to the orchestrator-shape codec + length so the consumer runtime allocates the right number of slots. + The actual codec ids are delivered via the worker connector payload + built by ``llm2code2wav_full_payload``. + """ + from vllm_omni.inputs.data import OmniTokensPrompt + + code2wav_inputs: list = [] + for output_wrapper in source_outputs: + out = output_wrapper.outputs[0] + mm = out.multimodal_output if hasattr(out, "multimodal_output") else None + mm = mm if isinstance(mm, dict) else {} + mm_codes = mm.get("codes", {}) if isinstance(mm, dict) else {} + prompt_len = 0 + if isinstance(mm_codes, dict) and "audio" in mm_codes: + audio = mm_codes["audio"] + if isinstance(audio, torch.Tensor) and audio.numel() > 0: + audio = audio.to(torch.long) + audio = _filter_zero_codec_rows(audio) + # +B*4 per batch row for the prepended pad_vec (see prepend_and_flatten_colmajor) + batch_size = int(audio.shape[0]) if audio.ndim >= 1 else 1 + prompt_len = int(audio.numel()) + batch_size * 4 + if prompt_len > MAX_CODE2WAV_TOKENS: + prompt_len = MAX_CODE2WAV_TOKENS + code2wav_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0] * prompt_len, + additional_information=None, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + return code2wav_inputs + + +llm2code2wav_token_only._is_sync_input = True + + +def llm2code2wav_full_payload( + transfer_manager, + pooling_output: dict, + request, +) -> dict | None: + """Producer-side packer for the worker connector data plane. + + AR runner's ``flatten_payload`` converts the per-step model emit + ``{"codes": {"audio": ...}}`` to ``pooling_output["codes.audio"]``. + The accumulator CONCATs per-step tensors along dim 0, so by flush + time this holds the full ``[total_steps, 1, 8, 4]`` codec tensor. + + A back-compat fallback to nested ``pooling_output["codes"]["audio"]`` + is kept in case a future runtime path bypasses `flatten_payload`. + """ + del transfer_manager + if not isinstance(pooling_output, dict): + return None + codec_codes = pooling_output.get("codes.audio") + if codec_codes is None: + # Back-compat fallback for un-flattened pooler emits. + codes = pooling_output.get("codes") + if isinstance(codes, dict): + codec_codes = codes.get("audio") + if not isinstance(codec_codes, torch.Tensor) or codec_codes.numel() == 0: + return None + codec_codes = codec_codes.to(torch.long) + codec_codes = _filter_zero_codec_rows(codec_codes) + if codec_codes.numel() == 0: + return None + + pad_vec = torch.tensor([TALKER_CODEC_PAD_TOKEN_ID] * 4) + code_final = prepend_and_flatten_colmajor(codec_codes, pad_vec).tolist() + if len(code_final) > MAX_CODE2WAV_TOKENS: + code_final = code_final[:MAX_CODE2WAV_TOKENS] + + return { + "codes": {"audio": code_final}, + "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, + } diff --git a/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py b/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py index e0d538cb3b0..ce9a13807ac 100644 --- a/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py @@ -535,9 +535,101 @@ def thinker2talker( return talker_inputs +# ============================================================================ +# PR3 worker-connector data plane (non-async-chunk path) — Group D minimal. +# ming_flash_omni's thinker→talker bridge passes detokenized text only; +# voice/speaker metadata flows through the USER request's +# additional_information, not the model's pooler_output. So there is no +# heavy tensor to migrate — the PR3 change is structural-only: register +# the _is_sync_input marker so the Phase 2a gate applies consistently. +# full_payload returns None (no per-step connector data). +# ============================================================================ + +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + + +def thinker2talker_token_only( + source_outputs: list[Any], + prompt: OmniTokensPrompt | TextPrompt | None = None, + _requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """Sync-side builder for the non-async-chunk thinker→talker path. + + Ports the legacy ``thinker2talker`` body to the standard PR3 SIP + signature (``source_outputs`` instead of ``stage_list, + engine_input_source``). Body is otherwise identical: extracts the + generated text from each thinker output and packages it with the + request's voice/speaker additional_information for the talker. + """ + if not isinstance(prompt, list): + prompt = [prompt] + + talker_inputs: list[OmniTokensPrompt] = [] + for i, source_output in enumerate(source_outputs): + output = source_output.outputs[0] + + generated_text = output.text if hasattr(output, "text") and output.text else "" + + original_prompt = prompt[i] if i < len(prompt) else None + additional_info: dict[str, Any] = {} + if original_prompt is not None and hasattr(original_prompt, "additional_information"): + additional_info = original_prompt.additional_information or {} + + spk_emb = additional_info.get("spk_emb", None) + if isinstance(spk_emb, list) and spk_emb and not hasattr(spk_emb[0], "device"): + import torch + + spk_emb = torch.tensor(spk_emb, dtype=torch.float32).unsqueeze(0) + + talker_info = { + "ming_task": "omni", + "text": generated_text, + "spk_emb": spk_emb, + "voice_name": additional_info.get("voice_name", "DB30"), + "prompt_text": additional_info.get("prompt_text", None), + "prompt_wav_lat": additional_info.get("prompt_wav_lat", None), + "prompt_wav_emb": additional_info.get("prompt_wav_emb", None), + "max_text_length": additional_info.get("max_text_length", 50), + } + + talker_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0], + additional_information=talker_info, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + + return talker_inputs + + +thinker2talker_token_only._is_sync_input = True + + +def thinker2talker_full_payload( + transfer_manager, + pooling_output, + request, +): + """Producer-side packer — no-op. + + ming_flash_omni's thinker emits no heavy tensor to ship via the worker + connector (the bridge passes text only, and speaker metadata arrives + through the USER request's additional_information). Returning None + causes the connector to skip the send for this transition. The + structural gate still fires so Phase 2a / 2d infrastructure behavior + is consistent across in-scope models. + """ + del transfer_manager, pooling_output, request + return None + + __all__ = [ "CFG_TEXT_SUFFIX", "expand_cfg_prompts", "thinker2imagegen", "thinker2talker", + "thinker2talker_full_payload", + "thinker2talker_token_only", ] diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index 472dcb93386..ab49d909817 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -86,3 +86,186 @@ def talker2code2wav( ) ) return code2wav_inputs + + +# ============================================================================ +# PR3 worker-connector data plane (non-async-chunk path) — Group B half. +# Only talker→code2wav is migrated in this commit; thinker→talker (Group A) +# requires model-side pooler_output emit and is deferred. +# ============================================================================ + +# Per-model REPLACE-keys for the full-payload accumulator. qwen2_5_omni's +# producer side does not emit model_outputs through pooler_output (it ships +# token_ids on the request directly), so the empty set preserves correctness. +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + + +def _strip_codec_boundaries(token_ids: list[int]) -> list[int]: + """Drop TALKER_CODEC_START/END boundary tokens (mirror talker2code2wav).""" + tids = list(token_ids) + if tids and tids[0] == TALKER_CODEC_START_TOKEN_ID: + tids = tids[1:] + if tids and tids[-1] == TALKER_CODEC_END_TOKEN_ID: + tids = tids[:-1] + return tids + + +def talker2code2wav_token_only( + source_outputs, + _prompt: OmniTokensPrompt | TextPrompt = None, + _requires_multimodal_data: bool = False, +): + """Sync-side placeholder for Stage-2 input (code2wav). + + Returns OmniTokensPrompt sized to the stripped codec token count. + Actual codec ids are delivered via the worker connector payload built + by ``talker2code2wav_full_payload``. + """ + code2wav_inputs = [] + for talker_output in source_outputs: + output = talker_output.outputs[0] + token_ids = _strip_codec_boundaries(list(output.cumulative_token_ids)) + if not token_ids: + continue + code2wav_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0] * len(token_ids), + additional_information=None, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + return code2wav_inputs + + +talker2code2wav_token_only._is_sync_input = True + + +def talker2code2wav_full_payload( + transfer_manager, + pooling_output: dict, + request, +) -> dict | None: + """Producer-side packer: ship the stripped codec ids via connector. + + Group B shape — token_ids only. The talker stage's output already + carries the codec ids on ``request.output_token_ids``; we strip the + boundary tokens and pack a minimal payload. + """ + del transfer_manager + token_ids = list(getattr(request, "output_token_ids", None) or []) + if not token_ids: + return None + token_ids = _strip_codec_boundaries(token_ids) + if not token_ids: + return None + return { + "codes": {"audio": token_ids}, + "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, + } + + +# ============================================================================ +# PR3 worker-connector data plane (non-async-chunk path) — Group A reduced +# to D-minimal shape. +# +# Three subagent investigations (2026-05-16, audits/) confirmed: +# - qwen2_5_omni talker consumes ONE tensor (last-layer hidden state) via +# Linear(3584, 896); no early-layer-0 consumer, no `accept_hidden_layer` +# HF config field. +# - `text_hidden_states` is NOT plumbed into the AR runner pooler_output +# chain, so the existing accumulator cannot ship it. +# So the PR3 migration is structural-only: thinker2talker_token_only mirrors +# the legacy body so additional_information continues to carry the latent +# tensor (same as cosyvoice3's post-fix state). full_payload returns None. +# ============================================================================ + +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + + +def thinker2talker_token_only( + source_outputs, + prompt: OmniTokensPrompt | TextPrompt = None, + requires_multimodal_data: bool = False, +): + """Sync-side builder for the non-async-chunk thinker->talker path. + + Body is identical to legacy ``thinker2talker`` above — preserves the + orchestrator-shaped data path (latent in additional_information) so + the talker stage receives thinker hidden states without requiring the + worker connector to deliver them. Filed as a Phase 4 follow-up to + route the latent via connector once the AR runner's text_hidden_states + plumbing is wired into pooler_output / model_intermediate_buffer. + + The ``_is_sync_input = True`` marker below activates the Phase 2a + structural gate so the rest of the PR3 infrastructure (gen scheduler + bridge, runner lifecycle, full-payload accumulator) participates + consistently with the other 8 migrated transitions. + """ + thinker_outputs = source_outputs + talker_inputs = [] + if not isinstance(prompt, list): + prompt = [prompt] + multi_modal_data = { + thinker_output.request_id: p.get("multi_modal_data", None) for thinker_output, p in zip(thinker_outputs, prompt) + } + + for i, thinker_output in enumerate(thinker_outputs): + output = thinker_output.outputs[0] + prompt_token_ids = thinker_output.prompt_token_ids + thinker_output_ids = output.cumulative_token_ids + prompt_token_ids_len = len(prompt_token_ids) + mm: OmniPayload = output.multimodal_output + latent = mm["latent"] + thinker_hidden_states = latent.clone().detach().to(latent.device) + decode_hidden = thinker_hidden_states[prompt_token_ids_len:].to(torch.float32) + prefill_hidden = thinker_hidden_states[:prompt_token_ids_len].to(torch.float32) + additional_information = to_dict( + OmniPayloadStruct( + hidden_states=HiddenStatesStruct(output=decode_hidden, output_shape=list(decode_hidden.shape)), + embed=EmbeddingsStruct(prefill=prefill_hidden, prefill_shape=list(prefill_hidden.shape)), + ids=IdsStruct(prompt=list(prompt_token_ids), output=list(thinker_output_ids)), + ) + ) + talker_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[TALKER_CODEC_START_TOKEN_ID] + + [TALKER_CODEC_PAD_TOKEN_ID] * (len(prompt_token_ids)) + + [TALKER_CODEC_END_TOKEN_ID], + additional_information=additional_information, + multi_modal_data=( + multi_modal_data[thinker_output.request_id] + if requires_multimodal_data and multi_modal_data is not None + else None + ), + mm_processor_kwargs=None, + ) + ) + + return talker_inputs + + +thinker2talker_token_only._is_sync_input = True + + +def thinker2talker_full_payload( + transfer_manager, + pooling_output, + request, +): + """Producer-side packer — no-op. + + qwen2_5_omni's thinker emits its last-layer hidden state via + ``text_hidden_states`` (the OmniOutput field), which is materialized + into ``multimodal_output["latent"]`` at the engine boundary. That + field is NOT plumbed into the AR runner's pooler_output chain + (data_entry_keys.flatten_payload + gpu_ar_model_runner emit), so the + accumulator cannot ship it via the worker connector today. + + Returning None tells the connector to skip the send for this + transition; the consumer reads the latent via additional_information + (preserved by thinker2talker_token_only). Filed as Phase 4 follow-up + alongside #90 (speaker/language) and cosyvoice3 (embed.*). + """ + del transfer_manager, pooling_output, request + return None diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index faa7e4cc4d3..8ef188009f2 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -280,3 +280,198 @@ def talker2code2wav_async_chunk( speaker=extract_speaker_from_request(request), language=extract_language_from_request(request), ) + + +# ============================================================================ +# PR3 worker-connector data plane (non-async-chunk path) — Group C multi-key. +# AR runner's `flatten_payload` (data_entry_keys.py:280-302) converts the +# model emit `multimodal_outputs={"codes": {"audio": ..., "ref": ...}, +# "meta": {"ref_code_len": ..., "codec_streaming": ...}}` to flat dotted keys +# (`codes.audio`, `codes.ref`, `meta.ref_code_len`, `meta.codec_streaming`) +# before the accumulator runs. +# - codes.audio is 2-D so default CONCAT across steps builds the full sequence. +# - codes.ref is a list (not Tensor with dim>=2) so accumulator LATEST-wins +# keeps the prefill-emitted ref tensor across decode steps (which don't emit +# ref again). +# - meta.ref_code_len is 1-D so LATEST-wins; consumer reads [-1]. +# ============================================================================ + +# Per-model REPLACE-keys for the full-payload accumulator. qwen3_tts's +# producer side emits codec frames that should CONCAT (codes.audio) plus +# scalars/lists that are correctly handled by default LATEST-wins, so this +# stays empty. +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() + +_CODEBOOK_SIZE = 2048 +_NUM_QUANTIZERS_DEFAULT = 16 + + +def _filter_audio_codes_qwen3_tts(audio_codes: torch.Tensor) -> torch.Tensor: + """Filter zero-padded + out-of-range codec frames. + + Mirrors the orchestrator-path body in `talker2code2wav` above. + """ + if not isinstance(audio_codes, torch.Tensor) or audio_codes.numel() == 0: + return audio_codes + if audio_codes.ndim != 2: + return audio_codes + valid_mask = audio_codes.any(dim=1) & (audio_codes.max(dim=1).values < _CODEBOOK_SIZE) + return audio_codes[valid_mask] + + +def _normalize_ref_code(ref_code, num_quantizers: int, ref_code_len: int): + """Coerce ref_code into a [ref_len, Q] tensor or None. Mirrors orchestrator path.""" + if isinstance(ref_code, list): + ref_code = ref_code[0] if ref_code else None + if not isinstance(ref_code, torch.Tensor) or ref_code.numel() == 0: + return None, 0 + ref_code = ref_code.to(torch.long).cpu().contiguous() + if ref_code.ndim == 1: + if ref_code.numel() % num_quantizers != 0: + return None, 0 + ref_code = ref_code.reshape(-1, num_quantizers) + elif ref_code.ndim != 2: + return None, 0 + if ref_code_len > 0 and int(ref_code.shape[0]) > ref_code_len: + ref_code = ref_code[:ref_code_len] + return ref_code, int(ref_code.shape[0]) + + +def talker2code2wav_token_only( + source_outputs: list, + prompt=None, + _requires_multimodal_data: bool = False, +) -> list: + """Sync-side placeholder for the non-async-chunk Stage-1 (code2wav) input. + + Sized to the expected codec token count (codebook-major flat: + Q * (ref_frames + audio_frames)). Speaker / language metadata are + extracted from `prompt` and threaded via `additional_information` + (orchestrator-style; same as the legacy `talker2code2wav` builder). + Actual codec ids are delivered via the worker connector payload built + by `talker2code2wav_full_payload`. + """ + from vllm_omni.inputs.data import OmniTokensPrompt + + code2wav_inputs: list = [] + for i, talker_output in enumerate(source_outputs): + if not talker_output.finished: + continue + output = talker_output.outputs[0] + mm = output.multimodal_output if hasattr(output, "multimodal_output") else None + mm = mm if isinstance(mm, dict) else {} + mm_codes = mm.get("codes", {}) if isinstance(mm, dict) else {} + token_ids = getattr(output, "cumulative_token_ids", []) or [] + seq_len = max(len(token_ids) - 1, 0) + + audio = mm_codes.get("audio") if isinstance(mm_codes, dict) else None + if isinstance(audio, torch.Tensor) and audio.numel() > 0: + audio = audio.to(torch.long) + audio = _filter_audio_codes_qwen3_tts(audio) + if seq_len > 0 and audio.ndim == 2 and int(audio.shape[0]) > seq_len: + audio = audio[-seq_len:] + num_audio_frames = int(audio.shape[0]) if audio.ndim == 2 else 0 + num_quantizers = int(audio.shape[1]) if audio.ndim == 2 and audio.shape[1] > 0 else _NUM_QUANTIZERS_DEFAULT + else: + num_audio_frames = 0 + num_quantizers = _NUM_QUANTIZERS_DEFAULT + + ref_code_raw = mm_codes.get("ref") if isinstance(mm_codes, dict) else None + ref_code_len_raw = mm.get("meta", {}).get("ref_code_len") if isinstance(mm.get("meta"), dict) else None + if isinstance(ref_code_len_raw, torch.Tensor): + ref_code_len = int(ref_code_len_raw.reshape(-1)[-1].item()) if ref_code_len_raw.numel() > 0 else 0 + elif ref_code_len_raw is None: + ref_code_len = 0 + else: + ref_code_len = int(ref_code_len_raw) + _, ref_frames = _normalize_ref_code(ref_code_raw, num_quantizers, ref_code_len) + + # Codebook-major flat: Q * (ref_frames + audio_frames) + prompt_len = num_quantizers * (ref_frames + num_audio_frames) + + additional_info = to_dict( + OmniPayloadStruct( + meta=MetaStruct(left_context_size=ref_frames) if ref_frames > 0 else None, + speaker=extract_speaker_from_prompt(prompt, index=i), + language=extract_language_from_prompt(prompt, index=i), + ) + ) + code2wav_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0] * prompt_len, + additional_information=additional_info if additional_info else None, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + return code2wav_inputs + + +talker2code2wav_token_only._is_sync_input = True + + +def talker2code2wav_full_payload( + transfer_manager, + pooling_output, + request, +): + """Producer-side packer. + + Reads accumulated codec from `pooling_output["codes.audio"]` (CONCAT + across steps via flatten_payload), latest `pooling_output["codes.ref"]` + (prefill-emitted), and latest `pooling_output["meta.ref_code_len"]`. + Replicates the orchestrator-path body of `talker2code2wav` (filter, + crop to seq_len, prepend ref, codebook-major flatten). + """ + del transfer_manager + if not isinstance(pooling_output, dict): + return None + + # codes.audio — try flat dotted first (flatten_payload), then nested fallback. + audio = pooling_output.get("codes.audio") + if audio is None: + codes_nested = pooling_output.get("codes") + if isinstance(codes_nested, dict): + audio = codes_nested.get("audio") + if not isinstance(audio, torch.Tensor) or audio.numel() == 0: + return None + audio = audio.to(torch.long) + audio = _filter_audio_codes_qwen3_tts(audio) + if audio.numel() == 0: + return None + + output_token_ids = list(getattr(request, "output_token_ids", None) or []) + seq_len = max(len(output_token_ids) - 1, 0) + if seq_len > 0 and audio.ndim == 2 and int(audio.shape[0]) > seq_len: + audio = audio[-seq_len:] + + num_quantizers = int(audio.shape[1]) if audio.ndim == 2 and audio.shape[1] > 0 else _NUM_QUANTIZERS_DEFAULT + + # meta.ref_code_len — flat dotted then nested fallback. + ref_code_len_raw = pooling_output.get("meta.ref_code_len") + if ref_code_len_raw is None: + meta_nested = pooling_output.get("meta") + if isinstance(meta_nested, dict): + ref_code_len_raw = meta_nested.get("ref_code_len") + if isinstance(ref_code_len_raw, torch.Tensor): + ref_code_len = int(ref_code_len_raw.reshape(-1)[-1].item()) if ref_code_len_raw.numel() > 0 else 0 + elif ref_code_len_raw is None: + ref_code_len = 0 + else: + ref_code_len = int(ref_code_len_raw) + + # codes.ref — flat dotted then nested fallback. + ref_code_raw = pooling_output.get("codes.ref") + if ref_code_raw is None: + codes_nested = pooling_output.get("codes") + if isinstance(codes_nested, dict): + ref_code_raw = codes_nested.get("ref") + ref_code, ref_frames = _normalize_ref_code(ref_code_raw, num_quantizers, ref_code_len) + if ref_code is not None: + audio = torch.cat([ref_code.to(audio.device), audio], dim=0) + + codec_codes = audio.transpose(0, 1).cpu().reshape(-1).tolist() + return { + "codes": {"audio": codec_codes}, + "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, + } diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index ed36688d58b..ff6ed255218 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -47,20 +47,31 @@ def should_accumulate_full_payload_output(model_config, custom_process_func) -> bool: - """Structural gate: accumulate full-payload outputs iff the configured - custom_process_func is a sync-side builder (marked by `_is_sync_input`) - and the stage is not in async_chunk mode. - - Lives at module level so the producer-side `OmniConnectorModelRunnerMixin - ._should_accumulate_full_payload_output()` does not need an arch-specific - import chain (the structural check is arch-agnostic). + """Producer-side structural gate. + + Fires iff the worker has a connector payload builder loaded + (``custom_process_func`` resolved via ``_load_custom_func`` from the + stage_config's ``custom_process_next_stage_input_func`` or the + ``*_full_payload`` derivative of ``custom_process_input_func``), the + stage is not in async_chunk mode, and ``model_stage`` is set. + + NOTE: the ``_is_sync_input`` marker is on the *consumer-side* + ``*_token_only`` builder, not on the ``*_full_payload`` packer that + workers load on the *producer* side. So checking it here would always + return False and the accumulator would never run. The + consumer-side scheduler gate (``uses_full_payload_input_coordinator`` + in ``omni_scheduling_coordinator.py``) is where the marker is + appropriately tested. + + Pre-Phase-2a, this gate was an arch + stage whitelist + (``Qwen3OmniMoeForConditionalGeneration`` and ``thinker``/``talker``). + Phase 2a generalized that to "any stage with a loaded packer + not + async_chunk + model_stage set" — arch-agnostic. """ if custom_process_func is None: return False if getattr(model_config, "async_chunk", False): return False - if not getattr(custom_process_func, "_is_sync_input", False): - return False return getattr(model_config, "model_stage", None) is not None From 6fa1364461bf1b9dd674d6db3812ec4331ab1957 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Sun, 17 May 2026 11:22:31 +0000 Subject: [PATCH 05/19] [PR3] Block A allowlist activation per-arch (init + coordinator) Squash of 5 [PR3 Block A] Enable commits that wire each migrated arch into both the worker init allowlist (_BLOCK_A_INIT_ALLOWLIST in gpu_ar_model_runner.py + gpu_generation_model_runner.py) and the scheduler coordinator allowlist (_FULL_PAYLOAD_INPUT_STAGES in omni_scheduling_coordinator.py), keeping the two in lockstep so a gate-enabled stage always has a wired-up worker connector. Archs activated (in commit order): - qwen2_5_omni talker -> code2wav (Block A pilot for q25; thinker -> talker stays orchestrator-routed at this commit because the producer builder is still a no-op). Also widens init_omni_connectors arch allowlist from Qwen3-only to a 7-arch frozenset. - covo_audio fused_thinker_talker -> code2wav. - mimo_audio fused_thinker_talker -> code2wav. - qwen3_tts talker -> code2wav (Qwen3TTSTalkerForConditionalGeneration -> Qwen3TTSCode2Wav). - cosyvoice3 cosyvoice3_talker -> cosyvoice3_code2wav. After this commit each arch's Stage-1 receives the full-payload delivery via the worker connector instead of via the orchestrator- side additional_information path. The producer builders themselves were added in the previous "Per-arch SIP builders" commit. Signed-off-by: natureofnature --- .../core/sched/omni_scheduling_coordinator.py | 16 ++++++++++++++ .../stage_input_processors/mimo_audio.py | 14 +++++++++++++ .../stage_input_processors/qwen2_5_omni.py | 14 ++++++++++++- vllm_omni/worker/gpu_ar_model_runner.py | 21 ++++++++++++++----- .../worker/gpu_generation_model_runner.py | 15 +++++++++---- 5 files changed, 70 insertions(+), 10 deletions(-) diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index 1aaa486873f..d00ab7687d9 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -40,6 +40,22 @@ { ("Qwen3OmniMoeForConditionalGeneration", "talker"), ("Qwen3OmniMoeForConditionalGeneration", "code2wav"), + # PR3 Block A incremental: enabling qwen2_5_omni talker->code2wav only. + # thinker->talker stays orchestrator-routed because its + # `thinker2talker_full_payload` is a no-op (heavy `text_hidden_states` + # not yet plumbed into pooler_output). Adding (Qwen2_5, talker) here + # without that plumbing would park talker requests in WAITING_FOR_INPUT + # with no transport to release them. + ("Qwen2_5OmniForConditionalGeneration", "code2wav"), + # PR3 Block A: covo_audio is fused_thinker_talker (Stage 0) → code2wav (Stage 1) + ("CovoAudioForConditionalGeneration", "code2wav"), + # PR3 Block A: mimo_audio is fused_thinker_talker (Stage 0) → code2wav (Stage 1) + ("MiMoAudioModel", "code2wav"), + # PR3 Block A: qwen3_tts is Qwen3TTSTalkerForConditionalGeneration (Stage 0) + # → Qwen3TTSCode2Wav (Stage 1). Stage 1 is the consumer. + ("Qwen3TTSCode2Wav", "code2wav"), + # PR3 Block A: cosyvoice3 stages cosyvoice3_talker (Stage 0) → cosyvoice3_code2wav (Stage 1) + ("CosyVoice3Model", "cosyvoice3_code2wav"), } ) diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index e4debdf9d06..3f203a61e4f 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -147,6 +147,20 @@ def llm2code2wav_async_chunk( Accumulates codes in connector per request_id, returns payload only when chunk_size is full or request is finished; returns None when waiting. """ + # Null guard: under Block A universal-ish init, the producer-side + # chunk_transfer_adapter calls this every emit step including no-output + # steps where pooling_output is None. Pre-Block-A this code path was + # unreachable (no connector init for mimo_audio). + if pooling_output is None or not isinstance(pooling_output, dict): + if is_finished: + connector = getattr(transfer_manager, "connector", None) + raw_cfg = getattr(connector, "config", {}) or {} + cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} + chunk_size = int(cfg.get("codec_chunk_frames", 3)) + left_context_size = int(cfg.get("codec_left_context_frames", 3)) + request_id = getattr(request, "external_req_id", None) + return _flush_remaining_codes(transfer_manager, request_id, chunk_size, left_context_size) + return None connector = getattr(transfer_manager, "connector", None) raw_cfg = getattr(connector, "config", {}) or {} cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index ab49d909817..026aa3cd44d 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -101,12 +101,24 @@ def talker2code2wav( def _strip_codec_boundaries(token_ids: list[int]) -> list[int]: - """Drop TALKER_CODEC_START/END boundary tokens (mirror talker2code2wav).""" + """Drop TALKER_CODEC_START/END boundary tokens (mirror talker2code2wav) + and filter sentinel/invalid codec ids. + + The talker emits codec ids on `request.output_token_ids`. Negative ids + (e.g., -1) appear as "stopped early" / "no token sampled this step" + sentinels and are NOT valid codec embedding indices. Passing -1 to + `torch.embedding` triggers a CUDA gather-kernel OOB assert in the + code2wav stage (`vectorized_gather_kernel index out of bounds`). We + filter them here at the producer-side strip so the worker connector + payload only ships valid codec ids. + """ tids = list(token_ids) if tids and tids[0] == TALKER_CODEC_START_TOKEN_ID: tids = tids[1:] if tids and tids[-1] == TALKER_CODEC_END_TOKEN_ID: tids = tids[:-1] + # Filter negative sentinel ids that the talker engine may insert. + tids = [t for t in tids if t >= 0] return tids diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 5da4cf9d870..482d42228e4 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -84,11 +84,22 @@ def __init__(self, *args, **kwargs): self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False) # Initialize KV cache manager (preserve vllm_config fallback behavior) self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config) - # Only Qwen3-Omni currently consumes the connector-based full-payload - # handoff added in this PR. Other model architectures (e.g. Bagel - # diffusion) retain their pre-existing runner behavior so this PR - # does not perturb them. - if getattr(self.model_config, "model_arch", None) == "Qwen3OmniMoeForConditionalGeneration": + # Worker-connector full-payload init is gated by an arch allowlist that + # grows as each per-arch transition is verified end-to-end (PR3 incremental + # Block A). Adding an arch here without also wiring its scheduler-side + # gate entries in `omni_scheduling_coordinator._FULL_PAYLOAD_INPUT_STAGES` + # produces a Stage-1 hang on the consumer side (request parks but no + # transport ever releases). Keep the two in lockstep. + _BLOCK_A_INIT_ALLOWLIST = { + "Qwen3OmniMoeForConditionalGeneration", + "Qwen2_5OmniForConditionalGeneration", + "CovoAudioForConditionalGeneration", + "MiMoAudioModel", + "Qwen3TTSTalkerForConditionalGeneration", + "Qwen3TTSCode2Wav", + "CosyVoice3Model", + } + if getattr(self.model_config, "model_arch", None) in _BLOCK_A_INIT_ALLOWLIST: self.init_omni_connectors( vllm_config=self.vllm_config, model_config=self.model_config, diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 175547ff31d..ec20c07801a 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -59,10 +59,17 @@ class GPUGenerationModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Scope full-payload connector init to Qwen3-Omni: other generation - # models (e.g. Bagel DiT) retain their pre-existing runner setup - # so this refactor does not perturb them. - if getattr(self.model_config, "model_arch", None) == "Qwen3OmniMoeForConditionalGeneration": + # See gpu_ar_model_runner.py for Block A allowlist policy. + _BLOCK_A_INIT_ALLOWLIST = { + "Qwen3OmniMoeForConditionalGeneration", + "Qwen2_5OmniForConditionalGeneration", + "CovoAudioForConditionalGeneration", + "MiMoAudioModel", + "Qwen3TTSTalkerForConditionalGeneration", + "Qwen3TTSCode2Wav", + "CosyVoice3Model", + } + if getattr(self.model_config, "model_arch", None) in _BLOCK_A_INIT_ALLOWLIST: self.init_omni_connectors( vllm_config=self.vllm_config, model_config=self.model_config, From 77b6762914b01db1805810d248138b625c011215 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Tue, 19 May 2026 01:48:43 +0000 Subject: [PATCH 06/19] [PR3] Closure: per-arch full_payload producers + dynin connector migration + Codex review Consolidates the PR3 closure work: prefix_caching state-leak fix, qwen2_5_omni real producer plumbing, dynin connector data-plane migration, and Codex review feedback. `omni_connector_model_runner_mixin.py`: `flush_full_payload_outputs` is invoked at the start of `cleanup_finished_request` to drain any pending full_payload entry for the finishing request before the rest of the cleanup runs. An earlier unconditional pop of `_pending_full_payload_send` raced with flush and broke audio consumers; flush-then-cleanup is idempotent and safe for paths without a downstream consumer. Fixes `test_thinker_prefix_caching[omni_server0]` state-leak regression. `stage_input_processors/qwen2_5_omni.py`: `thinker2talker_full_payload` is no longer a no-op. The real producer-side packer reads `pooling_output["hidden"]`, splits prefill/decode, applies a stop-emission trim aligned with the legacy contract (mirrors qwen3_omni; covers max-token finishes without losing a hidden row), and builds an `OmniPayloadStruct` matching the field set that `thinker2talker_token_only` writes into `additional_information`. `deploy/dynin_omni_ci.yaml`: adds `custom_process_next_stage_input_func` on Stage 0 (token2text) and Stage 1 (token2image) pointing at the `_full_payload` builders. Without this entry, `_load_custom_func` finds no builder, `_should_accumulate_full_payload_output()` returns False, and the producer never enqueues a connector message. End-to-end connector flow observed at runtime: Stage-0 flush_full_payload_outputs(req_id) -> to_send=[req_id] Stage-0 send_full_payload_outputs payload_keys=[... 'code_predictor_codes', ...] Stage-1 full_payload recv complete: payload_type=dict Stage-1 flush_full_payload_outputs ... -> to_send=[req_id] (Stage-2 recv likewise) dynin's legacy consumer-side `custom_process_input_func` on Stage 1 and Stage 2 is retained (see DECISIONS.md D-017): the connector path is primary, but `_bridge_tokens` is still wired to propagate `additional_information` to the downstream `OmniTokensPrompt` -- a propagation the scheduler-side rewrite (`metadata` -> `request.prompt_token_ids`) does not yet do. Offline t2s passes through the pure connector pipeline; the online server test relies on `request.additional_information` for response assembly, so the legacy SIP stays until the scheduler is extended in a follow-up (see PR4 direction in D-017). `core/sched/omni_scheduling_coordinator.py` adds three (arch, stage) pairs to `_FULL_PAYLOAD_INPUT_STAGES`: - `(Qwen2_5OmniForConditionalGeneration, talker)` -- consumer gate for the newly-active qwen2_5_omni connector path. - `(DyninOmniForConditionalGeneration, token2image)` and `(..., token2audio)` -- consumer gates that park the dynin Stage 1 and Stage 2 requests in WAITING_FOR_INPUT until the upstream payload arrives. `worker/gpu_ar_model_runner.py` and `worker/gpu_generation_model_runner.py` add `DyninOmniForConditionalGeneration` to `_BLOCK_A_INIT_ALLOWLIST` so `init_omni_connectors` runs for dynin workers. Codex review fixes: - Issue 1 (q25 SIP trim heuristic): replace unconditional `output_token_ids[:-1] + h[:-1]` trim with a `stop_emission_drop`- based trim that mirrors qwen3_omni's contract. Folded into the qwen2_5_omni SIP rewrite above. - Issue 3 (external_req_id None fallback): `register_chunk_recv` and `_resolve_external_req_id` now treat an explicit `None` on the request struct as a fallback to the internal `request_id`, preventing recv-key collisions like `None__` across requests. - Issue 4 (coordinator test): rewrite the positive case in `tests/core/sched/test_omni_scheduling_coordinator.py` to iterate `_FULL_PAYLOAD_INPUT_STAGES` so newly-whitelisted (arch, stage) pairs fire the gate by construction. Remove the q2.5/talker entry from the negative cases (now whitelisted). `stage_configs/dynin_omni.yaml` and `stage_configs/dynin_omni_multiconnector.yaml`: remove a transiently- added `sync_process_input_func` line. The active deploy path is the `deploy/dynin_omni_ci.yaml` rewrite above; the stage_configs siblings are kept in sync to avoid drift. Verified on H800 dev environment: - dynin e2e (offline + online, full): 11 passed / 1 skipped / 0 fail in 19:12. Connector flow verified by stage-0/1/2 flush+send+recv log triplet. - qwen3_omni online_serving (3 tests, dynin diff applied): 3 passed / 0 failed. - Canonical CI sweep (test-ready.yml + test-merge.yml, 21 steps): 21 pass / 0 fail in 77.7 min. Signed-off-by: natureofnature --- .../sched/test_omni_scheduling_coordinator.py | 18 +- .../core/sched/omni_scheduling_coordinator.py | 16 +- vllm_omni/deploy/dynin_omni_ci.yaml | 2 + .../stage_configs/dynin_omni.yaml | 1 - .../dynin_omni_multiconnector.yaml | 1 - .../stage_input_processors/qwen2_5_omni.py | 154 ++++++++++++++++-- vllm_omni/worker/gpu_ar_model_runner.py | 1 + .../worker/gpu_generation_model_runner.py | 1 + .../omni_connector_model_runner_mixin.py | 30 +++- 9 files changed, 187 insertions(+), 37 deletions(-) diff --git a/tests/core/sched/test_omni_scheduling_coordinator.py b/tests/core/sched/test_omni_scheduling_coordinator.py index d2fea3a3f49..74670b06c7f 100644 --- a/tests/core/sched/test_omni_scheduling_coordinator.py +++ b/tests/core/sched/test_omni_scheduling_coordinator.py @@ -101,17 +101,22 @@ class TestFullPayloadCoordinatorSelection(unittest.TestCase): input through the worker connector. """ - def test_qwen3_omni_talker_and_code2wav_fire_gate(self): - for model_stage in ("talker", "code2wav"): + def test_all_whitelisted_arch_stage_pairs_fire_gate(self): + """All (arch, stage) pairs in _FULL_PAYLOAD_INPUT_STAGES must fire + the gate when stage_id > 0 and async_chunk=False. + """ + from vllm_omni.core.sched.omni_scheduling_coordinator import _FULL_PAYLOAD_INPUT_STAGES + + for arch, stage in _FULL_PAYLOAD_INPUT_STAGES: model_config = SimpleNamespace( stage_id=1, async_chunk=False, - model_arch="Qwen3OmniMoeForConditionalGeneration", - model_stage=model_stage, + model_arch=arch, + model_stage=stage, ) self.assertTrue( uses_full_payload_input_coordinator(model_config), - msg=f"expected gate to fire for Qwen3Omni/{model_stage}", + msg=f"expected gate to fire for {arch}/{stage} (entry in _FULL_PAYLOAD_INPUT_STAGES)", ) def test_other_arch_or_stage_or_mode_does_not_fire(self): @@ -128,9 +133,6 @@ def test_other_arch_or_stage_or_mode_does_not_fire(self): model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage="some_future_stage", ), - SimpleNamespace( - stage_id=1, async_chunk=False, model_arch="Qwen2_5OmniForConditionalGeneration", model_stage="talker" - ), SimpleNamespace( stage_id=1, async_chunk=False, model_arch="Qwen3TTSForConditionalGeneration", model_stage="code2wav" ), diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index d00ab7687d9..effd56e318c 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -40,12 +40,11 @@ { ("Qwen3OmniMoeForConditionalGeneration", "talker"), ("Qwen3OmniMoeForConditionalGeneration", "code2wav"), - # PR3 Block A incremental: enabling qwen2_5_omni talker->code2wav only. - # thinker->talker stays orchestrator-routed because its - # `thinker2talker_full_payload` is a no-op (heavy `text_hidden_states` - # not yet plumbed into pooler_output). Adding (Qwen2_5, talker) here - # without that plumbing would park talker requests in WAITING_FOR_INPUT - # with no transport to release them. + # PR3 Block A: qwen2_5_omni thinker->talker now uses the real + # full-payload producer builder (text_hidden_states routed via + # pooler_output["hidden"] -> accumulator -> connector). Both + # stages of qwen2_5_omni are enabled. + ("Qwen2_5OmniForConditionalGeneration", "talker"), ("Qwen2_5OmniForConditionalGeneration", "code2wav"), # PR3 Block A: covo_audio is fused_thinker_talker (Stage 0) → code2wav (Stage 1) ("CovoAudioForConditionalGeneration", "code2wav"), @@ -56,6 +55,11 @@ ("Qwen3TTSCode2Wav", "code2wav"), # PR3 Block A: cosyvoice3 stages cosyvoice3_talker (Stage 0) → cosyvoice3_code2wav (Stage 1) ("CosyVoice3Model", "cosyvoice3_code2wav"), + # PR3 dynin migration: token2text (Stage 0) -> token2image (Stage 1) + # -> token2audio (Stage 2). Producer wires via + # custom_process_next_stage_input_func: *_full_payload in deploy yaml. + ("DyninOmniForConditionalGeneration", "token2image"), + ("DyninOmniForConditionalGeneration", "token2audio"), } ) diff --git a/vllm_omni/deploy/dynin_omni_ci.yaml b/vllm_omni/deploy/dynin_omni_ci.yaml index 525b7d888c2..2ddc281e7a7 100644 --- a/vllm_omni/deploy/dynin_omni_ci.yaml +++ b/vllm_omni/deploy/dynin_omni_ci.yaml @@ -14,6 +14,7 @@ stage_args: worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler engine_output_type: latent + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image_full_payload trust_remote_code: true gpu_memory_utilization: 0.5 enforce_eager: true @@ -36,6 +37,7 @@ stage_args: worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler engine_output_type: latent + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2image_to_token2audio_full_payload trust_remote_code: true gpu_memory_utilization: 0.2 enforce_eager: true diff --git a/vllm_omni/model_executor/stage_configs/dynin_omni.yaml b/vllm_omni/model_executor/stage_configs/dynin_omni.yaml index 0e7171eb9ee..024443e8d16 100644 --- a/vllm_omni/model_executor/stage_configs/dynin_omni.yaml +++ b/vllm_omni/model_executor/stage_configs/dynin_omni.yaml @@ -40,7 +40,6 @@ stage_args: max_num_batched_tokens: 32768 engine_input_source: [0] custom_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image - sync_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image_token_only final_output: true final_output_type: image diff --git a/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml index 0189718cea0..1ab65f0fab9 100644 --- a/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml +++ b/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml @@ -39,7 +39,6 @@ stage_args: max_num_batched_tokens: 32768 engine_input_source: [0] custom_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image - sync_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image_token_only final_output: true final_output_type: image input_connectors: diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index 026aa3cd44d..411f4973dd1 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -1,3 +1,5 @@ +import logging + import torch from vllm.inputs import TextPrompt @@ -11,6 +13,8 @@ ) from vllm_omni.inputs.data import OmniTokensPrompt +logger = logging.getLogger(__name__) + TALKER_CODEC_PAD_TOKEN_ID = 8292 TALKER_CODEC_START_TOKEN_ID = 8293 TALKER_CODEC_END_TOKEN_ID = 8294 @@ -265,19 +269,139 @@ def thinker2talker_full_payload( pooling_output, request, ): - """Producer-side packer — no-op. - - qwen2_5_omni's thinker emits its last-layer hidden state via - ``text_hidden_states`` (the OmniOutput field), which is materialized - into ``multimodal_output["latent"]`` at the engine boundary. That - field is NOT plumbed into the AR runner's pooler_output chain - (data_entry_keys.flatten_payload + gpu_ar_model_runner emit), so the - accumulator cannot ship it via the worker connector today. - - Returning None tells the connector to skip the send for this - transition; the consumer reads the latent via additional_information - (preserved by thinker2talker_token_only). Filed as Phase 4 follow-up - alongside #90 (speaker/language) and cosyvoice3 (embed.*). + """Producer-side packer for the worker-connector data plane. + + The AR runner emits per-step ``pooling_output["hidden"]`` (the + thinker's last-layer hidden states for the request span, unpacked + from ``OmniOutput.text_hidden_states``). The full-payload + accumulator concatenates those per-step rows across decode steps, so + by the time this builder fires the materialized + ``pooling_output["hidden"]`` contains the full prefill+decode + hidden-state trajectory of size + ``len(prompt_token_ids) + len(output_token_ids)``. + + We split it at ``len(prompt_token_ids)`` into prefill embeddings and + decode hidden states, then pack the ``OmniPayload``-shaped dict that + the talker's ``thinker_to_talker_process`` already reads (keys + ``embed.prefill`` / ``hidden_states.output`` / ``ids.prompt`` / + ``ids.output``). Shape matches what + ``thinker2talker_token_only`` writes into + ``additional_information``, so the consumer-side coordinator gate + flip is a drop-in once the no-touch coordinator file is updated. + + Like ``qwen3_omni.thinker2talker_full_payload``, we apply a + finish-reason-aware stop-row trim: vLLM v1 appends the sampled + token to ``output_token_ids`` before ``check_stop``, so a request + that finished via ``FINISHED_STOPPED`` has one extra accumulated + hidden-state row that the talker must not consume. Max-token + finishes need no drop. Status is read from the request when + available; otherwise we fall back to a last-token-in-stop-set + heuristic. """ - del transfer_manager, pooling_output, request - return None + del transfer_manager + if not isinstance(pooling_output, dict): + return None + + hidden = pooling_output.get("hidden") + if not isinstance(hidden, torch.Tensor): + return None + + def _ensure_list(x): + if x is None: + return [] + if hasattr(x, "_x"): + return list(x._x) + if isinstance(x, list): + return list(x) + return list(x) + + prompt_token_ids = _ensure_list(getattr(request, "prompt_token_ids", None)) + output_token_ids = _ensure_list(getattr(request, "output_token_ids", None)) + all_token_ids = _ensure_list(getattr(request, "all_token_ids", None) or []) + if not all_token_ids: + all_token_ids = list(prompt_token_ids) + list(output_token_ids) + + # Length-aware trim of accumulated thinker output, finish-reason-aware. + # Mirror qwen3_omni.thinker2talker_full_payload's logic so a stop-finish + # does not leak an extra hidden-state row to the talker. + status = getattr(request, "status", None) + status_name = getattr(status, "name", None) or "" + if not status_name and status is not None: + status_name = str(status).rsplit(".", 1)[-1] + stop_emission_drop = 1 if status_name == "FINISHED_STOPPED" else 0 + if stop_emission_drop == 0 and not status_name and output_token_ids: + # Worker-side CachedRequestState has no `.status` field in vLLM v1; + # fall back to a last-token-in-stop-set heuristic. + sampling_params = getattr(request, "sampling_params", None) + if sampling_params is not None: + stop_ids: set[int] = set() + ignore_eos = bool(getattr(sampling_params, "ignore_eos", False)) + for sid in getattr(sampling_params, "stop_token_ids", None) or (): + if isinstance(sid, int): + stop_ids.add(sid) + if not ignore_eos: + for eos in ( + getattr(sampling_params, "eos_token_id", None), + getattr(sampling_params, "_eos_token_id", None), + ): + if isinstance(eos, int): + stop_ids.add(eos) + for sid in ( + getattr(sampling_params, "all_stop_token_ids", None) + or getattr(sampling_params, "_all_stop_token_ids", None) + or () + ): + if isinstance(sid, int): + stop_ids.add(sid) + if stop_ids and output_token_ids[-1] in stop_ids: + stop_emission_drop = 1 + + # Trim accumulated thinker output based on stop_emission_drop computed + # above. Mirror qwen3_omni.thinker2talker_full_payload's contract: + # target_rows = len(all_token_ids) - stop_emission_drop + # which excludes the stop-emission row for FINISHED_STOPPED but keeps + # all rows for FINISHED_LENGTH_CAPPED (max_tokens) finishes. + if stop_emission_drop > 0 and len(output_token_ids) >= stop_emission_drop: + output_token_ids = output_token_ids[:-stop_emission_drop] + h = hidden.detach().cpu().to(torch.float32) + target_rows = max(0, len(all_token_ids) - stop_emission_drop) + if target_rows <= 0: + return None + if h.dim() >= 1 and h.shape[0] > target_rows: + logger.warning( + "qwen2_5_omni.thinker2talker_full_payload: excess hidden rows " + "(got %d, target %d, stop_drop %d) for req=%s; trimming", + int(h.shape[0]), + target_rows, + stop_emission_drop, + getattr(request, "request_id", None), + ) + h = h[:target_rows] + + prompt_len = len(prompt_token_ids) + if h.shape[0] < prompt_len: + # Under-captured prefill — defensively skip rather than ship a + # truncated payload that would confuse the talker's prefill path. + return None + + prefill_hidden = h[:prompt_len] + decode_hidden = h[prompt_len:] + + payload: OmniPayload = to_dict( + OmniPayloadStruct( + hidden_states=HiddenStatesStruct( + output=decode_hidden, + output_shape=list(decode_hidden.shape), + ), + embed=EmbeddingsStruct( + prefill=prefill_hidden, + prefill_shape=list(prefill_hidden.shape), + ), + ids=IdsStruct( + prompt=list(prompt_token_ids), + output=list(output_token_ids), + ), + ) + ) + # payload["meta"] removed — was the only diff vs legacy payload, causes mix_to_text_audio_001 failure + return payload diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 482d42228e4..c89ebcee023 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -98,6 +98,7 @@ def __init__(self, *args, **kwargs): "Qwen3TTSTalkerForConditionalGeneration", "Qwen3TTSCode2Wav", "CosyVoice3Model", + "DyninOmniForConditionalGeneration", } if getattr(self.model_config, "model_arch", None) in _BLOCK_A_INIT_ALLOWLIST: self.init_omni_connectors( diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index ec20c07801a..232224555d1 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -68,6 +68,7 @@ def __init__(self, *args, **kwargs): "Qwen3TTSTalkerForConditionalGeneration", "Qwen3TTSCode2Wav", "CosyVoice3Model", + "DyninOmniForConditionalGeneration", } if getattr(self.model_config, "model_arch", None) in _BLOCK_A_INIT_ALLOWLIST: self.init_omni_connectors( diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index ff6ed255218..032d87d27bc 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -253,6 +253,21 @@ def cleanup_finished_request(self, req_id: str) -> None: saves is added to ``_deferred_send_cleanup`` so the bg save's decrement path drains it without leaving orphans. """ + # Force-flush any pending full-payload accumulator entry before + # cleanup proceeds. Without this, finished requests with no + # downstream consumer (e.g. text-only on multi-modal arch) leave + # the entry orphaned in _pending_full_payload_send across requests, + # which empirically destabilises subsequent thinker forwards + # (test_thinker_prefix_caching regression). flush is a near-no-op + # for paths with no consumer, and idempotent when the entry has + # already been flushed by the scheduler-driven path. + try: + self.flush_full_payload_outputs({req_id}) + except Exception: + # Defensive: connector may not be initialised for archs + # outside the Block A allowlist. Cleanup must still proceed. + pass + ext_id = self._request_ids_mapping.pop(req_id, None) keys_to_clean: list[str] = [req_id] if ext_id is not None and ext_id != req_id: @@ -1007,11 +1022,11 @@ def register_chunk_recv(self, request: Any) -> None: if self._stage_id == 0: return request_id = request.request_id - self._request_ids_mapping[request_id] = getattr( - request, - "external_req_id", - request_id, - ) + # Codex Issue 3: explicit external_req_id=None should fall back to + # request_id; otherwise recv keys become `None__` and + # collide across requests. + ext = getattr(request, "external_req_id", None) + self._request_ids_mapping[request_id] = ext if ext is not None else request_id with self._lock: if request_id in self._stage_recv_req_ids: return @@ -2195,7 +2210,10 @@ def _resolve_external_req_id(self, request: Any, fallback_req_id: str) -> str: if mapped is not None: return mapped if request is not None: - return getattr(request, "external_req_id", fallback_req_id) + # Codex Issue 3: external_req_id may be explicitly None; fall back. + ext = getattr(request, "external_req_id", None) + if ext is not None: + return ext return fallback_req_id def _resolve_next_stage_id(self, model_config: Any) -> int: From b6e03475d4efc29bb51f3440e04d7786d5690dc4 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 20 May 2026 09:31:40 +0000 Subject: [PATCH 07/19] [PR3] v5: cosyvoice3 connector codec hook + conditional code2wav trim Layered fix for cosyvoice3 sync voice cloning: producer emits real codec via the worker connector only when the talker reaches a stop token; consumer overlays it into the placeholder input ids and trims the prompt-conditioning mel only when that connector path is active. Max-token fallbacks keep the legacy `additional_information` path intact so non-cosyvoice3 archs and other terminal conditions are unaffected. Producer (model side, `models/cosyvoice3/cosyvoice3.py`) ------------------------------------------------------- - New `build_pooler_payload(req_id, req_index, input_batch, sampled_token_ids, invalid_req_indices)` hook discovered via duck-typed `getattr` on the runner. Returns `{"codes.audio": tensor}` only when the per-request codec history is non-empty AND the talker has emitted at least one stop-class token (id >= speech_token_size). Mid-step or pre-finish polling returns `None` so the connector message is shipped exactly once at finish. - `_pooler_codec_rows` keeps cumulative codec ids in `self._pooler_codec_history_by_req` and tracks two per-req sentinels: - `_pooler_codec_sampled_seen_by_req`: any in-vocab token observed for the request. - `_pooler_codec_sampled_finished_by_req`: a stop-class id was observed. Sampled path is preferred; on cold-start it falls back to `_pooler_output_history_from_input_batch` (vllm leaves the decoded slots at -1 under `prefer_model_sampler=True`, so this fallback is normally inert here but kept for resume paths). SIP (`stage_input_processors/cosyvoice3.py`) -------------------------------------------- - `text2flow_full_payload` reads `pooling_output["codes.audio"]` (flat dotted + nested fallback), tensor-wraps into `codes.audio`, and sets `meta.next_stage_prompt_len = len(token_ids)` for the overlay length contract. - `_FULL_PAYLOAD_REPLACE_KEYS` adds `codes.audio` (per-step payload carries cumulative codec, not delta). - `text2flow_token_only` keeps the legacy `additional_information` packing (multimodal_output + `ids.prompt`) so the orchestrator still has a usable fallback when `codes.audio` is not shipped. Runner dispatch (`worker/gpu_ar_model_runner.py`) ------------------------------------------------- - `_attach_model_pooler_payload` invokes the model hook and merges returned keys via `_pooler_payload_has_key` (handles both flat dotted and nested layouts). Non-cosyvoice3 archs fall through unchanged. - `_output_token_ids_for_model_sampler` now trims at the first -1 per request so the model sampler never sees placeholder slots. Consumer overlay (`worker/gpu_generation_model_runner.py`) ---------------------------------------------------------- - `_overlay_full_payload_input_ids` runs before each generation step (non-async-chunk path). For each scheduled request, looks up the connector payload via `model_intermediate_buffer`, reads `_payload_audio_codes`, flattens to a 1-D tensor, and copies it into the placeholder slots in `input_ids`. Length mismatch is a loud RuntimeError (drift catches misaligned producer + scheduler next_stage_prompt_len contract). code2wav trim (`models/cosyvoice3/cosyvoice3_code2wav.py`) --------------------------------------------------------- - Sync `forward()` now accepts `token_offset_tokens: int = 0` and threads it through to `_forward_mel`; the model-side caller in `cosyvoice3.py` passes `speech_token.shape[1]` only when `payload.codes is not None and payload.codes.audio is not None`, i.e., the request actually traveled the connector path. Legacy fallback continues to call `forward()` with the default (0) so pre-existing en_001 behavior is preserved. Tests ----- - `tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py` extends with hook contract cases (sampled-vs-history priority, finish gating, cache reuse). - `tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py` covers the new sentinel sets. - `tests/worker/test_omni_gpu_model_runner.py` adds dispatch + dotted-key resolution coverage for the new hook. Verified on H800 dev environment with `--run-level full_model -m "full_model and tts"`: * voice_clone_zh_001 (sync, connector path active): payload_keys=["codes", "embed", "meta"] code_len=292 similarity=0.903 PASS * voice_clone_en_001 (sync, legacy fallback when talker hits max-tokens without stop): payload_keys=["embed", "meta"] code_len=None similarity=0.963 PASS Follow-up (not gating): `_pooler_codec_history_by_req`, `_pooler_codec_sampled_seen_by_req`, and `_pooler_codec_sampled_finished_by_req` are not yet pruned in `cleanup_finished_request`; long-running multi-request servers may accumulate per-req entries. CI is unaffected (per-test server). Signed-off-by: natureofnature --- .../cosyvoice3/test_cosyvoice3_components.py | 10 +- .../test_cosyvoice3_model_helpers.py | 76 ++++++++++++ tests/worker/test_omni_gpu_model_runner.py | 17 +++ .../models/cosyvoice3/cosyvoice3.py | 108 ++++++++++++++++++ .../models/cosyvoice3/cosyvoice3_code2wav.py | 3 +- .../stage_input_processors/cosyvoice3.py | 68 +++++++---- vllm_omni/worker/gpu_ar_model_runner.py | 98 +++++++++++++++- .../worker/gpu_generation_model_runner.py | 46 ++++++++ 8 files changed, 400 insertions(+), 26 deletions(-) diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py index 0e071f724e5..dd73c85cad9 100644 --- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py +++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py @@ -293,14 +293,22 @@ def inference(self, speech_feat, finalize=True): model = object.__new__(CosyVoice3Code2Wav) nn.Module.__init__(model) model.hift = DummyHiFT() - model._forward_mel = lambda **_: torch.ones((1, 80, 8), dtype=torch.float32) + forward_mel_calls = [] + + def fake_forward_mel(**kwargs): + forward_mel_calls.append(kwargs) + return torch.ones((1, 80, 8), dtype=torch.float32) + + model._forward_mel = fake_forward_mel out = model.forward( token=torch.tensor([[1, 2, 3]], dtype=torch.int32), prompt_token=torch.tensor([[4, 5]], dtype=torch.int32), prompt_feat=torch.ones((1, 4, 80), dtype=torch.float32), embedding=torch.ones((1, 192), dtype=torch.float32), + token_offset_tokens=2, ) assert out.shape == (1, 1, 8) assert model.hift.finalize_calls == [True] + assert forward_mel_calls[0]["token_offset_tokens"] == 2 diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py index 956d32528bb..cdf6cc358cd 100644 --- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py +++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py @@ -138,6 +138,56 @@ def _make_sampling_metadata( ) +def test_build_pooler_payload_waits_for_sampled_stop_token(): + model = _make_talker_model() + input_batch = SimpleNamespace( + req_id_to_index={"r1": 0}, + num_prompt_tokens=[3], + num_tokens_no_spec=[3], + token_ids_cpu=torch.tensor([[101, 102, 103]], dtype=torch.long), + ) + + payload = model.build_pooler_payload( + req_id="r1", + req_index=0, + input_batch=input_batch, + sampled_token_ids=[[10]], + invalid_req_indices=set(), + ) + assert payload is None + + payload = model.build_pooler_payload( + req_id="r1", + req_index=0, + input_batch=input_batch, + sampled_token_ids=[[20, 6562]], + invalid_req_indices=set(), + ) + assert payload is not None + assert torch.equal(payload["codes.audio"], torch.tensor([[10], [20]], dtype=torch.long)) + + +def test_build_pooler_payload_falls_back_to_input_batch_history(): + model = _make_talker_model() + input_batch = SimpleNamespace( + req_id_to_index={"r1": 0}, + num_prompt_tokens=[3], + num_tokens_no_spec=[8], + token_ids_cpu=torch.tensor([[101, 102, 103, 10, 20, 6562, -1, 30]], dtype=torch.long), + ) + + payload = model.build_pooler_payload( + req_id="r1", + req_index=0, + input_batch=input_batch, + sampled_token_ids=None, + invalid_req_indices=set(), + ) + + assert payload is not None + assert torch.equal(payload["codes.audio"], torch.tensor([[10], [20], [30]], dtype=torch.long)) + + def test_split_request_ids_uses_seq_token_counts(): CosyVoice3Model, _ = _cosyvoice3_model_and_runner() ids = torch.tensor([10, 11, 12, 13, 14], dtype=torch.long) @@ -265,6 +315,32 @@ def test_forward_uses_non_stream_decode_without_chunk_metadata(): assert len(model.code2wav.forward_streaming_calls) == 0 call = model.code2wav.forward_calls[0] assert call["token"].tolist() == [[0, 1, 2]] + assert call["token_offset_tokens"] == 0 + + +def test_forward_trims_non_streaming_connector_codes(): + model = _make_code2wav_model() + + runtime_info = [ + { + "embed": { + "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long), + "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), + "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), + }, + "codes": {"audio": torch.tensor([0, 1, 2], dtype=torch.long)}, + "meta": {"next_stage_prompt_len": 3}, + } + ] + + model.forward( + input_ids=torch.tensor([0, 1, 2], dtype=torch.long), + positions=torch.tensor([0, 1, 2], dtype=torch.long), + model_intermediate_buffer=runtime_info, + seq_token_counts=[3], + ) + + assert model.code2wav.forward_calls[0]["token_offset_tokens"] == 3 def test_forward_reuses_streaming_cache_state_between_chunks(): diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index 94ffd937ead..a605a5adfaf 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -160,6 +160,23 @@ class _DummyVllmConfig: return runner +def test_generation_overlay_full_payload_input_ids_replaces_placeholders(): + from vllm_omni.worker.gpu_generation_model_runner import GPUGenerationModelRunner + + runner = object.__new__(GPUGenerationModelRunner) + runner.model_config = SimpleNamespace(async_chunk=False) + runner.model_intermediate_buffer = { + "r1": {"codes": {"audio": [9, 8]}}, + "r2": {"codes": {"audio": torch.tensor([7, 6, 5], dtype=torch.long)}}, + } + runner.query_start_loc = SimpleNamespace(cpu=torch.tensor([0, 2], dtype=torch.int32)) + input_ids = torch.zeros(5, dtype=torch.long) + + GPUGenerationModelRunner._overlay_full_payload_input_ids(runner, input_ids, ["r1", "r2"], [2, 3]) + + assert input_ids.tolist() == [9, 8, 7, 6, 5] + + def test_talker_mtp_forward_cpu_updates_inputs_and_info(monkeypatch): # Patch the module-level `set_forward_context` symbol used inside # OmniGPUModelRunner._talker_mtp_forward. diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py index 5023307ff8c..4ce11a4ec07 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py @@ -5,6 +5,7 @@ from dataclasses import replace from functools import partial from threading import Lock +from typing import Any import torch import torch.nn as nn @@ -532,6 +533,111 @@ def _cosyvoice3_ras_enabled(self, sampling_metadata: SamplingMetadata) -> bool: return False return True + def build_pooler_payload( + self, + *, + req_id: str, + req_index: int, + input_batch: Any, + sampled_token_ids: Any | None = None, + invalid_req_indices: set[int] | None = None, + ) -> dict[str, object] | None: + if self.model_stage != "cosyvoice3_talker": + return None + codec_rows = self._pooler_codec_rows( + req_id=req_id, + req_index=req_index, + input_batch=input_batch, + sampled_token_ids=sampled_token_ids, + invalid_req_indices=invalid_req_indices, + ) + if codec_rows is None: + return None + return {"codes.audio": codec_rows} + + def _pooler_codec_rows( + self, + *, + req_id: str, + req_index: int, + input_batch: Any, + sampled_token_ids: Any | None, + invalid_req_indices: set[int] | None, + ) -> torch.Tensor | None: + input_req_index = getattr(input_batch, "req_id_to_index", {}).get(req_id) + if input_req_index is None: + return None + + speech_token_size = int(self.config.llm["speech_token_size"]) + cache = getattr(self, "_pooler_codec_history_by_req", None) + if cache is None: + cache = {} + self._pooler_codec_history_by_req = cache + sampled_seen = getattr(self, "_pooler_codec_sampled_seen_by_req", None) + if sampled_seen is None: + sampled_seen = set() + self._pooler_codec_sampled_seen_by_req = sampled_seen + sampled_finished = getattr(self, "_pooler_codec_sampled_finished_by_req", None) + if sampled_finished is None: + sampled_finished = set() + self._pooler_codec_sampled_finished_by_req = sampled_finished + + if sampled_token_ids is not None and (invalid_req_indices is None or req_index not in invalid_req_indices): + sampled_ids = self._pooler_sampled_token_ids(sampled_token_ids, req_index) + if sampled_ids: + sampled_seen.add(req_id) + current = cache.setdefault(req_id, []) + current.extend(token_id for token_id in sampled_ids if 0 <= token_id < speech_token_size) + if any(token_id >= speech_token_size for token_id in sampled_ids): + sampled_finished.add(req_id) + elif req_id not in cache: + history = self._pooler_output_history_from_input_batch( + input_batch, + input_req_index, + speech_token_size, + ) + if history: + cache[req_id] = history + + token_ids = cache.get(req_id, []) + if not token_ids or (req_id in sampled_seen and req_id not in sampled_finished): + return None + return torch.tensor(token_ids, dtype=torch.long).reshape(-1, 1) + + @staticmethod + def _pooler_output_history_from_input_batch( + input_batch: Any, + req_index: int, + speech_token_size: int, + ) -> list[int]: + prompt_lens = getattr(input_batch, "num_prompt_tokens", None) + num_tokens = getattr(input_batch, "num_tokens_no_spec", None) + token_ids_cpu = getattr(input_batch, "token_ids_cpu", None) + if prompt_lens is None or num_tokens is None or token_ids_cpu is None: + return [] + start = int(prompt_lens[req_index]) + end = int(num_tokens[req_index]) + if end <= start: + return [] + return [ + int(token_id) + for token_id in token_ids_cpu[req_index, start:end].tolist() + if 0 <= int(token_id) < speech_token_size + ] + + @staticmethod + def _pooler_sampled_token_ids(sampled_token_ids: Any, req_index: int) -> list[int]: + if sampled_token_ids is None or req_index >= len(sampled_token_ids): + return [] + req_sampled_ids = sampled_token_ids[req_index] + if isinstance(req_sampled_ids, torch.Tensor): + req_sampled_ids = req_sampled_ids.detach().to("cpu").reshape(-1).tolist() + elif not isinstance(req_sampled_ids, list): + req_sampled_ids = list(req_sampled_ids) if req_sampled_ids is not None else [] + if -1 in req_sampled_ids: + req_sampled_ids = req_sampled_ids[: req_sampled_ids.index(-1)] + return [int(token_id) for token_id in req_sampled_ids] + def sample( self, logits: torch.Tensor, @@ -784,12 +890,14 @@ def forward( else: self._stream_vocoder_cache_by_req[req_id] = new_cache_state else: + uses_connector_codes = payload.codes is not None and payload.codes.audio is not None tts_speech = self.code2wav.forward( token=token.unsqueeze(0), prompt_token=speech_token[:1], prompt_feat=speech_feat[:1], embedding=embedding[:1], n_timesteps=10, + token_offset_tokens=speech_token.shape[1] if uses_connector_codes else 0, ) audio = tts_speech.reshape(-1).to(dtype=torch.float32) diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py index 186a258c809..cb3228c13a7 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py @@ -292,6 +292,7 @@ def forward( prompt_feat: torch.Tensor, embedding: torch.Tensor, n_timesteps: int = 10, + token_offset_tokens: int = 0, ) -> torch.Tensor: """Generate audio waveform from speech tokens.""" feat = self._forward_mel( @@ -300,7 +301,7 @@ def forward( prompt_feat=prompt_feat, embedding=embedding, n_timesteps=n_timesteps, - token_offset_tokens=0, + token_offset_tokens=token_offset_tokens, streaming=False, finalize=True, ) diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index 912c227fb5e..d6686414758 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -46,6 +46,25 @@ def _ensure_list(x: Any) -> list[Any]: return [x] +def _to_token_id_list(value: Any) -> list[int]: + if value is None: + return [] + if isinstance(value, torch.Tensor): + value = value.detach().to("cpu").reshape(-1).tolist() + token_ids: list[int] = [] + for item in _ensure_list(value): + if isinstance(item, torch.Tensor): + token_ids.extend(_to_token_id_list(item)) + continue + if isinstance(item, (list, tuple)): + token_ids.extend(_to_token_id_list(item)) + continue + token_id = int(item) + if token_id >= 0: + token_ids.append(token_id) + return token_ids + + def _to_cpu_tensor(x: Any) -> torch.Tensor | None: if isinstance(x, list): if not x: @@ -287,7 +306,9 @@ def talker2code2wav_async_chunk( # CONCAT across the (already trivial) per-request accumulator history so a # regression where decode unexpectedly re-emits them does not silently # duplicate the prefill tensor. See mixin._FULL_PAYLOAD_REPLACE_KEYS. -_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset({"embed.speech_token", "embed.speech_feat", "embed.embedding"}) +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset( + {"codes.audio", "embed.speech_token", "embed.speech_feat", "embed.embedding"} +) def text2flow_token_only( @@ -297,15 +318,9 @@ def text2flow_token_only( ): """Sync-side builder for the non-async-chunk text→flow path. - Mirrors the legacy `text2flow` shape but strips the prefill embed - tensors out of `additional_information` — those travel via the worker - connector payload built by `text2flow_full_payload`. Small metadata - (`ids.prompt` = the talker's original prompt token prefix) stays - inline so the flow stage can still locate the prefix. - - prompt_token_ids = the talker's `cumulative_token_ids` (real codec - tokens, not [0]*N) because the flow stage consumes talker output - verbatim. + Connector-delivered codec ids replace these only when the talker reached + a real stop token; max-token fallbacks keep this legacy token path and + prompt conditioning metadata. """ del prompt engine_inputs: list[OmniTokensPrompt] = [] @@ -315,11 +330,6 @@ def text2flow_token_only( output = source_output.outputs[0] output_ids = _ensure_list(output.cumulative_token_ids) prefix_ids = _ensure_list(source_output.prompt_token_ids) - # Preserve full multimodal_output (incl. embed.*) in additional_information - # for now: the worker connector wire to model_intermediate_buffer on the - # code2wav side is not yet plumbed. Filed as Phase 4 follow-up (parallel - # to #90 speaker/language re-routing). text2flow_full_payload still ships - # embed.* via the connector — currently redundant but ready for activation. multi_modal_data = output.multimodal_output if multi_modal_data is None: raise RuntimeError(f"Missing multimodal_output for request {source_output.request_id}") @@ -346,13 +356,13 @@ def text2flow_full_payload( ): """Producer-side packer. - Reads the prefill-emitted `embed.{speech_token, speech_feat, embedding}` - from the accumulator (flat dotted keys after flatten_payload; nested - fallback for safety) and ships them as a single connector payload. + Reads accumulated talker codec ids plus prefill-emitted + `embed.{speech_token, speech_feat, embedding}` from the accumulator and + ships them as a single connector payload. The downstream flow stage reads these from `model_intermediate_buffer` (see cosyvoice3.py:671 in the code2wav forward — runtime_info pickup). """ - del transfer_manager, request + del transfer_manager if not isinstance(pooling_output, dict): return None embed_out: dict[str, Any] = {} @@ -364,9 +374,21 @@ def text2flow_full_payload( v = nested.get(key) if isinstance(v, torch.Tensor) and v.numel() > 0: embed_out[key] = v - if not embed_out: + token_ids = _to_token_id_list(pooling_output.get("codes.audio")) + if not token_ids: + nested_codes = pooling_output.get("codes") + if isinstance(nested_codes, dict): + token_ids = _to_token_id_list(nested_codes.get("audio")) + if not embed_out and not token_ids: return None - return { - "embed": embed_out, - "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, + payload: dict[str, Any] = { + "meta": { + "finished": torch.tensor(True, dtype=torch.bool), + } } + if embed_out: + payload["embed"] = embed_out + if token_ids: + payload["codes"] = {"audio": torch.tensor(token_ids, dtype=torch.long)} + payload["meta"]["next_stage_prompt_len"] = len(token_ids) + return payload diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index c89ebcee023..033d2861e63 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -121,6 +121,94 @@ def _make_buffer(self, *size, dtype, numpy=True): with maybe_disable_pin_memory_for_ray(self, total_bytes): return super()._make_buffer(*size, dtype=dtype, numpy=numpy) + def _build_model_sampler_output_token_ids(self) -> list[list[int]]: + """Build decoded-token history for custom model samplers. + + vLLM only populates sampling_metadata.output_token_ids when penalties or + logits processors require it. CosyVoice3's custom RAS sampler also + depends on this history, so we reconstruct it directly from the input + batch for prefer_model_sampler models. + """ + req_output_token_ids = getattr(self.input_batch, "req_output_token_ids", []) + req_ids = list(getattr(self.input_batch, "req_ids", [])) + output_token_ids = [list(req_output_token_ids[idx] or []) for idx in range(len(req_ids))] + + sampled_token_ids_cpu = getattr(self.input_batch, "sampled_token_ids_cpu", None) + async_copy_ready_event = getattr(self.input_batch, "async_copy_ready_event", None) + prev_req_id_to_index = getattr(self.input_batch, "prev_req_id_to_index", None) + if sampled_token_ids_cpu is None or not output_token_ids or prev_req_id_to_index is None: + return output_token_ids + + sampled_token_ids: list[list[int]] | None = None + for index, req_id in enumerate(req_ids): + prev_index = prev_req_id_to_index.get(req_id) + if prev_index is None: + continue + req_history = output_token_ids[index] + if not req_history or req_history[-1] != -1: + continue + if sampled_token_ids is None: + assert async_copy_ready_event is not None + async_copy_ready_event.synchronize() + sampled_token_ids = sampled_token_ids_cpu.tolist() + new_ids = list(sampled_token_ids[prev_index]) + if not new_ids: + continue + num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1) + first_placeholder = req_history.index(-1) + num_placeholders = len(req_history) - first_placeholder + num_to_replace = min(num_sampled_ids, num_placeholders) + req_history[first_placeholder : first_placeholder + num_to_replace] = new_ids[:num_to_replace] + + for index, req_history in enumerate(output_token_ids): + if -1 in req_history: + output_token_ids[index] = req_history[: req_history.index(-1)] + + return output_token_ids + + def _sampling_metadata_for_model_sampler(self, sampling_metadata): + output_token_ids = self._build_model_sampler_output_token_ids() + if output_token_ids == sampling_metadata.output_token_ids: + return sampling_metadata + return replace(sampling_metadata, output_token_ids=output_token_ids) + + @staticmethod + def _pooler_payload_has_key(payload: dict[str, object], key: str) -> bool: + if payload.get(key) is not None: + return True + if "." not in key: + return False + cur: object = payload + for part in key.split("."): + if not isinstance(cur, dict) or part not in cur: + return False + cur = cur[part] + return cur is not None + + def _attach_model_pooler_payload( + self, + payload: dict[str, object], + req_id: str, + sampled_token_ids: Any, + req_index: int, + invalid_req_indices: set[int] | None, + ) -> None: + build_pooler_payload = getattr(self.model, "build_pooler_payload", None) + if not callable(build_pooler_payload): + return + updates = build_pooler_payload( + req_id=req_id, + req_index=req_index, + input_batch=self.input_batch, + sampled_token_ids=sampled_token_ids, + invalid_req_indices=invalid_req_indices, + ) + if not isinstance(updates, dict): + return + for key, value in updates.items(): + if value is not None and not self._pooler_payload_has_key(payload, key): + payload[key] = value + def _request_final_stage_id(self, req_id: str) -> int | None: info = self.model_intermediate_buffer.get(req_id) if not isinstance(info, dict): @@ -917,6 +1005,7 @@ def propose_draft_token_ids(sampled_token_ids): engine_output_type, downstream_req_ids = self._resolve_pooler_payload_req_ids(req_ids_output_copy) needs_pooler_payload = len(downstream_req_ids) > 0 downstream_req_id_set = set(downstream_req_ids) + invalid_req_indices_set = set(invalid_req_indices) hidden_states_cpu = None req_hidden_states_cpu: dict[str, torch.Tensor] | None = None if needs_pooler_payload: @@ -986,7 +1075,7 @@ def propose_draft_token_ids(sampled_token_ids): req_hidden_states_cpu[rid] = hidden_states[start:end].detach().to("cpu").contiguous() pooler_output = [] - for rid in req_ids_output_copy: + for out_idx, rid in enumerate(req_ids_output_copy): if rid not in downstream_req_id_set: pooler_output.append({}) continue @@ -1041,6 +1130,13 @@ def _unwrap_lists(v): seq_len=seq_len, ) payload.update(mm_payload) + self._attach_model_pooler_payload( + payload, + rid, + sampler_output.sampled_token_ids, + out_idx, + invalid_req_indices_set, + ) # Flatten nested dicts to dotted keys so pooling_output # stays dict[str, torch.Tensor] for msgspec serialization. pooler_output.append(flatten_payload(payload)) diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 232224555d1..a34ef57dd2f 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -76,6 +76,51 @@ def __init__(self, *args, **kwargs): model_config=self.model_config, ) + @staticmethod + def _flatten_audio_codes_to_tensor(codes, device: torch.device) -> torch.Tensor | None: + if codes is None: + return None + if isinstance(codes, torch.Tensor): + return codes.reshape(-1).to(device=device, dtype=torch.long) + if isinstance(codes, (list, tuple)): + if not codes: + return torch.empty(0, device=device, dtype=torch.long) + if all(isinstance(item, torch.Tensor) for item in codes): + return torch.cat([item.reshape(-1).to(device=device, dtype=torch.long) for item in codes], dim=0) + try: + return torch.as_tensor(codes, device=device, dtype=torch.long).reshape(-1) + except (TypeError, ValueError): + return None + + def _overlay_full_payload_input_ids( + self, + input_ids: torch.Tensor | None, + req_ids: list[str], + num_scheduled_tokens_np: np.ndarray, + ) -> None: + if input_ids is None or getattr(self.model_config, "async_chunk", False): + return + + for req_index, req_id in enumerate(req_ids): + scheduled = int(num_scheduled_tokens_np[req_index]) + if scheduled <= 0: + continue + payload = self.model_intermediate_buffer.get(req_id) + codes = self._payload_audio_codes(payload) + if codes is None: + continue + flat_codes = self._flatten_audio_codes_to_tensor(codes, input_ids.device) + if flat_codes is None or flat_codes.numel() == 0: + continue + start = int(self.query_start_loc.cpu[req_index]) + end = start + scheduled + if flat_codes.numel() != scheduled: + message = "full-payload input_ids override length mismatch for req=%s: payload=%d scheduled=%d" + message_args = (req_id, int(flat_codes.numel()), scheduled) + logger.error(message, *message_args) + raise RuntimeError(message % message_args) + input_ids[start:end].copy_(flat_codes.to(dtype=input_ids.dtype)) + def _update_request_states(self, scheduler_output: SchedulerOutput): # remove requests for req_id in scheduler_output.finished_req_ids: @@ -307,6 +352,7 @@ def execute_model( num_tokens_padded, intermediate_tensors, ) + self._overlay_full_payload_input_ids(input_ids, req_ids, num_scheduled_tokens_np) # [Omni] Pass token counts per request for code2wav output slicing model_kwargs["seq_token_counts"] = tokens From 00a1a756e00302bd477be2577f58f174952c958c Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 20 May 2026 17:29:19 +0000 Subject: [PATCH 08/19] [PR3] v6-final: revert codec connector path; legacy + SIP prompt-strip MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address issue on 809f6e16 (v5): - High #1: build_pooler_payload received `out_idx` but sampler_output/invalid_req_indices index by input_batch row. - High #2: gpu_generation_model_runner._overlay_full_payload_input_ids was CosyVoice3-specific in the common runner. - Medium #3: cosyvoice3._pooler_output_history_from_input_batch didn't stop at -1 placeholder. Resolution: drop the connector codec path for CosyVoice3 sync and deliver codec via legacy `additional_information`, strip prompt/reference prefix at the SIP layer, and gate code2wav mel-trim on the talker-prefill offset only when a speech-stop token was seen. Source changes: - gpu_ar_model_runner.py: remove build_pooler_payload hook + _attach_model_pooler_payload + _pooler_payload_has_key. - gpu_generation_model_runner.py: remove _flatten_audio_codes_to_tensor + _overlay_full_payload_input_ids + its call site. - cosyvoice3.py (model): remove build_pooler_payload + _pooler_codec_rows + _pooler_output_history_from_input_batch + _pooler_sampled_token_ids (and three per-req caches: _pooler_codec_history_by_req, _pooler_codec_sampled_seen_by_req, _pooler_codec_sampled_finished_by_req). code2wav.forward token_offset_tokens now reads `meta.talker_prefill_offset` (already a struct field used by qwen3_tts). - SIP cosyvoice3.py: text2flow + text2flow_token_only strip the prompt token prefix and the prompt speech_token prefix from cumulative_token_ids; set meta.talker_prefill_offset only when raw output contains a speech-stop token. text2flow_full_payload no longer ships codes.audio (embed/meta only). Drop `codes.audio` from _FULL_PAYLOAD_REPLACE_KEYS. - _to_token_id_list no longer filters negative ids (needed for stop-token detection on raw cumulative ids). Side effects: - v5's cosyvoice3 per-req cache leak is gone (no pooler hook → no accumulator). - The pre-existing baseline `voice_clone_zh_001[cosyvoice3]` sim=0.00 (transcript "先") failure is fixed. Verification on H800 GPU with `--run-level full_model -m "full_model and tts"`: - test_voice_clone_zh_001[cosyvoice3]: PASS sim=1.000 (baseline FAIL sim=0.00; v5 PASS sim=0.903) - test_voice_clone_en_001[cosyvoice3]: PASS sim=0.963 (baseline PASS sim=0.946; v5 PASS sim=0.963) Trade-off vs project_pr3_scope: CosyVoice3 sync codec stays on legacy additional_information; embed/prompt conditioning still ships via connector. Other PR3-migrated archs are unaffected (none consumed codes.audio via the removed overlay). Signed-off-by: natureofnature --- .../cosyvoice3/test_cosyvoice3_components.py | 3 +- .../test_cosyvoice3_model_helpers.py | 76 +----------- .../test_cosyvoice3_stage_input_processors.py | 73 +++++++++++- tests/worker/test_omni_gpu_model_runner.py | 17 --- .../models/cosyvoice3/cosyvoice3.py | 110 +----------------- .../stage_input_processors/cosyvoice3.py | 82 ++++++++----- vllm_omni/worker/gpu_ar_model_runner.py | 45 ------- .../worker/gpu_generation_model_runner.py | 47 -------- 8 files changed, 128 insertions(+), 325 deletions(-) diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py index dd73c85cad9..bf2261cb920 100644 --- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py +++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py @@ -306,9 +306,8 @@ def fake_forward_mel(**kwargs): prompt_token=torch.tensor([[4, 5]], dtype=torch.int32), prompt_feat=torch.ones((1, 4, 80), dtype=torch.float32), embedding=torch.ones((1, 192), dtype=torch.float32), - token_offset_tokens=2, ) assert out.shape == (1, 1, 8) assert model.hift.finalize_calls == [True] - assert forward_mel_calls[0]["token_offset_tokens"] == 2 + assert forward_mel_calls[0]["token_offset_tokens"] == 0 diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py index cdf6cc358cd..b0afc95921a 100644 --- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py +++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py @@ -138,77 +138,6 @@ def _make_sampling_metadata( ) -def test_build_pooler_payload_waits_for_sampled_stop_token(): - model = _make_talker_model() - input_batch = SimpleNamespace( - req_id_to_index={"r1": 0}, - num_prompt_tokens=[3], - num_tokens_no_spec=[3], - token_ids_cpu=torch.tensor([[101, 102, 103]], dtype=torch.long), - ) - - payload = model.build_pooler_payload( - req_id="r1", - req_index=0, - input_batch=input_batch, - sampled_token_ids=[[10]], - invalid_req_indices=set(), - ) - assert payload is None - - payload = model.build_pooler_payload( - req_id="r1", - req_index=0, - input_batch=input_batch, - sampled_token_ids=[[20, 6562]], - invalid_req_indices=set(), - ) - assert payload is not None - assert torch.equal(payload["codes.audio"], torch.tensor([[10], [20]], dtype=torch.long)) - - -def test_build_pooler_payload_falls_back_to_input_batch_history(): - model = _make_talker_model() - input_batch = SimpleNamespace( - req_id_to_index={"r1": 0}, - num_prompt_tokens=[3], - num_tokens_no_spec=[8], - token_ids_cpu=torch.tensor([[101, 102, 103, 10, 20, 6562, -1, 30]], dtype=torch.long), - ) - - payload = model.build_pooler_payload( - req_id="r1", - req_index=0, - input_batch=input_batch, - sampled_token_ids=None, - invalid_req_indices=set(), - ) - - assert payload is not None - assert torch.equal(payload["codes.audio"], torch.tensor([[10], [20], [30]], dtype=torch.long)) - - -def test_split_request_ids_uses_seq_token_counts(): - CosyVoice3Model, _ = _cosyvoice3_model_and_runner() - ids = torch.tensor([10, 11, 12, 13, 14], dtype=torch.long) - chunks = CosyVoice3Model._split_request_ids(ids, [2, 2, 2]) - assert [c.tolist() for c in chunks] == [[10, 11], [12, 13], [14]] - - -def test_split_request_ids_honors_single_request_seq_token_counts(): - CosyVoice3Model, _ = _cosyvoice3_model_and_runner() - ids = torch.tensor([10, 11, 12, 13, 14], dtype=torch.long) - chunks = CosyVoice3Model._split_request_ids(ids, [3]) - assert [c.tolist() for c in chunks] == [[10, 11, 12]] - - -def test_sanitize_codec_tokens_filters_out_of_range(): - model = _make_code2wav_model() - raw = torch.tensor([-1, 0, 3, 4, 99], dtype=torch.long) - clean = model._sanitize_codec_tokens(raw) - assert clean.tolist() == [0, 3] - - def test_forward_prefers_token_offset_when_present(): model = _make_code2wav_model() @@ -318,7 +247,7 @@ def test_forward_uses_non_stream_decode_without_chunk_metadata(): assert call["token_offset_tokens"] == 0 -def test_forward_trims_non_streaming_connector_codes(): +def test_forward_uses_non_stream_talker_prefill_offset(): model = _make_code2wav_model() runtime_info = [ @@ -328,8 +257,7 @@ def test_forward_trims_non_streaming_connector_codes(): "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32), "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32), }, - "codes": {"audio": torch.tensor([0, 1, 2], dtype=torch.long)}, - "meta": {"next_stage_prompt_len": 3}, + "meta": {"talker_prefill_offset": 3}, } ] diff --git a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py index d54533fd0cb..6debabb0e3f 100644 --- a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py +++ b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py @@ -6,13 +6,19 @@ import torch -from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import talker2code2wav_async_chunk, text2flow +from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import ( + talker2code2wav_async_chunk, + text2flow, + text2flow_full_payload, + text2flow_token_only, +) -def _source_output(request_id: str, prompt_ids: list[int], out_ids: list[int], mm: dict): +def _source_output(request_id: str, prompt_ids: list[int], out_ids: list[int], mm: dict, finished: bool = True): return SimpleNamespace( request_id=request_id, prompt_token_ids=prompt_ids, + finished=finished, outputs=[SimpleNamespace(token_ids=out_ids, cumulative_token_ids=out_ids, multimodal_output=mm)], ) @@ -45,8 +51,8 @@ def _transfer_manager( def test_text2flow_supports_batched_source_outputs(): source_outputs = [ - _source_output("req-0", [10, 11], [1, 2, 3], {"speech_token": torch.tensor([[1, 2]])}), - _source_output("req-1", [20, 21], [4, 5], {"speech_token": torch.tensor([[3, 4]])}), + _source_output("req-0", [10, 11], [1, 2, 3], {"speech_token": torch.tensor([[8, 9]])}), + _source_output("req-1", [20, 21], [4, 5], {"speech_token": torch.tensor([[6, 7]])}), ] outputs = text2flow(source_outputs=source_outputs, prompt=None) @@ -58,6 +64,65 @@ def test_text2flow_supports_batched_source_outputs(): assert outputs[1]["additional_information"]["ids"]["prompt"] == [20, 21] +def test_text2flow_strips_reference_speech_prefix_from_cumulative_ids(): + source_outputs = [ + _source_output("req-0", [10, 11], [8, 9, 1, 2, 3], {"speech_token": torch.tensor([[8, 9]])}), + ] + + outputs = text2flow(source_outputs=source_outputs, prompt=None) + + assert outputs[0]["prompt_token_ids"] == [1, 2, 3] + + +def test_text2flow_token_only_strips_reference_speech_prefix_from_cumulative_ids(): + source_outputs = [ + _source_output( + "req-strip", + [10, 11], + [4, 5, 1, 2, 3], + {"embed": {"speech_token": torch.tensor([[4, 5]])}}, + ) + ] + + outputs = text2flow_token_only(source_outputs=source_outputs, prompt=None) + + assert len(outputs) == 1 + assert outputs[0]["prompt_token_ids"] == [1, 2, 3] + assert outputs[0]["additional_information"]["ids"]["prompt"] == [10, 11] + + +def test_text2flow_token_only_marks_prompt_trim_for_stop_token_completion(): + source_outputs = [ + _source_output( + "req-stop", + [10, 11], + [4, 5, 1, 2, 6562], + {"embed": {"speech_token": torch.tensor([[4, 5]])}}, + ) + ] + + outputs = text2flow_token_only(source_outputs=source_outputs, prompt=None) + + assert outputs[0]["prompt_token_ids"] == [1, 2, 6562] + assert outputs[0]["additional_information"]["meta"]["talker_prefill_offset"] == 2 + + +def test_text2flow_full_payload_does_not_send_codec_ids(): + payload = text2flow_full_payload( + None, + { + "embed.speech_token": torch.tensor([[1, 2]], dtype=torch.long), + "codes.audio": torch.tensor([7, 8, 9], dtype=torch.long), + }, + SimpleNamespace(), + ) + + assert payload is not None + assert "codes" not in payload + assert "next_stage_prompt_len" not in payload["meta"] + assert torch.equal(payload["embed"]["speech_token"], torch.tensor([[1, 2]], dtype=torch.long)) + + def test_talker2code2wav_async_chunk_final_payload_uses_absolute_token_offset(): transfer_manager = _transfer_manager() request = SimpleNamespace( diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index a605a5adfaf..94ffd937ead 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -160,23 +160,6 @@ class _DummyVllmConfig: return runner -def test_generation_overlay_full_payload_input_ids_replaces_placeholders(): - from vllm_omni.worker.gpu_generation_model_runner import GPUGenerationModelRunner - - runner = object.__new__(GPUGenerationModelRunner) - runner.model_config = SimpleNamespace(async_chunk=False) - runner.model_intermediate_buffer = { - "r1": {"codes": {"audio": [9, 8]}}, - "r2": {"codes": {"audio": torch.tensor([7, 6, 5], dtype=torch.long)}}, - } - runner.query_start_loc = SimpleNamespace(cpu=torch.tensor([0, 2], dtype=torch.int32)) - input_ids = torch.zeros(5, dtype=torch.long) - - GPUGenerationModelRunner._overlay_full_payload_input_ids(runner, input_ids, ["r1", "r2"], [2, 3]) - - assert input_ids.tolist() == [9, 8, 7, 6, 5] - - def test_talker_mtp_forward_cpu_updates_inputs_and_info(monkeypatch): # Patch the module-level `set_forward_context` symbol used inside # OmniGPUModelRunner._talker_mtp_forward. diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py index 4ce11a4ec07..2dc81174cb0 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py @@ -5,7 +5,6 @@ from dataclasses import replace from functools import partial from threading import Lock -from typing import Any import torch import torch.nn as nn @@ -533,111 +532,6 @@ def _cosyvoice3_ras_enabled(self, sampling_metadata: SamplingMetadata) -> bool: return False return True - def build_pooler_payload( - self, - *, - req_id: str, - req_index: int, - input_batch: Any, - sampled_token_ids: Any | None = None, - invalid_req_indices: set[int] | None = None, - ) -> dict[str, object] | None: - if self.model_stage != "cosyvoice3_talker": - return None - codec_rows = self._pooler_codec_rows( - req_id=req_id, - req_index=req_index, - input_batch=input_batch, - sampled_token_ids=sampled_token_ids, - invalid_req_indices=invalid_req_indices, - ) - if codec_rows is None: - return None - return {"codes.audio": codec_rows} - - def _pooler_codec_rows( - self, - *, - req_id: str, - req_index: int, - input_batch: Any, - sampled_token_ids: Any | None, - invalid_req_indices: set[int] | None, - ) -> torch.Tensor | None: - input_req_index = getattr(input_batch, "req_id_to_index", {}).get(req_id) - if input_req_index is None: - return None - - speech_token_size = int(self.config.llm["speech_token_size"]) - cache = getattr(self, "_pooler_codec_history_by_req", None) - if cache is None: - cache = {} - self._pooler_codec_history_by_req = cache - sampled_seen = getattr(self, "_pooler_codec_sampled_seen_by_req", None) - if sampled_seen is None: - sampled_seen = set() - self._pooler_codec_sampled_seen_by_req = sampled_seen - sampled_finished = getattr(self, "_pooler_codec_sampled_finished_by_req", None) - if sampled_finished is None: - sampled_finished = set() - self._pooler_codec_sampled_finished_by_req = sampled_finished - - if sampled_token_ids is not None and (invalid_req_indices is None or req_index not in invalid_req_indices): - sampled_ids = self._pooler_sampled_token_ids(sampled_token_ids, req_index) - if sampled_ids: - sampled_seen.add(req_id) - current = cache.setdefault(req_id, []) - current.extend(token_id for token_id in sampled_ids if 0 <= token_id < speech_token_size) - if any(token_id >= speech_token_size for token_id in sampled_ids): - sampled_finished.add(req_id) - elif req_id not in cache: - history = self._pooler_output_history_from_input_batch( - input_batch, - input_req_index, - speech_token_size, - ) - if history: - cache[req_id] = history - - token_ids = cache.get(req_id, []) - if not token_ids or (req_id in sampled_seen and req_id not in sampled_finished): - return None - return torch.tensor(token_ids, dtype=torch.long).reshape(-1, 1) - - @staticmethod - def _pooler_output_history_from_input_batch( - input_batch: Any, - req_index: int, - speech_token_size: int, - ) -> list[int]: - prompt_lens = getattr(input_batch, "num_prompt_tokens", None) - num_tokens = getattr(input_batch, "num_tokens_no_spec", None) - token_ids_cpu = getattr(input_batch, "token_ids_cpu", None) - if prompt_lens is None or num_tokens is None or token_ids_cpu is None: - return [] - start = int(prompt_lens[req_index]) - end = int(num_tokens[req_index]) - if end <= start: - return [] - return [ - int(token_id) - for token_id in token_ids_cpu[req_index, start:end].tolist() - if 0 <= int(token_id) < speech_token_size - ] - - @staticmethod - def _pooler_sampled_token_ids(sampled_token_ids: Any, req_index: int) -> list[int]: - if sampled_token_ids is None or req_index >= len(sampled_token_ids): - return [] - req_sampled_ids = sampled_token_ids[req_index] - if isinstance(req_sampled_ids, torch.Tensor): - req_sampled_ids = req_sampled_ids.detach().to("cpu").reshape(-1).tolist() - elif not isinstance(req_sampled_ids, list): - req_sampled_ids = list(req_sampled_ids) if req_sampled_ids is not None else [] - if -1 in req_sampled_ids: - req_sampled_ids = req_sampled_ids[: req_sampled_ids.index(-1)] - return [int(token_id) for token_id in req_sampled_ids] - def sample( self, logits: torch.Tensor, @@ -890,14 +784,14 @@ def forward( else: self._stream_vocoder_cache_by_req[req_id] = new_cache_state else: - uses_connector_codes = payload.codes is not None and payload.codes.audio is not None + token_offset = max(0, meta.talker_prefill_offset or 0) if meta else 0 tts_speech = self.code2wav.forward( token=token.unsqueeze(0), prompt_token=speech_token[:1], prompt_feat=speech_feat[:1], embedding=embedding[:1], n_timesteps=10, - token_offset_tokens=speech_token.shape[1] if uses_connector_codes else 0, + token_offset_tokens=token_offset, ) audio = tts_speech.reshape(-1).to(dtype=torch.float32) diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index d6686414758..f355af8c308 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -16,6 +16,8 @@ ) from vllm_omni.inputs.data import OmniTokensPrompt +_COSYVOICE3_SPEECH_TOKEN_SIZE = 6561 + def _build_prompt_embed_struct(prompt_payload: dict[str, Any]) -> EmbeddingsStruct | None: """Wrap prompt_payload's flat speech_token/speech_feat/embedding tensors into EmbeddingsStruct.""" @@ -59,12 +61,39 @@ def _to_token_id_list(value: Any) -> list[int]: if isinstance(item, (list, tuple)): token_ids.extend(_to_token_id_list(item)) continue - token_id = int(item) - if token_id >= 0: - token_ids.append(token_id) + token_ids.append(int(item)) return token_ids +def _strip_prompt_prefix(output_ids: list[Any], prefix_ids: list[Any]) -> list[Any]: + if prefix_ids and len(output_ids) >= len(prefix_ids) and output_ids[: len(prefix_ids)] == prefix_ids: + return output_ids[len(prefix_ids) :] + return output_ids + + +def _prompt_speech_token_ids(multi_modal_data: dict[str, Any]) -> list[int]: + speech_token = multi_modal_data.get("speech_token") + if speech_token is None: + embed = multi_modal_data.get("embed") + if isinstance(embed, dict): + speech_token = embed.get("speech_token") + return _to_token_id_list(speech_token) + + +def _has_speech_stop_token(output_ids: list[Any]) -> bool: + return any(token_id >= _COSYVOICE3_SPEECH_TOKEN_SIZE for token_id in _to_token_id_list(output_ids)) + + +def _set_non_stream_prompt_trim(additional_info: dict[str, Any], prompt_speech_len: int) -> None: + if prompt_speech_len <= 0: + return + meta = additional_info.get("meta") + if not isinstance(meta, dict): + meta = {} + additional_info["meta"] = meta + meta["talker_prefill_offset"] = prompt_speech_len + + def _to_cpu_tensor(x: Any) -> torch.Tensor | None: if isinstance(x, list): if not x: @@ -114,9 +143,14 @@ def text2flow( if multi_modal_data is None: raise RuntimeError(f"Missing multimodal_output for request {source_output.request_id}") - output_ids = _ensure_list(output.cumulative_token_ids) prefix_ids = _ensure_list(source_output.prompt_token_ids) + raw_output_ids = _ensure_list(output.cumulative_token_ids) + prompt_speech_ids = _prompt_speech_token_ids(multi_modal_data) + output_ids = _strip_prompt_prefix(raw_output_ids, prefix_ids) + output_ids = _strip_prompt_prefix(output_ids, prompt_speech_ids) additional_info = dict(multi_modal_data) + if _has_speech_stop_token(raw_output_ids): + _set_non_stream_prompt_trim(additional_info, len(prompt_speech_ids)) additional_info.setdefault("ids", {})["prompt"] = prefix_ids engine_inputs.append(OmniTokensPrompt(prompt_token_ids=output_ids, additional_information=additional_info)) return engine_inputs @@ -306,9 +340,7 @@ def talker2code2wav_async_chunk( # CONCAT across the (already trivial) per-request accumulator history so a # regression where decode unexpectedly re-emits them does not silently # duplicate the prefill tensor. See mixin._FULL_PAYLOAD_REPLACE_KEYS. -_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset( - {"codes.audio", "embed.speech_token", "embed.speech_feat", "embed.embedding"} -) +_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset({"embed.speech_token", "embed.speech_feat", "embed.embedding"}) def text2flow_token_only( @@ -318,9 +350,9 @@ def text2flow_token_only( ): """Sync-side builder for the non-async-chunk text→flow path. - Connector-delivered codec ids replace these only when the talker reached - a real stop token; max-token fallbacks keep this legacy token path and - prompt conditioning metadata. + CosyVoice3 sync keeps codec ids on the legacy token path. Some vLLM v1 + histories include the source prompt prefix, so strip it only when it is an + exact leading match. """ del prompt engine_inputs: list[OmniTokensPrompt] = [] @@ -328,12 +360,17 @@ def text2flow_token_only( if not source_output.finished: continue output = source_output.outputs[0] - output_ids = _ensure_list(output.cumulative_token_ids) prefix_ids = _ensure_list(source_output.prompt_token_ids) + raw_output_ids = _ensure_list(output.cumulative_token_ids) + output_ids = _strip_prompt_prefix(raw_output_ids, prefix_ids) multi_modal_data = output.multimodal_output if multi_modal_data is None: raise RuntimeError(f"Missing multimodal_output for request {source_output.request_id}") + prompt_speech_ids = _prompt_speech_token_ids(multi_modal_data) + output_ids = _strip_prompt_prefix(output_ids, prompt_speech_ids) additional_info: dict[str, Any] = dict(multi_modal_data) + if _has_speech_stop_token(raw_output_ids): + _set_non_stream_prompt_trim(additional_info, len(prompt_speech_ids)) additional_info.setdefault("ids", {})["prompt"] = prefix_ids engine_inputs.append( OmniTokensPrompt( @@ -356,9 +393,8 @@ def text2flow_full_payload( ): """Producer-side packer. - Reads accumulated talker codec ids plus prefill-emitted - `embed.{speech_token, speech_feat, embedding}` from the accumulator and - ships them as a single connector payload. + Reads prefill-emitted `embed.{speech_token, speech_feat, embedding}` from + the accumulator and ships prompt conditioning as a connector payload. The downstream flow stage reads these from `model_intermediate_buffer` (see cosyvoice3.py:671 in the code2wav forward — runtime_info pickup). """ @@ -374,21 +410,11 @@ def text2flow_full_payload( v = nested.get(key) if isinstance(v, torch.Tensor) and v.numel() > 0: embed_out[key] = v - token_ids = _to_token_id_list(pooling_output.get("codes.audio")) - if not token_ids: - nested_codes = pooling_output.get("codes") - if isinstance(nested_codes, dict): - token_ids = _to_token_id_list(nested_codes.get("audio")) - if not embed_out and not token_ids: + if not embed_out: return None - payload: dict[str, Any] = { + return { "meta": { "finished": torch.tensor(True, dtype=torch.bool), - } + }, + "embed": embed_out, } - if embed_out: - payload["embed"] = embed_out - if token_ids: - payload["codes"] = {"audio": torch.tensor(token_ids, dtype=torch.long)} - payload["meta"]["next_stage_prompt_len"] = len(token_ids) - return payload diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 033d2861e63..fb6be8eb231 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -172,43 +172,6 @@ def _sampling_metadata_for_model_sampler(self, sampling_metadata): return sampling_metadata return replace(sampling_metadata, output_token_ids=output_token_ids) - @staticmethod - def _pooler_payload_has_key(payload: dict[str, object], key: str) -> bool: - if payload.get(key) is not None: - return True - if "." not in key: - return False - cur: object = payload - for part in key.split("."): - if not isinstance(cur, dict) or part not in cur: - return False - cur = cur[part] - return cur is not None - - def _attach_model_pooler_payload( - self, - payload: dict[str, object], - req_id: str, - sampled_token_ids: Any, - req_index: int, - invalid_req_indices: set[int] | None, - ) -> None: - build_pooler_payload = getattr(self.model, "build_pooler_payload", None) - if not callable(build_pooler_payload): - return - updates = build_pooler_payload( - req_id=req_id, - req_index=req_index, - input_batch=self.input_batch, - sampled_token_ids=sampled_token_ids, - invalid_req_indices=invalid_req_indices, - ) - if not isinstance(updates, dict): - return - for key, value in updates.items(): - if value is not None and not self._pooler_payload_has_key(payload, key): - payload[key] = value - def _request_final_stage_id(self, req_id: str) -> int | None: info = self.model_intermediate_buffer.get(req_id) if not isinstance(info, dict): @@ -1005,7 +968,6 @@ def propose_draft_token_ids(sampled_token_ids): engine_output_type, downstream_req_ids = self._resolve_pooler_payload_req_ids(req_ids_output_copy) needs_pooler_payload = len(downstream_req_ids) > 0 downstream_req_id_set = set(downstream_req_ids) - invalid_req_indices_set = set(invalid_req_indices) hidden_states_cpu = None req_hidden_states_cpu: dict[str, torch.Tensor] | None = None if needs_pooler_payload: @@ -1130,13 +1092,6 @@ def _unwrap_lists(v): seq_len=seq_len, ) payload.update(mm_payload) - self._attach_model_pooler_payload( - payload, - rid, - sampler_output.sampled_token_ids, - out_idx, - invalid_req_indices_set, - ) # Flatten nested dicts to dotted keys so pooling_output # stays dict[str, torch.Tensor] for msgspec serialization. pooler_output.append(flatten_payload(payload)) diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index a34ef57dd2f..cf59113ca9b 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -76,51 +76,6 @@ def __init__(self, *args, **kwargs): model_config=self.model_config, ) - @staticmethod - def _flatten_audio_codes_to_tensor(codes, device: torch.device) -> torch.Tensor | None: - if codes is None: - return None - if isinstance(codes, torch.Tensor): - return codes.reshape(-1).to(device=device, dtype=torch.long) - if isinstance(codes, (list, tuple)): - if not codes: - return torch.empty(0, device=device, dtype=torch.long) - if all(isinstance(item, torch.Tensor) for item in codes): - return torch.cat([item.reshape(-1).to(device=device, dtype=torch.long) for item in codes], dim=0) - try: - return torch.as_tensor(codes, device=device, dtype=torch.long).reshape(-1) - except (TypeError, ValueError): - return None - - def _overlay_full_payload_input_ids( - self, - input_ids: torch.Tensor | None, - req_ids: list[str], - num_scheduled_tokens_np: np.ndarray, - ) -> None: - if input_ids is None or getattr(self.model_config, "async_chunk", False): - return - - for req_index, req_id in enumerate(req_ids): - scheduled = int(num_scheduled_tokens_np[req_index]) - if scheduled <= 0: - continue - payload = self.model_intermediate_buffer.get(req_id) - codes = self._payload_audio_codes(payload) - if codes is None: - continue - flat_codes = self._flatten_audio_codes_to_tensor(codes, input_ids.device) - if flat_codes is None or flat_codes.numel() == 0: - continue - start = int(self.query_start_loc.cpu[req_index]) - end = start + scheduled - if flat_codes.numel() != scheduled: - message = "full-payload input_ids override length mismatch for req=%s: payload=%d scheduled=%d" - message_args = (req_id, int(flat_codes.numel()), scheduled) - logger.error(message, *message_args) - raise RuntimeError(message % message_args) - input_ids[start:end].copy_(flat_codes.to(dtype=input_ids.dtype)) - def _update_request_states(self, scheduler_output: SchedulerOutput): # remove requests for req_id in scheduler_output.finished_req_ids: @@ -352,8 +307,6 @@ def execute_model( num_tokens_padded, intermediate_tensors, ) - self._overlay_full_payload_input_ids(input_ids, req_ids, num_scheduled_tokens_np) - # [Omni] Pass token counts per request for code2wav output slicing model_kwargs["seq_token_counts"] = tokens From 5566b6c838782514ed9b760fa638082c4bf4441a Mon Sep 17 00:00:00 2001 From: natureofnature Date: Thu, 21 May 2026 07:56:05 +0000 Subject: [PATCH 09/19] [PR3] cleanup: scrub PR-internal markers, rename Block-A allowlist, fix docstring accuracy Comment-and-naming cleanup across PR3-touched files. Signed-off-by: natureofnature --- .../sched/test_omni_scheduling_coordinator.py | 10 +-- .../test_qwen3_omni_streaming_helpers.py | 23 +++--- tests/worker/test_omni_gpu_model_runner.py | 14 ++-- .../core/sched/omni_scheduling_coordinator.py | 24 +++--- .../stage_input_processors/cosyvoice3.py | 12 +-- .../stage_input_processors/covo_audio.py | 12 +-- .../stage_input_processors/dynin_omni.py | 11 +-- .../stage_input_processors/mimo_audio.py | 16 ++-- .../stage_input_processors/ming_flash_omni.py | 39 +++++----- .../stage_input_processors/qwen2_5_omni.py | 74 +++++++++++-------- .../stage_input_processors/qwen3_omni.py | 17 ++--- .../stage_input_processors/qwen3_tts.py | 14 ++-- vllm_omni/worker/gpu_ar_model_runner.py | 16 ++-- .../worker/gpu_generation_model_runner.py | 6 +- .../omni_connector_model_runner_mixin.py | 32 ++++---- 15 files changed, 165 insertions(+), 155 deletions(-) diff --git a/tests/core/sched/test_omni_scheduling_coordinator.py b/tests/core/sched/test_omni_scheduling_coordinator.py index 74670b06c7f..a6d4920e3d5 100644 --- a/tests/core/sched/test_omni_scheduling_coordinator.py +++ b/tests/core/sched/test_omni_scheduling_coordinator.py @@ -94,11 +94,11 @@ def remove_requests(self, requests): class TestFullPayloadCoordinatorSelection(unittest.TestCase): """Tests for the (model_arch, model_stage) whitelist gate. - The gate scope must stay aligned with init_omni_connectors arch scope in - gpu_ar_model_runner.py / gpu_generation_model_runner.py. Until those init - sites are generalised (planned for a later PR matching the tmp/trim_refactor - branch shape), only Qwen3-Omni talker / code2wav route full_payload stage - input through the worker connector. + The init_omni_connectors arch allowlist is keyed by ``model_arch`` and + is a superset of the stages registered here -- consumer-wait stages + must be registered explicitly in ``_FULL_PAYLOAD_INPUT_STAGES``, while + the init allowlist covers both producer- and consumer-side runners. + These tests pin which ``(arch, stage)`` pairs the gate fires for today. """ def test_all_whitelisted_arch_stage_pairs_fire_gate(self): diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index 7b176b29b75..72b39ac0546 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -220,11 +220,10 @@ def test_thinker2talker_full_payload_drops_stop_emission_row_when_finished_stopp """FINISHED_STOPPED: drop 1 extra row even when rows == target. vLLM appends the stop-token to output_token_ids before check_stop, so - len(all_token_ids) includes the stop slot AND the accumulator has the - stop emission's forward row. Both counts equal P+O (here 3). Talker + len(all_token_ids) includes the stop slot AND the full-payload + accumulator has the stop emission's forward row. Both counts equal P+O (here 3). Talker target should be P+O-1 (=2), not P+O. Without the extra drop the - stop emission's hidden state leaks into talker prefill (fba23325 - spurious-phoneme regression). + stop emission's hidden state leaks into talker prefill. """ request = SimpleNamespace( request_id="thinker-stop-finished", @@ -272,7 +271,7 @@ def test_thinker2talker_full_payload_drops_stop_emission_via_eos_fallback() -> N def test_thinker2talker_full_payload_no_drop_when_finished_length_capped() -> None: - """FINISHED_LENGTH_CAPPED (max_tokens): no extra drop; BK 9702 regression guard.""" + """FINISHED_LENGTH_CAPPED (max_tokens): no extra drop applied.""" request = SimpleNamespace( request_id="thinker-length-capped", prompt_token_ids=[151644, 872], @@ -528,7 +527,7 @@ def __init__(self, tids): def test_covo_audio_llm2code2wav_full_payload_smoke() -> None: - """Smoke: covo_audio producer-side packer returns audio_codes + finished.""" + """Smoke: covo_audio producer-side payload builder returns audio_codes + finished.""" from types import SimpleNamespace from vllm_omni.model_executor.models.covo_audio.config_covo_audio import COVO_AUDIO_TOKEN_INDEX @@ -577,7 +576,7 @@ def __init__(self, outs): def test_dynin_omni_full_payload_smoke() -> None: - """Smoke: dynin_omni producer-side packer returns token_ids + finished.""" + """Smoke: dynin_omni producer-side payload builder returns token_ids + finished.""" from types import SimpleNamespace from vllm_omni.model_executor.stage_input_processors.dynin_omni import ( @@ -621,7 +620,7 @@ def __init__(self, tids): def test_qwen2_5_omni_talker2code2wav_full_payload_smoke() -> None: - """Smoke: qwen2_5_omni producer-side packer strips boundaries.""" + """Smoke: qwen2_5_omni producer-side payload builder strips boundaries.""" from types import SimpleNamespace from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( @@ -668,7 +667,7 @@ def __init__(self, mm): def test_mimo_audio_llm2code2wav_full_payload_smoke() -> None: - """Smoke: mimo_audio producer-side packer reads flat codes.audio + flattens.""" + """Smoke: mimo_audio producer-side payload builder reads flat codes.audio + flattens.""" from types import SimpleNamespace import torch @@ -831,7 +830,7 @@ def __init__(self, output_tids, prompt_tids): self.prompt_token_ids = prompt_tids self.finished = True - # multimodal_output has embed.* + we expect token_only to preserve it (Phase 4 #90 follow-up). + # multimodal_output has embed.* + we expect token_only to preserve it. import torch embed = {"speech_token": torch.zeros(2, 4)} @@ -936,7 +935,7 @@ def __init__(self, info): def test_ming_flash_omni_thinker2talker_full_payload_noop() -> None: - """thinker2talker_full_payload returns None — no heavy tensor migration.""" + """ming_flash_omni thinker2talker_full_payload is a no-op (returns None).""" from vllm_omni.model_executor.stage_input_processors.ming_flash_omni import ( thinker2talker_full_payload, ) @@ -987,7 +986,7 @@ class _Prompt(dict): def test_qwen2_5_omni_thinker2talker_full_payload_noop() -> None: - """thinker2talker_full_payload returns None — no heavy tensor migration today.""" + """thinker2talker_full_payload returns None when pooling_output lacks the "hidden" key (defensive).""" from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( thinker2talker_full_payload, ) diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index 94ffd937ead..d04eab3b334 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -425,7 +425,7 @@ def test_accumulate_full_payload_output_preserves_aligned_all_zero_qwen3_omni_co def test_accumulate_full_payload_output_keeps_misaligned_all_zero_qwen3_omni_codec_rows(): - # After removing the sender-side zero filter, the accumulator keeps every + # After removing the sender-side zero filter, the full-payload accumulator keeps every # codec row including misaligned all-zero rows. The downstream consumer # (_extract_qwen3_full_payload_codec_rows) is the authoritative crop and # filters by output_token_ids. @@ -474,13 +474,13 @@ def test_accumulate_full_payload_output_keeps_all_zero_qwen3_omni_prefill_placeh def test_full_payload_output_accumulation_hook_matrix(): """Producer-side gate: fires iff custom_process_func is loaded and not async_chunk. - Phase 2a generalized the gate from an arch + stage whitelist to a structural - check on the loaded packer. `_custom_process_func is None` short-circuits; + The gate is a structural check on the loaded payload builder. + `_custom_process_func is None` short-circuits; that maps to terminal stages (e.g. code2wav, qwen3_tts code2wav, qwen2_5 code2wav) whose stage_config has no `custom_process_next_stage_input_func` and no `*_full_payload` derivative of `custom_process_input_func`. """ - # Thinker / talker producer stages: packer loaded -> gate fires. + # Thinker / talker producer stages: payload builder loaded -> gate fires. assert _make_full_payload_accumulation_runner(model_stage="thinker")._should_accumulate_full_payload_output() assert _make_full_payload_accumulation_runner(model_stage="talker")._should_accumulate_full_payload_output() @@ -495,9 +495,9 @@ def test_full_payload_output_accumulation_hook_matrix(): model_stage="talker", async_chunk=True )._should_accumulate_full_payload_output() - # Non-qwen3 arches: gate is now arch-agnostic, but if the fixture's arch - # has no PR3 wire its runtime `_custom_process_func` would be None. - # Emulate that. + # Non-qwen3 arches: gate is arch-agnostic, but if the fixture's arch + # does not configure a connector payload builder, its runtime + # `_custom_process_func` is None. Emulate that. runner = _make_full_payload_accumulation_runner(model_arch="Qwen3TTSForConditionalGeneration") runner._custom_process_func = None runner._should_accumulate_full_payload_output_cached = None diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index effd56e318c..26033f11328 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -33,30 +33,30 @@ # Stage 1 hang (gate parks the request, no transport ever releases it). # # The `_is_sync_input` markers on per-model `*_token_only` builders in -# stage_input_processors/ remain as forward-compat documentation; when init -# is generalised (see tmp/trim_refactor branch) this whitelist can move back -# to a structural marker check or be dropped entirely. +# stage_input_processors/ remain as forward-compat documentation; when +# init is generalised this whitelist can move back to a structural marker +# check or be dropped entirely. _FULL_PAYLOAD_INPUT_STAGES: frozenset[tuple[str, str]] = frozenset( { ("Qwen3OmniMoeForConditionalGeneration", "talker"), ("Qwen3OmniMoeForConditionalGeneration", "code2wav"), - # PR3 Block A: qwen2_5_omni thinker->talker now uses the real - # full-payload producer builder (text_hidden_states routed via + # qwen2_5_omni thinker->talker uses the real full-payload + # producer builder (text_hidden_states routed via # pooler_output["hidden"] -> accumulator -> connector). Both # stages of qwen2_5_omni are enabled. ("Qwen2_5OmniForConditionalGeneration", "talker"), ("Qwen2_5OmniForConditionalGeneration", "code2wav"), - # PR3 Block A: covo_audio is fused_thinker_talker (Stage 0) → code2wav (Stage 1) + # covo_audio: fused_thinker_talker (Stage 0) -> code2wav (Stage 1). ("CovoAudioForConditionalGeneration", "code2wav"), - # PR3 Block A: mimo_audio is fused_thinker_talker (Stage 0) → code2wav (Stage 1) + # mimo_audio: fused_thinker_talker (Stage 0) -> code2wav (Stage 1). ("MiMoAudioModel", "code2wav"), - # PR3 Block A: qwen3_tts is Qwen3TTSTalkerForConditionalGeneration (Stage 0) - # → Qwen3TTSCode2Wav (Stage 1). Stage 1 is the consumer. + # qwen3_tts: Qwen3TTSTalkerForConditionalGeneration (Stage 0) + # -> Qwen3TTSCode2Wav (Stage 1). Stage 1 is the consumer. ("Qwen3TTSCode2Wav", "code2wav"), - # PR3 Block A: cosyvoice3 stages cosyvoice3_talker (Stage 0) → cosyvoice3_code2wav (Stage 1) + # cosyvoice3: cosyvoice3_talker (Stage 0) -> cosyvoice3_code2wav (Stage 1). ("CosyVoice3Model", "cosyvoice3_code2wav"), - # PR3 dynin migration: token2text (Stage 0) -> token2image (Stage 1) - # -> token2audio (Stage 2). Producer wires via + # dynin: token2text (Stage 0) -> token2image (Stage 1) -> + # token2audio (Stage 2). Producer wires via # custom_process_next_stage_input_func: *_full_payload in deploy yaml. ("DyninOmniForConditionalGeneration", "token2image"), ("DyninOmniForConditionalGeneration", "token2audio"), diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index f355af8c308..3aed006600b 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -327,13 +327,13 @@ def talker2code2wav_async_chunk( # ============================================================================ -# PR3 worker-connector data plane (non-async-chunk path) — Group D-ish. +# Worker-connector data plane (non-async-chunk path). # cosyvoice3 talker emits `multimodal_outputs={"embed": {"speech_token": t, # "speech_feat": t, "embedding": t}}` ONLY at prefill (decode steps emit -# `{}`). After flatten_payload (data_entry_keys.py:280-302) these become -# flat top-level keys `embed.speech_token` etc., persisted across decode -# steps by the accumulator (decode doesn't re-emit them). Shipping via -# the connector keeps the orchestrator off the heavy-tensor path. +# `{}`). After flatten_payload these become flat top-level keys +# `embed.speech_token` etc., persisted across decode steps by the +# full-payload accumulator (decode doesn't re-emit them). Shipping via the connector +# keeps the orchestrator off the heavy-tensor path. # ============================================================================ # All three embed tensors are emitted once at prefill and must REPLACE-not- @@ -391,7 +391,7 @@ def text2flow_full_payload( pooling_output, request, ): - """Producer-side packer. + """Producer-side payload builder. Reads prefill-emitted `embed.{speech_token, speech_feat, embedding}` from the accumulator and ships prompt conditioning as a connector payload. diff --git a/vllm_omni/model_executor/stage_input_processors/covo_audio.py b/vllm_omni/model_executor/stage_input_processors/covo_audio.py index 52ca8a44bca..c5cdc312cc6 100644 --- a/vllm_omni/model_executor/stage_input_processors/covo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/covo_audio.py @@ -28,7 +28,7 @@ def llm2code2wav( """Legacy orchestrator-path builder (retained for async_chunk + back-compat). The non-async-chunk path now goes through ``llm2code2wav_token_only`` + - worker connector + ``llm2code2wav_full_payload`` (PR3). + worker connector + ``llm2code2wav_full_payload``. """ talker_outputs = source_outputs code2wav_inputs = [] @@ -74,9 +74,9 @@ def llm2code2wav_token_only( return code2wav_inputs -# Mark as the sync-side input builder — the structural full-payload gate -# (omni_connector_model_runner_mixin.should_accumulate_full_payload_output) -# fires only when the resolved custom_process_func carries this marker. +# Mark for forward compatibility; current consumer wait gating is +# _FULL_PAYLOAD_INPUT_STAGES-driven (see the mixin +# should_accumulate_full_payload_output docstring). llm2code2wav_token_only._is_sync_input = True @@ -85,10 +85,10 @@ def llm2code2wav_full_payload( pooling_output: dict[str, Any], request: Any, ) -> dict[str, Any] | None: - """Producer-side packer for the worker connector data plane. + """Producer-side payload builder for the worker connector data plane. covo_audio's fused_thinker_talker stage emits codec ids via - ``request.output_token_ids`` (token-id-only Group B shape — no + ``request.output_token_ids`` (token-ids only -- no hidden_states or embed tensors), so the connector payload is just the filtered audio codes plus a finished marker. """ diff --git a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py index e2843d9394b..8c61e43c490 100644 --- a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py @@ -149,7 +149,7 @@ def token2image_to_token2audio( # ============================================================================ -# PR3 worker-connector data plane (non-async-chunk path). +# Worker-connector data plane (non-async-chunk path). # ============================================================================ # Per-model REPLACE-keys for the full-payload accumulator. dynin_omni's @@ -160,7 +160,7 @@ def token2image_to_token2audio( def _build_full_payload(pooling_output: dict[str, Any] | None, request: Any) -> dict[str, Any] | None: - """Producer-side packer: assemble dynin_omni connector payload. + """Producer-side payload builder: assemble dynin_omni connector payload. Reads token_ids from ``pooling_output["token_ids"]`` (preferred) or ``request.output_token_ids`` (fallback). Reads structured non-tensor @@ -201,7 +201,7 @@ def token2text_to_token2image_full_payload( pooling_output: dict[str, Any], request: Any, ) -> dict[str, Any] | None: - """Producer-side packer for the Stage-0 → Stage-1 (text → image) transition.""" + """Producer-side payload builder for the Stage-0 → Stage-1 (text → image) transition.""" del transfer_manager return _build_full_payload(pooling_output, request) @@ -211,7 +211,7 @@ def token2image_to_token2audio_full_payload( pooling_output: dict[str, Any], request: Any, ) -> dict[str, Any] | None: - """Producer-side packer for the Stage-1 → Stage-2 (image → audio) transition.""" + """Producer-side payload builder for the Stage-1 → Stage-2 (image → audio) transition.""" del transfer_manager return _build_full_payload(pooling_output, request) @@ -264,6 +264,7 @@ def token2image_to_token2audio_token_only( return _token_only_from_source(source_outputs) -# Mark sync-side builders for the structural full-payload gate. +# Mark sync-side builders for forward compatibility; current consumer +# wait gating is _FULL_PAYLOAD_INPUT_STAGES-driven. token2text_to_token2image_token_only._is_sync_input = True token2image_to_token2audio_token_only._is_sync_input = True diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index 3f203a61e4f..be57ef55a5b 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -147,10 +147,8 @@ def llm2code2wav_async_chunk( Accumulates codes in connector per request_id, returns payload only when chunk_size is full or request is finished; returns None when waiting. """ - # Null guard: under Block A universal-ish init, the producer-side - # chunk_transfer_adapter calls this every emit step including no-output - # steps where pooling_output is None. Pre-Block-A this code path was - # unreachable (no connector init for mimo_audio). + # Null guard: chunk_transfer_adapter calls this every emit step + # including no-output steps where pooling_output is None. if pooling_output is None or not isinstance(pooling_output, dict): if is_finished: connector = getattr(transfer_manager, "connector", None) @@ -317,10 +315,10 @@ def llm2code2wav( # ============================================================================ -# PR3 worker-connector data plane (non-async-chunk path) — Group B. -# AR runner's `flatten_payload` (data_entry_keys.py:280-302) converts the -# model emit `multimodal_outputs={"codes": {"audio": ...}}` to flat -# `pooling_output["codes.audio"]` before the accumulator runs, so default +# Worker-connector data plane (non-async-chunk path). +# AR runner's `flatten_payload` converts the model emit +# `multimodal_outputs={"codes": {"audio": ...}}` to flat +# `pooling_output["codes.audio"]` before the full-payload accumulator runs, so default # CONCAT semantics build the full codec tensor across all decode steps. # ============================================================================ @@ -397,7 +395,7 @@ def llm2code2wav_full_payload( pooling_output: dict, request, ) -> dict | None: - """Producer-side packer for the worker connector data plane. + """Producer-side payload builder for the worker connector data plane. AR runner's ``flatten_payload`` converts the per-step model emit ``{"codes": {"audio": ...}}`` to ``pooling_output["codes.audio"]``. diff --git a/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py b/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py index ce9a13807ac..40830bc0a55 100644 --- a/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py @@ -536,13 +536,18 @@ def thinker2talker( # ============================================================================ -# PR3 worker-connector data plane (non-async-chunk path) — Group D minimal. -# ming_flash_omni's thinker→talker bridge passes detokenized text only; +# Worker-connector data plane (non-async-chunk path) -- inactive for +# ming_flash_omni. +# ming_flash_omni's thinker->talker bridge passes detokenized text only; # voice/speaker metadata flows through the USER request's -# additional_information, not the model's pooler_output. So there is no -# heavy tensor to migrate — the PR3 change is structural-only: register -# the _is_sync_input marker so the Phase 2a gate applies consistently. -# full_payload returns None (no per-step connector data). +# additional_information, not the model's pooler_output. No heavy +# tensor to migrate, so ``thinker2talker_full_payload`` returns None. +# ming_flash_omni is not in ``_OMNI_CONNECTOR_INIT_ARCHS`` or +# ``_FULL_PAYLOAD_INPUT_STAGES``, so the worker connector is not +# initialised for this arch and the consumer never waits on a connector +# payload; data flows through ``additional_information`` written by +# ``thinker2talker_token_only``. The ``*_full_payload`` definition is +# retained for forward compatibility. # ============================================================================ _FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() @@ -555,9 +560,9 @@ def thinker2talker_token_only( ) -> list[OmniTokensPrompt]: """Sync-side builder for the non-async-chunk thinker→talker path. - Ports the legacy ``thinker2talker`` body to the standard PR3 SIP - signature (``source_outputs`` instead of ``stage_list, - engine_input_source``). Body is otherwise identical: extracts the + Ports the legacy ``thinker2talker`` body to the new stage-input-processor signature + (``source_outputs`` instead of ``stage_list, engine_input_source``). + Body is otherwise identical: extracts the generated text from each thinker output and packages it with the request's voice/speaker additional_information for the talker. """ @@ -612,14 +617,14 @@ def thinker2talker_full_payload( pooling_output, request, ): - """Producer-side packer — no-op. - - ming_flash_omni's thinker emits no heavy tensor to ship via the worker - connector (the bridge passes text only, and speaker metadata arrives - through the USER request's additional_information). Returning None - causes the connector to skip the send for this transition. The - structural gate still fires so Phase 2a / 2d infrastructure behavior - is consistent across in-scope models. + """Producer-side payload builder — no-op. + + ming_flash_omni's thinker emits no heavy tensor to ship via the + worker connector (the bridge passes text only, and speaker metadata + arrives through the USER request's additional_information). + ming_flash_omni is not in ``_OMNI_CONNECTOR_INIT_ARCHS`` so this + function is never invoked at runtime; it is retained for forward + compatibility with the connector path. """ del transfer_manager, pooling_output, request return None diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index 411f4973dd1..822300fd3e9 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -93,9 +93,19 @@ def talker2code2wav( # ============================================================================ -# PR3 worker-connector data plane (non-async-chunk path) — Group B half. -# Only talker→code2wav is migrated in this commit; thinker→talker (Group A) -# requires model-side pooler_output emit and is deferred. +# Worker-connector data plane (non-async-chunk path). +# Both transitions ship payloads via the worker connector +# (registered in ``_FULL_PAYLOAD_INPUT_STAGES`` in +# omni_scheduling_coordinator): +# - thinker->talker reads accumulated ``pooling_output["hidden"]`` and +# packs an OmniPayload-shaped dict (embed.prefill / +# hidden_states.output / ids.prompt / ids.output) for the talker. +# ``thinker2talker_token_only`` writes the same shape into +# ``additional_information`` as a legacy sync fallback; the talker's +# ``talker_preprocess`` reads either source through the same payload +# keys. +# - talker->code2wav strips TALKER_CODEC_{START,END} boundary tokens +# and ships the codec token ids. # ============================================================================ # Per-model REPLACE-keys for the full-payload accumulator. qwen2_5_omni's @@ -162,9 +172,9 @@ def talker2code2wav_full_payload( pooling_output: dict, request, ) -> dict | None: - """Producer-side packer: ship the stripped codec ids via connector. + """Producer-side payload builder: ship the stripped codec ids via connector. - Group B shape — token_ids only. The talker stage's output already + Token-ids-only shape. The talker stage's output already carries the codec ids on ``request.output_token_ids``; we strip the boundary tokens and pack a minimal payload. """ @@ -182,18 +192,17 @@ def talker2code2wav_full_payload( # ============================================================================ -# PR3 worker-connector data plane (non-async-chunk path) — Group A reduced -# to D-minimal shape. +# Worker-connector data plane (non-async-chunk path) -- thinker->talker. # -# Three subagent investigations (2026-05-16, audits/) confirmed: -# - qwen2_5_omni talker consumes ONE tensor (last-layer hidden state) via -# Linear(3584, 896); no early-layer-0 consumer, no `accept_hidden_layer` -# HF config field. -# - `text_hidden_states` is NOT plumbed into the AR runner pooler_output -# chain, so the existing accumulator cannot ship it. -# So the PR3 migration is structural-only: thinker2talker_token_only mirrors -# the legacy body so additional_information continues to carry the latent -# tensor (same as cosyvoice3's post-fix state). full_payload returns None. +# qwen2_5_omni's talker consumes the thinker's last-layer hidden state +# via Linear(3584, 896). The AR runner publishes those hidden states +# per decode step on ``pooling_output["hidden"]`` (unpacked from +# ``OmniOutput.text_hidden_states``); the full-payload accumulator +# concatenates them so ``thinker2talker_full_payload`` sees the full +# prefill+decode trajectory and packs an OmniPayload-shaped dict. +# ``thinker2talker_token_only`` writes the same shape into +# ``additional_information`` as a legacy sync fallback; the talker's +# ``talker_preprocess`` reads ``info_dict`` regardless of source. # ============================================================================ _FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() @@ -206,17 +215,19 @@ def thinker2talker_token_only( ): """Sync-side builder for the non-async-chunk thinker->talker path. - Body is identical to legacy ``thinker2talker`` above — preserves the - orchestrator-shaped data path (latent in additional_information) so - the talker stage receives thinker hidden states without requiring the - worker connector to deliver them. Filed as a Phase 4 follow-up to - route the latent via connector once the AR runner's text_hidden_states - plumbing is wired into pooler_output / model_intermediate_buffer. - - The ``_is_sync_input = True`` marker below activates the Phase 2a - structural gate so the rest of the PR3 infrastructure (gen scheduler - bridge, runner lifecycle, full-payload accumulator) participates - consistently with the other 8 migrated transitions. + Body mirrors the legacy ``thinker2talker`` above: packs an + OmniPayload-shaped dict (hidden_states.output / embed.prefill / + ids.prompt / ids.output) into ``additional_information``, allocates + TALKER_CODEC_{START,PAD,END} prompt slots, and forwards + ``multi_modal_data``. Serves as a legacy sync fallback; the same + shape is also built by ``thinker2talker_full_payload`` below and + shipped via the worker connector. + + The ``_is_sync_input = True`` marker below is currently dormant + forward-compat documentation -- the consumer-wait gate is + whitelist-driven via ``_FULL_PAYLOAD_INPUT_STAGES`` (see the mixin + ``should_accumulate_full_payload_output`` docstring), not by this + marker. """ thinker_outputs = source_outputs talker_inputs = [] @@ -269,7 +280,7 @@ def thinker2talker_full_payload( pooling_output, request, ): - """Producer-side packer for the worker-connector data plane. + """Producer-side payload builder for the worker-connector data plane. The AR runner emits per-step ``pooling_output["hidden"]`` (the thinker's last-layer hidden states for the request span, unpacked @@ -284,10 +295,9 @@ def thinker2talker_full_payload( decode hidden states, then pack the ``OmniPayload``-shaped dict that the talker's ``thinker_to_talker_process`` already reads (keys ``embed.prefill`` / ``hidden_states.output`` / ``ids.prompt`` / - ``ids.output``). Shape matches what - ``thinker2talker_token_only`` writes into - ``additional_information``, so the consumer-side coordinator gate - flip is a drop-in once the no-touch coordinator file is updated. + ``ids.output``). Shape matches what ``thinker2talker_token_only`` + writes into ``additional_information``, so the talker consumes the + same payload layout from either path. Like ``qwen3_omni.thinker2talker_full_payload``, we apply a finish-reason-aware stop-row trim: vLLM v1 appends the sampled diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index b56e15a7362..83b267d5933 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -484,13 +484,12 @@ def thinker2talker_full_payload( # Length-aware trim of accumulated thinker output, finish-reason-aware. # vLLM appends the sampled token to `output_token_ids` BEFORE - # `check_stop` (scheduler.py:1641-1651), so a stop-finished request - # has accumulator_rows == len(all_token_ids) including the stop - # emission row -- the talker must NOT consume that row (fba23325 - # spurious-phoneme regression). Max-token finishes do not append - # an extra forward, so no drop is needed (BK 9702 long-output - # regression). Primary: distinguish via `request.status`. Fallback - # only when status is absent: last-token-in-stop-id heuristic. + # `check_stop`, so a stop-finished request has accumulator_rows + # == len(all_token_ids) including the stop emission row -- the + # talker must NOT consume that row. Max-token finishes do not + # append an extra forward, so no drop is needed. Primary: + # distinguish via `request.status`. Fallback only when status + # is absent: last-token-in-stop-id heuristic. status = getattr(request, "status", None) status_name = getattr(status, "name", None) or "" if not status_name and status is not None: @@ -958,6 +957,6 @@ def talker2code2wav( return code2wav_inputs -# Mark sync-side builders for the structural full-payload gate (see -# should_accumulate_full_payload_output above). +# Mark for forward compatibility; current consumer wait gating is +# _FULL_PAYLOAD_INPUT_STAGES-driven (see should_accumulate_full_payload_output above). thinker2talker_token_only._is_sync_input = True diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 8ef188009f2..fb9f3fd0d8f 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -283,12 +283,12 @@ def talker2code2wav_async_chunk( # ============================================================================ -# PR3 worker-connector data plane (non-async-chunk path) — Group C multi-key. -# AR runner's `flatten_payload` (data_entry_keys.py:280-302) converts the -# model emit `multimodal_outputs={"codes": {"audio": ..., "ref": ...}, -# "meta": {"ref_code_len": ..., "codec_streaming": ...}}` to flat dotted keys -# (`codes.audio`, `codes.ref`, `meta.ref_code_len`, `meta.codec_streaming`) -# before the accumulator runs. +# Worker-connector data plane (non-async-chunk path). +# AR runner's `flatten_payload` converts the model emit +# `multimodal_outputs={"codes": {"audio": ..., "ref": ...}, +# "meta": {"ref_code_len": ..., "codec_streaming": ...}}` to flat dotted +# keys (`codes.audio`, `codes.ref`, `meta.ref_code_len`, +# `meta.codec_streaming`) before the full-payload accumulator runs. # - codes.audio is 2-D so default CONCAT across steps builds the full sequence. # - codes.ref is a list (not Tensor with dim>=2) so accumulator LATEST-wins # keeps the prefill-emitted ref tensor across decode steps (which don't emit @@ -415,7 +415,7 @@ def talker2code2wav_full_payload( pooling_output, request, ): - """Producer-side packer. + """Producer-side payload builder. Reads accumulated codec from `pooling_output["codes.audio"]` (CONCAT across steps via flatten_payload), latest `pooling_output["codes.ref"]` diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index fb6be8eb231..60c059c5a44 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -84,13 +84,13 @@ def __init__(self, *args, **kwargs): self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False) # Initialize KV cache manager (preserve vllm_config fallback behavior) self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config) - # Worker-connector full-payload init is gated by an arch allowlist that - # grows as each per-arch transition is verified end-to-end (PR3 incremental - # Block A). Adding an arch here without also wiring its scheduler-side - # gate entries in `omni_scheduling_coordinator._FULL_PAYLOAD_INPUT_STAGES` - # produces a Stage-1 hang on the consumer side (request parks but no - # transport ever releases). Keep the two in lockstep. - _BLOCK_A_INIT_ALLOWLIST = { + # Worker-connector init is gated by a per-`model_arch` allowlist + # (covers both producer-side and consumer-side runners for the + # arches below). Consumer-wait stages must be registered + # separately as `(model_arch, model_stage)` tuples in + # `omni_scheduling_coordinator._FULL_PAYLOAD_INPUT_STAGES`; + # forgetting that produces a Stage-1 hang on the consumer. + _OMNI_CONNECTOR_INIT_ARCHS = { "Qwen3OmniMoeForConditionalGeneration", "Qwen2_5OmniForConditionalGeneration", "CovoAudioForConditionalGeneration", @@ -100,7 +100,7 @@ def __init__(self, *args, **kwargs): "CosyVoice3Model", "DyninOmniForConditionalGeneration", } - if getattr(self.model_config, "model_arch", None) in _BLOCK_A_INIT_ALLOWLIST: + if getattr(self.model_config, "model_arch", None) in _OMNI_CONNECTOR_INIT_ARCHS: self.init_omni_connectors( vllm_config=self.vllm_config, model_config=self.model_config, diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index cf59113ca9b..9f1060ed1aa 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -59,8 +59,8 @@ class GPUGenerationModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # See gpu_ar_model_runner.py for Block A allowlist policy. - _BLOCK_A_INIT_ALLOWLIST = { + # Mirrors the init allowlist in gpu_ar_model_runner.py. + _OMNI_CONNECTOR_INIT_ARCHS = { "Qwen3OmniMoeForConditionalGeneration", "Qwen2_5OmniForConditionalGeneration", "CovoAudioForConditionalGeneration", @@ -70,7 +70,7 @@ def __init__(self, *args, **kwargs): "CosyVoice3Model", "DyninOmniForConditionalGeneration", } - if getattr(self.model_config, "model_arch", None) in _BLOCK_A_INIT_ALLOWLIST: + if getattr(self.model_config, "model_arch", None) in _OMNI_CONNECTOR_INIT_ARCHS: self.init_omni_connectors( vllm_config=self.vllm_config, model_config=self.model_config, diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index 032d87d27bc..f5b16532b5d 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -56,17 +56,15 @@ def should_accumulate_full_payload_output(model_config, custom_process_func) -> stage is not in async_chunk mode, and ``model_stage`` is set. NOTE: the ``_is_sync_input`` marker is on the *consumer-side* - ``*_token_only`` builder, not on the ``*_full_payload`` packer that - workers load on the *producer* side. So checking it here would always - return False and the accumulator would never run. The - consumer-side scheduler gate (``uses_full_payload_input_coordinator`` - in ``omni_scheduling_coordinator.py``) is where the marker is - appropriately tested. - - Pre-Phase-2a, this gate was an arch + stage whitelist - (``Qwen3OmniMoeForConditionalGeneration`` and ``thinker``/``talker``). - Phase 2a generalized that to "any stage with a loaded packer + not - async_chunk + model_stage set" — arch-agnostic. + ``*_token_only`` builder, not on the ``*_full_payload`` payload builder that + workers load on the *producer* side, so checking it here would + always return False and the full-payload accumulator would never run. The + marker itself is currently dormant forward-compat documentation: + the consumer-side scheduler gate + (``uses_full_payload_input_coordinator`` in + ``omni_scheduling_coordinator.py``) is whitelist-driven on + ``(model_arch, model_stage)`` against ``_FULL_PAYLOAD_INPUT_STAGES`` + -- adding the marker alone does not open a consumer-wait gate. """ if custom_process_func is None: return False @@ -265,7 +263,7 @@ def cleanup_finished_request(self, req_id: str) -> None: self.flush_full_payload_outputs({req_id}) except Exception: # Defensive: connector may not be initialised for archs - # outside the Block A allowlist. Cleanup must still proceed. + # outside the connector init allowlist. Cleanup must still proceed. pass ext_id = self._request_ids_mapping.pop(req_id, None) @@ -776,7 +774,7 @@ def _materialize_full_payload_entry(entry): def _resolve_full_payload_replace_keys(self) -> frozenset: """Per-model REPLACE-key set for the full-payload accumulator. - Looked up from the SIP module that ships the model's sync builder + Looked up from the stage-input-processor module that ships the model's sync builder (`model_config.custom_process_input_func.__module__`). The module declares ``_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str]``; if absent, returns the empty set. @@ -1022,9 +1020,9 @@ def register_chunk_recv(self, request: Any) -> None: if self._stage_id == 0: return request_id = request.request_id - # Codex Issue 3: explicit external_req_id=None should fall back to - # request_id; otherwise recv keys become `None__` and - # collide across requests. + # Explicit external_req_id=None must fall back to request_id; + # otherwise recv keys become `None__` and collide + # across requests. ext = getattr(request, "external_req_id", None) self._request_ids_mapping[request_id] = ext if ext is not None else request_id with self._lock: @@ -2210,7 +2208,7 @@ def _resolve_external_req_id(self, request: Any, fallback_req_id: str) -> str: if mapped is not None: return mapped if request is not None: - # Codex Issue 3: external_req_id may be explicitly None; fall back. + # external_req_id may be explicitly None; fall back. ext = getattr(request, "external_req_id", None) if ext is not None: return ext From 39572efd49aac509dbb7ee730938c1ffe1c76d44 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Thu, 21 May 2026 08:59:19 +0000 Subject: [PATCH 10/19] Minor fix Signed-off-by: natureofnature --- .../test_omni_scheduler_mixin_timeouts.py | 104 ++++++++++++++++++ .../sched/test_omni_scheduling_coordinator.py | 43 +++++++- .../test_qwen2_5_omni_thinker2talker.py | 87 +++++++++++++++ vllm_omni/core/sched/omni_ar_scheduler.py | 1 + .../core/sched/omni_generation_scheduler.py | 1 + vllm_omni/core/sched/omni_scheduler_mixin.py | 73 ++++++++++-- .../stage_input_processors/dynin_omni.py | 7 +- .../stage_input_processors/mimo_audio.py | 4 +- .../stage_input_processors/qwen2_5_omni.py | 4 +- .../stage_input_processors/qwen3_tts.py | 15 ++- .../omni_connector_model_runner_mixin.py | 30 ++++- 11 files changed, 344 insertions(+), 25 deletions(-) create mode 100644 tests/core/sched/test_omni_scheduler_mixin_timeouts.py create mode 100644 tests/model_executor/stage_input_processors/test_qwen2_5_omni_thinker2talker.py diff --git a/tests/core/sched/test_omni_scheduler_mixin_timeouts.py b/tests/core/sched/test_omni_scheduler_mixin_timeouts.py new file mode 100644 index 00000000000..053c926cc24 --- /dev/null +++ b/tests/core/sched/test_omni_scheduler_mixin_timeouts.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit coverage for _process_pending_input_timeouts. + +Verifies that the mixin correctly *delegates* timed-out requests to the +base scheduler's ``finish_requests`` API with ``RequestStatus.FINISHED_ERROR``. +The end-to-end effect (queue removal + status set + per-request cleanup + +client-facing FINISHED_ERROR emission) is the responsibility of upstream +vLLM's ``finish_requests`` implementation and is covered by upstream tests; +this file only asserts the wiring from the mixin to that API. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin + + +class _FakeCoordinator: + def __init__(self, timed_out_ids): + self._timed_out_ids = set(timed_out_ids) + self.calls = [] + + def collect_timed_out_request_ids(self, timeout_s): + self.calls.append(timeout_s) + return set(self._timed_out_ids) + + +class _FakeScheduler(OmniSchedulerMixin): + def __init__(self, requests, coordinator): + self.requests = requests + self.input_coordinator = coordinator + self.finish_calls = [] + + def finish_requests(self, req_ids, status): + self.finish_calls.append((set(req_ids), status)) + + +def test_process_pending_input_timeouts_delegates_to_finish_requests(): + """Timed-out request present in self.requests is forwarded to finish_requests.""" + req_id = "stuck-req" + requests = {req_id: SimpleNamespace(request_id=req_id)} + coord = _FakeCoordinator(timed_out_ids={req_id}) + scheduler = _FakeScheduler(requests, coord) + + scheduler._process_pending_input_timeouts() + + assert len(coord.calls) == 1, "coordinator should be polled once" + assert coord.calls[0] > 0, "timeout must be positive when enabled" + + assert len(scheduler.finish_calls) == 1 + finished_ids, status = scheduler.finish_calls[0] + assert finished_ids == {req_id} + # RequestStatus is the upstream enum; the mixin imports it as + # RequestStatus.FINISHED_ERROR. Check by name to avoid hard import here. + assert getattr(status, "name", str(status)).endswith("FINISHED_ERROR") + + +def test_process_pending_input_timeouts_skips_already_freed_request(): + """Timed-out id no longer in self.requests must not be forwarded.""" + coord = _FakeCoordinator(timed_out_ids={"already-freed"}) + scheduler = _FakeScheduler(requests={}, coordinator=coord) + + scheduler._process_pending_input_timeouts() + + assert coord.calls == [coord.calls[0]] and coord.calls[0] > 0 + assert scheduler.finish_calls == [] + + +def test_process_pending_input_timeouts_noop_without_coordinator(): + """No coordinator => no finish_requests call, no crash.""" + + class _NoCoord(OmniSchedulerMixin): + def __init__(self): + self.requests = {} + self.input_coordinator = None + self.finish_calls = [] + + def finish_requests(self, req_ids, status): + self.finish_calls.append((set(req_ids), status)) + + scheduler = _NoCoord() + scheduler._process_pending_input_timeouts() + assert scheduler.finish_calls == [] + + +def test_process_pending_input_timeouts_disabled_when_timeout_zero(monkeypatch): + """Setting DEFAULT_INPUT_WAIT_TIMEOUT_S <= 0 disables the safety net.""" + from vllm_omni.core.sched import omni_scheduler_mixin + + monkeypatch.setattr(omni_scheduler_mixin, "DEFAULT_INPUT_WAIT_TIMEOUT_S", 0.0) + + coord = _FakeCoordinator(timed_out_ids={"r1"}) + scheduler = _FakeScheduler(requests={"r1": SimpleNamespace(request_id="r1")}, coordinator=coord) + scheduler._process_pending_input_timeouts() + assert coord.calls == [], "coordinator must not be polled when timeout is disabled" + assert scheduler.finish_calls == [] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/core/sched/test_omni_scheduling_coordinator.py b/tests/core/sched/test_omni_scheduling_coordinator.py index a6d4920e3d5..1b36cd784d8 100644 --- a/tests/core/sched/test_omni_scheduling_coordinator.py +++ b/tests/core/sched/test_omni_scheduling_coordinator.py @@ -101,13 +101,46 @@ class TestFullPayloadCoordinatorSelection(unittest.TestCase): These tests pin which ``(arch, stage)`` pairs the gate fires for today. """ - def test_all_whitelisted_arch_stage_pairs_fire_gate(self): - """All (arch, stage) pairs in _FULL_PAYLOAD_INPUT_STAGES must fire - the gate when stage_id > 0 and async_chunk=False. + # Expected whitelist (model_arch, model_stage). Hardcoded to avoid the + # tautology of importing _FULL_PAYLOAD_INPUT_STAGES and asserting it + # against itself; any drift between this matrix and the whitelist will + # fail loudly here. + EXPECTED_FULL_PAYLOAD_INPUT_STAGES: frozenset[tuple[str, str]] = frozenset( + { + ("Qwen3OmniMoeForConditionalGeneration", "talker"), + ("Qwen3OmniMoeForConditionalGeneration", "code2wav"), + ("Qwen2_5OmniForConditionalGeneration", "talker"), + ("Qwen2_5OmniForConditionalGeneration", "code2wav"), + ("CovoAudioForConditionalGeneration", "code2wav"), + ("MiMoAudioModel", "code2wav"), + ("Qwen3TTSCode2Wav", "code2wav"), + ("CosyVoice3Model", "cosyvoice3_code2wav"), + ("DyninOmniForConditionalGeneration", "token2image"), + ("DyninOmniForConditionalGeneration", "token2audio"), + } + ) + + def test_whitelist_matches_expected_matrix(self): + """_FULL_PAYLOAD_INPUT_STAGES must equal the hardcoded expected matrix. + + Catches both accidental additions (which would silently enable the + consumer-wait gate for a new arch) and accidental removals (which + would silently disable an enabled arch). """ from vllm_omni.core.sched.omni_scheduling_coordinator import _FULL_PAYLOAD_INPUT_STAGES - for arch, stage in _FULL_PAYLOAD_INPUT_STAGES: + self.assertEqual( + frozenset(_FULL_PAYLOAD_INPUT_STAGES), + self.EXPECTED_FULL_PAYLOAD_INPUT_STAGES, + msg="_FULL_PAYLOAD_INPUT_STAGES drifted from the expected matrix; " + "update EXPECTED_FULL_PAYLOAD_INPUT_STAGES if intentional.", + ) + + def test_all_whitelisted_arch_stage_pairs_fire_gate(self): + """Every (arch, stage) pair in the expected matrix must fire + the gate when stage_id > 0 and async_chunk=False. + """ + for arch, stage in self.EXPECTED_FULL_PAYLOAD_INPUT_STAGES: model_config = SimpleNamespace( stage_id=1, async_chunk=False, @@ -116,7 +149,7 @@ def test_all_whitelisted_arch_stage_pairs_fire_gate(self): ) self.assertTrue( uses_full_payload_input_coordinator(model_config), - msg=f"expected gate to fire for {arch}/{stage} (entry in _FULL_PAYLOAD_INPUT_STAGES)", + msg=f"expected gate to fire for {arch}/{stage}", ) def test_other_arch_or_stage_or_mode_does_not_fire(self): diff --git a/tests/model_executor/stage_input_processors/test_qwen2_5_omni_thinker2talker.py b/tests/model_executor/stage_input_processors/test_qwen2_5_omni_thinker2talker.py new file mode 100644 index 00000000000..0cb61a972f2 --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_qwen2_5_omni_thinker2talker.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Light coverage for qwen2_5_omni.thinker2talker_full_payload. + +Covers the finish-reason-aware stop-row trim contract: when the request +status is FINISHED_STOPPED, the builder must drop one row from the +accumulated hidden states (vLLM v1 appends the sampled stop token to +output_token_ids before check_stop, so the trailing hidden-state row +corresponds to the stop emission and must not reach the talker). +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( + thinker2talker_full_payload, +) + + +def _make_request( + prompt_token_ids, + output_token_ids, + status_name: str | None = "FINISHED_STOPPED", +): + status = SimpleNamespace(name=status_name) if status_name else None + return SimpleNamespace( + request_id="r1", + prompt_token_ids=prompt_token_ids, + output_token_ids=output_token_ids, + all_token_ids=list(prompt_token_ids) + list(output_token_ids), + status=status, + sampling_params=None, + ) + + +def test_finished_stopped_trims_one_decode_row(): + """FINISHED_STOPPED: drop trailing hidden-state row so talker does not + consume the stop-emission row. + """ + prompt = [1, 2, 3] + output = [10, 11, 12] + request = _make_request(prompt, output, status_name="FINISHED_STOPPED") + # 6 prompt+output rows + 1 stop-emission row = 7 hidden rows total. + hidden = torch.arange(7 * 4, dtype=torch.float32).reshape(7, 4) + pooling = {"hidden": hidden} + + payload = thinker2talker_full_payload(transfer_manager=None, pooling_output=pooling, request=request) + + assert payload is not None + # ids.output had one trailing stop row dropped: 3 - 1 = 2 remaining. + assert payload["ids"]["output"] == output[:-1] + # embed.prefill must cover only the prompt rows. + assert payload["embed"]["prefill"].shape[0] == len(prompt) + # hidden_states.output covers the decode rows minus the dropped stop row. + assert payload["hidden_states"]["output"].shape[0] == len(output) - 1 + + +def test_finished_length_capped_keeps_all_rows(): + """FINISHED_LENGTH_CAPPED: no row drop; hidden_states.output covers + all decode rows. + """ + prompt = [1, 2, 3] + output = [10, 11, 12] + request = _make_request(prompt, output, status_name="FINISHED_LENGTH_CAPPED") + hidden = torch.arange(6 * 4, dtype=torch.float32).reshape(6, 4) + pooling = {"hidden": hidden} + + payload = thinker2talker_full_payload(transfer_manager=None, pooling_output=pooling, request=request) + + assert payload is not None + assert payload["ids"]["output"] == output + assert payload["embed"]["prefill"].shape[0] == len(prompt) + assert payload["hidden_states"]["output"].shape[0] == len(output) + + +def test_missing_hidden_returns_none(): + """Defensive: pooling_output without "hidden" returns None.""" + request = _make_request([1, 2], [3], status_name="FINISHED_STOPPED") + assert thinker2talker_full_payload(transfer_manager=None, pooling_output={}, request=request) is None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index fb76fd52bf1..c306765b891 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -211,6 +211,7 @@ def schedule(self) -> SchedulerOutput: # type: ignore[override] if getattr(req, "status", None) == RequestStatus.FINISHED_ABORTED: queue.remove(req) self._consume_pending_connector_output(model_mode="ar") + self._process_pending_input_timeouts() if self.chunk_transfer_adapter: self.chunk_transfer_adapter.process_pending_chunks(self.waiting, self.running) diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index efff106cf63..403bd4c5db7 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -83,6 +83,7 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests = create_request_queue(self.policy) req_index = 0 self._consume_pending_connector_output(model_mode="generation") + self._process_pending_input_timeouts() if self.chunk_transfer_adapter: self.chunk_transfer_adapter.process_pending_chunks(self.waiting, self.running) diff --git a/vllm_omni/core/sched/omni_scheduler_mixin.py b/vllm_omni/core/sched/omni_scheduler_mixin.py index fba514756a7..16b7c373c99 100644 --- a/vllm_omni/core/sched/omni_scheduler_mixin.py +++ b/vllm_omni/core/sched/omni_scheduler_mixin.py @@ -1,13 +1,32 @@ from __future__ import annotations +import os from typing import Any +from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import EngineCoreEventType from vllm.v1.request import Request, RequestStatus, StreamingUpdate from vllm_omni.core.sched.output import OmniInputRegistration, OmniSchedulerOutput +logger = init_logger(__name__) + +# Upper bound on how long a request may sit in full-payload-input wait +# (the state ``OmniSchedulingCoordinator`` records via ``_waiting_since``) +# before the scheduler force-fails it. Defends against stuck consumer-side +# requests when the producer drops a full-payload, send fails, or recv +# never arrives. Override per-deployment via +# VLLM_OMNI_INPUT_WAIT_TIMEOUT_S; set <=0 to disable the safety net. +# +# Scope: this constant only covers the full-payload coordinator path +# (``input_coordinator``). The async-chunk path uses +# ``chunk_transfer_adapter`` and is not affected by this constant. +try: + DEFAULT_INPUT_WAIT_TIMEOUT_S: float = float(os.environ.get("VLLM_OMNI_INPUT_WAIT_TIMEOUT_S", "300")) +except ValueError: + DEFAULT_INPUT_WAIT_TIMEOUT_S = 300.0 + class OmniSchedulerMixin: """Shared scheduler helpers for omni-specific request handling.""" @@ -44,23 +63,63 @@ def _consume_pending_connector_output(self, model_mode: str) -> None: connector_output.stage_recv_req_ids if connector_output else set(), ) + def _process_pending_input_timeouts(self) -> None: + """Force-fail requests waiting on the full-payload coordinator too long. + + Called at the top of every ``schedule()`` cycle, right after + ``_consume_pending_connector_output``. Without this hook, a request + whose producer dropped a payload would sit in the + full-payload-input wait state indefinitely (the runner mixin + protects ``_pending_load_reqs`` from prune sweeps). + + Reads ``_waiting_since`` timestamps maintained by the input + coordinator and delegates to the base scheduler's + ``finish_requests`` to mark expired requests FINISHED_ERROR. + Disabled when ``DEFAULT_INPUT_WAIT_TIMEOUT_S`` is <= 0. + + Scope: only covers ``input_coordinator`` (full-payload path). + Async-chunk requests park in ``chunk_transfer_adapter`` instead + and are not handled here -- if a similar safety net is needed + for the chunk path, it belongs in the chunk adapter. + """ + if DEFAULT_INPUT_WAIT_TIMEOUT_S <= 0: + return + input_coordinator = getattr(self, "input_coordinator", None) + if input_coordinator is None: + return + timed_out_ids = input_coordinator.collect_timed_out_request_ids(timeout_s=DEFAULT_INPUT_WAIT_TIMEOUT_S) + if not timed_out_ids: + return + present_ids = {req_id for req_id in timed_out_ids if req_id in self.requests} + if not present_ids: + return + logger.warning( + "Marking %d request(s) as FINISHED_ERROR after waiting > %.0fs for connector input: %s", + len(present_ids), + DEFAULT_INPUT_WAIT_TIMEOUT_S, + sorted(present_ids), + ) + self.finish_requests(present_ids, RequestStatus.FINISHED_ERROR) + def _capture_omni_connector_output(self, model_runner_output: Any, model_mode: str) -> None: """Stash the model runner's omni_connector_output for next schedule(). Called at the tail of every ``update_from_output()``. Identical between AR and generation schedulers except for ``model_mode``. + + NOTE: this method only stashes the output. Applying the metadata + is the responsibility of ``_consume_pending_connector_output()`` + at the start of the next ``schedule()`` cycle. Applying it twice + (once here, once on consume) is unsafe under + ``update_request_metadata`` in generation mode, which resets + ``prompt_token_ids`` / ``_output_token_ids`` / ``num_computed_tokens`` + and would clobber any progress between the two calls. """ + del model_mode # only used by the (removed) double-apply branch omni_output = getattr(model_runner_output, "omni_connector_output", None) if omni_output is None: return self._latest_omni_connector_output = omni_output - input_coordinator = getattr(self, "input_coordinator", None) - if input_coordinator and omni_output.request_metadata: - input_coordinator.update_request_metadata( - self.requests, - omni_output.request_metadata, - model_mode=model_mode, - ) def _wrap_omni_scheduler_output( self, diff --git a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py index 8c61e43c490..2ff4d55a7cd 100644 --- a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py @@ -191,8 +191,11 @@ def _build_full_payload(pooling_output: dict[str, Any] | None, request: Any) -> payload = _normalize_additional_info(src_additional_info) payload.update(_normalize_additional_info(runtime_bridge_info)) payload["detok_id"] = [_to_int(pooling_output.get("detok_id"), default=_to_int(payload.get("detok_id"), default=0))] - payload["code_predictor_codes"] = token_ids - payload["finished"] = torch.tensor(True, dtype=torch.bool) + # Use nested OmniPayload shape so the scheduling-metadata extractor in + # OmniConnectorModelRunnerMixin reads codes.audio and meta.finished + # (flat keys at the top level are silently dropped with a warning). + payload["codes"] = {"audio": token_ids} + payload["meta"] = {"finished": torch.tensor(True, dtype=torch.bool)} return payload diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index be57ef55a5b..c6b1017cf41 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -339,7 +339,9 @@ def _filter_zero_codec_rows(codec_codes: torch.Tensor) -> torch.Tensor: is_all_zero = (codec_codes == 0).all(dim=(1, 2, 3)) nonzero_idx = (~is_all_zero).nonzero(as_tuple=True)[0] if len(nonzero_idx) == 0: - return codec_codes + # All rows are zero-padded; return an empty tensor so the caller + # can detect this via numel()==0 and skip the request. + return codec_codes[:0] if len(nonzero_idx) < codec_codes.shape[0]: return codec_codes[nonzero_idx] return codec_codes diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index 822300fd3e9..a48a844e3a1 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -413,5 +413,7 @@ def _ensure_list(x): ), ) ) - # payload["meta"] removed — was the only diff vs legacy payload, causes mix_to_text_audio_001 failure + # Intentionally omit payload["meta"]: the thinker->talker transition + # carries no scheduler-relevant metadata (next_stage_prompt_len / + # left_context_size are not set on this edge). return payload diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index fb9f3fd0d8f..15de87741b0 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -53,10 +53,13 @@ def talker2code2wav( # audio_codes may still contain zero-padded / invalid rows, so trim only # after filtering valid frames instead of trying to align EOS indices. seq_len = max(len(token_ids) - 1, 0) - # Filter invalid frames: zero-padded (EOS) and frames containing - # out-of-range values (e.g. stop_token_id=2150 exceeds codebook_size=2048). + # Filter invalid frames: zero-padded (EOS), out-of-range values (e.g. + # stop_token_id=2150 exceeds codebook_size=2048), and negative + # sentinels (e.g. -1 padding). _CODEBOOK_SIZE = 2048 - valid_mask = audio_codes.any(dim=1) & (audio_codes.max(dim=1).values < _CODEBOOK_SIZE) + valid_mask = ( + (audio_codes >= 0).all(dim=1) & audio_codes.any(dim=1) & (audio_codes.max(dim=1).values < _CODEBOOK_SIZE) + ) audio_codes = audio_codes[valid_mask] if seq_len > 0 and audio_codes.ndim == 2 and int(audio_codes.shape[0]) > seq_len: audio_codes = audio_codes[-seq_len:] @@ -307,7 +310,7 @@ def talker2code2wav_async_chunk( def _filter_audio_codes_qwen3_tts(audio_codes: torch.Tensor) -> torch.Tensor: - """Filter zero-padded + out-of-range codec frames. + """Filter zero-padded, out-of-range, and negative-padded codec frames. Mirrors the orchestrator-path body in `talker2code2wav` above. """ @@ -315,7 +318,9 @@ def _filter_audio_codes_qwen3_tts(audio_codes: torch.Tensor) -> torch.Tensor: return audio_codes if audio_codes.ndim != 2: return audio_codes - valid_mask = audio_codes.any(dim=1) & (audio_codes.max(dim=1).values < _CODEBOOK_SIZE) + valid_mask = ( + (audio_codes >= 0).all(dim=1) & audio_codes.any(dim=1) & (audio_codes.max(dim=1).values < _CODEBOOK_SIZE) + ) return audio_codes[valid_mask] diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index f5b16532b5d..550971484ff 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -255,10 +255,11 @@ def cleanup_finished_request(self, req_id: str) -> None: # cleanup proceeds. Without this, finished requests with no # downstream consumer (e.g. text-only on multi-modal arch) leave # the entry orphaned in _pending_full_payload_send across requests, - # which empirically destabilises subsequent thinker forwards - # (test_thinker_prefix_caching regression). flush is a near-no-op - # for paths with no consumer, and idempotent when the entry has - # already been flushed by the scheduler-driven path. + # which empirically destabilises subsequent thinker forwards by + # making prefix-cache reuse observe stale accumulator state. + # flush is a near-no-op for paths with no consumer, and idempotent + # when the entry has already been flushed by the scheduler-driven + # path. try: self.flush_full_payload_outputs({req_id}) except Exception: @@ -733,6 +734,27 @@ def _should_accumulate_full_payload_output(self) -> bool: _custom_process_func, both of which are set at init time. Avoid the per-step dynamic import inside the model decode loop. """ + if getattr(self, "_omni_connector", None) is None: + # No connector at all: send_full_payload_outputs would no-op. + # Skip the per-step accumulator+build that would otherwise be + # silently discarded. Defends against a terminal stage whose + # custom_process_input_func has a *_full_payload derivative in + # the same module (e.g. dynin stage 2 token2image_to_token2audio + # in pipelines that don't configure any connector at all). + # + # Known limitation: a *terminal-consumer* stage that has a + # connector configured for receiving upstream input is NOT + # caught here -- ``_omni_connector`` is non-None for it, and + # ``_load_custom_func`` may still resolve a ``*_full_payload`` + # derivative from this stage's ``custom_process_input_func``. + # In that case the accumulator builds payloads that + # ``send_full_payload_outputs`` later drops via its own + # connector-side checks (wasted CPU, not a functional bug). + # A topology-aware gate (explicit producer field or pipeline + # is_terminal info) would close the gap; that change is out + # of scope for this PR. + self._should_accumulate_full_payload_output_cached = False + return False cached = getattr(self, "_should_accumulate_full_payload_output_cached", None) if cached is not None: return cached From 31578254d5e1a28c6be31bcc63086da64f746a50 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Thu, 21 May 2026 15:45:17 +0000 Subject: [PATCH 11/19] trim codes Signed-off-by: natureofnature --- .../test_qwen3_omni_streaming_helpers.py | 171 +++++++++++------- tests/test_config_factory.py | 53 +++++- tests/worker/test_omni_gpu_model_runner.py | 3 + .../core/sched/omni_scheduling_coordinator.py | 4 - .../models/ming_flash_omni/pipeline.py | 1 - .../models/qwen3_tts/pipeline.py | 8 +- .../stage_input_processors/cosyvoice3.py | 3 - .../stage_input_processors/covo_audio.py | 6 - .../stage_input_processors/dynin_omni.py | 6 - .../stage_input_processors/mimo_audio.py | 3 - .../stage_input_processors/ming_flash_omni.py | 17 +- .../stage_input_processors/qwen2_5_omni.py | 84 +++------ .../stage_input_processors/qwen3_omni.py | 5 - .../stage_input_processors/qwen3_tts.py | 32 ++-- .../omni_connector_model_runner_mixin.py | 11 -- 15 files changed, 219 insertions(+), 188 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index 72b39ac0546..c01e098373e 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -499,17 +499,13 @@ def __init__(self): def test_covo_audio_llm2code2wav_token_only_smoke() -> None: - """Smoke: covo_audio token-only builder marks `_is_sync_input` - and returns placeholder prompts sized to audio_codes count.""" + """Smoke: covo_audio token-only builder returns placeholder prompts sized to audio_codes count.""" + # source_outputs is a list of objects with .outputs[0].token_ids + from vllm_omni.model_executor.models.covo_audio.config_covo_audio import COVO_AUDIO_TOKEN_INDEX from vllm_omni.model_executor.stage_input_processors.covo_audio import ( llm2code2wav_token_only, ) - assert getattr(llm2code2wav_token_only, "_is_sync_input", False) is True - - # source_outputs is a list of objects with .outputs[0].token_ids - from vllm_omni.model_executor.models.covo_audio.config_covo_audio import COVO_AUDIO_TOKEN_INDEX - class _Out: def __init__(self, tids): self.token_ids = tids @@ -545,15 +541,11 @@ def test_covo_audio_llm2code2wav_full_payload_smoke() -> None: def test_dynin_omni_token_only_smoke() -> None: - """Smoke: dynin_omni token-only builders mark _is_sync_input and return placeholders.""" + """Smoke: dynin_omni token-only builders return placeholders.""" from vllm_omni.model_executor.stage_input_processors.dynin_omni import ( - token2image_to_token2audio_token_only, token2text_to_token2image_token_only, ) - assert getattr(token2text_to_token2image_token_only, "_is_sync_input", False) is True - assert getattr(token2image_to_token2audio_token_only, "_is_sync_input", False) is True - class _Out: def __init__(self, tids, mm=None): self.token_ids = tids @@ -576,7 +568,7 @@ def __init__(self, outs): def test_dynin_omni_full_payload_smoke() -> None: - """Smoke: dynin_omni producer-side payload builder returns token_ids + finished.""" + """Smoke: dynin_omni producer-side payload builder returns nested OmniPayload + carries metadata.""" from types import SimpleNamespace from vllm_omni.model_executor.stage_input_processors.dynin_omni import ( @@ -587,9 +579,9 @@ def test_dynin_omni_full_payload_smoke() -> None: req = SimpleNamespace(output_token_ids=[], additional_information={"speaker": ["alice"]}) payload = token2text_to_token2image_full_payload(None, pooling, req) assert payload is not None - assert payload["code_predictor_codes"] == [1, 2, 3] - assert payload["finished"].item() is True - # additional_information carried forward as list-wrapped (speaker) + assert payload["codes"]["audio"] == [1, 2, 3] + assert payload["meta"]["finished"].item() is True + # additional_information is normalized + carried forward (speaker stays list-wrapped). assert payload.get("speaker") == ["alice"] @@ -601,8 +593,6 @@ def test_qwen2_5_omni_talker2code2wav_token_only_smoke() -> None: talker2code2wav_token_only, ) - assert getattr(talker2code2wav_token_only, "_is_sync_input", False) is True - class _Out: def __init__(self, tids): self.cumulative_token_ids = tids @@ -639,15 +629,13 @@ def test_qwen2_5_omni_talker2code2wav_full_payload_smoke() -> None: def test_mimo_audio_llm2code2wav_token_only_smoke() -> None: - """Smoke: mimo_audio token-only builder marks _is_sync_input + sizes prompt.""" + """Smoke: mimo_audio token-only builder sizes prompt.""" import torch from vllm_omni.model_executor.stage_input_processors.mimo_audio import ( llm2code2wav_token_only, ) - assert getattr(llm2code2wav_token_only, "_is_sync_input", False) is True - class _Out: def __init__(self, mm): self.multimodal_output = mm @@ -716,15 +704,13 @@ def test_mimo_audio_full_payload_nested_fallback() -> None: def test_qwen3_tts_talker2code2wav_token_only_smoke() -> None: - """Smoke: qwen3_tts token-only marks _is_sync_input + sizes placeholder.""" + """Smoke: qwen3_tts token-only sizes placeholder.""" import torch from vllm_omni.model_executor.stage_input_processors.qwen3_tts import ( talker2code2wav_token_only, ) - assert getattr(talker2code2wav_token_only, "_is_sync_input", False) is True - class _Out: def __init__(self, mm, tids): self.multimodal_output = mm @@ -768,7 +754,9 @@ def test_qwen3_tts_talker2code2wav_full_payload_smoke() -> None: def test_qwen3_tts_full_payload_with_ref_code() -> None: - """Smoke: ref_code prepended via codes.ref + meta.ref_code_len from flat path.""" + """Exact: ref_code is prepended (not appended) to audio, ref_code_len trims + ref, and the flatten is codebook-major. Protects against ref-append-position + regressions, ref_code_len-not-applied bugs, and flatten-order regressions.""" from types import SimpleNamespace import torch @@ -777,20 +765,34 @@ def test_qwen3_tts_full_payload_with_ref_code() -> None: talker2code2wav_full_payload, ) - # Audio: 3 frames [3, 16] + # Audio: 3 frames [3, 16] (no filter drops these — all positive, in-range). audio = torch.arange(3 * 16, dtype=torch.long).reshape(3, 16) + 1 - # Ref code: 2 frames [2, 16] (already 2-D) + # Ref code: 2 frames [2, 16] (already 2-D), distinct value range so we can + # detect the prepend ordering. ref = torch.arange(2 * 16, dtype=torch.long).reshape(2, 16) + 100 pooling_output = { "codes.audio": audio, "codes.ref": [ref], "meta.ref_code_len": torch.tensor([2], dtype=torch.int32), } - req = SimpleNamespace(output_token_ids=list(range(10))) + req = SimpleNamespace(output_token_ids=list(range(10))) # seq_len = 9 > 3, no audio crop payload = talker2code2wav_full_payload(None, pooling_output, req) assert payload is not None - # Total frames = 2 (ref) + 3 (audio) = 5; codebook-major: 16 * 5 = 80 - assert len(payload["codes"]["audio"]) == 80 + + # Exact expected: ref (prepended) + audio (no crop since seq_len > rows), then + # transpose [5, 16] -> [16, 5] and flatten row-major (codebook-major). + expected = torch.cat([ref, audio], dim=0).transpose(0, 1).reshape(-1).tolist() + assert payload["codes"]["audio"] == expected, ( + f"codec flatten mismatch -- got first 8 = {payload['codes']['audio'][:8]}, expected first 8 = {expected[:8]}" + ) + assert len(payload["codes"]["audio"]) == 80 # 16 quantizers * (2 ref + 3 audio) frames + + # Sanity guards: first codebook-major column = [ref[0,0], ref[1,0], audio[0,0], ...], + # so the prepend order must put 100 before 1. + first_col = payload["codes"]["audio"][:5] + assert first_col == [100, 116, 1, 17, 33], ( + f"first column wrong: {first_col} -- ref likely appended instead of prepended" + ) def test_qwen3_tts_full_payload_nested_fallback() -> None: @@ -811,14 +813,82 @@ def test_qwen3_tts_full_payload_nested_fallback() -> None: assert len(payload["codes"]["audio"]) == 32 # 16 * 2 +def test_qwen3_tts_codec_filter_and_crop_edge_cases() -> None: + """Regression gate for codec filter + seq_len crop on both token_only and full_payload. + + Mixes valid / all-zero / negative / >=_CODEBOOK_SIZE rows. Asserts: + - Token-only placeholder length matches Q * (#kept rows after crop). + - Full-payload codes.audio matches the exact codebook-major flatten + of the kept-and-cropped rows. + + Protects against future cleanup reverting the codex P2 #3 (negative + codec filter) or the _CODEBOOK_SIZE upper bound. + """ + from types import SimpleNamespace + + import torch + + from vllm_omni.model_executor.stage_input_processors.qwen3_tts import ( + _CODEBOOK_SIZE, + talker2code2wav_full_payload, + talker2code2wav_token_only, + ) + + Q = 4 # simulated num_quantizers (default is 16; small here for readability) + # 7 rows: valid / all-zero / negative / out-of-range / boundary-valid / valid / valid. + audio_rows = [ + [10, 20, 30, 40], # row 0: valid -> KEEP + [0, 0, 0, 0], # row 1: all-zero -> DROP + [50, -1, 60, 70], # row 2: negative -> DROP + [100, _CODEBOOK_SIZE, 110, 120], # row 3: >= 2048 -> DROP + [200, _CODEBOOK_SIZE - 1, 210, 220], # row 4: boundary 2047 -> KEEP + [300, 310, 320, 330], # row 5: valid -> KEEP + [400, 410, 420, 430], # row 6: valid -> KEEP + ] + audio = torch.tensor(audio_rows, dtype=torch.long) + kept = [audio_rows[i] for i in (0, 4, 5, 6)] # 4 rows after filter + + # === token_only path === + # cumulative_token_ids of length 4 -> seq_len = 3 -> crop kept[-3:] = rows {4, 5, 6} + class _Out: + def __init__(self, ctids, mm): + self.cumulative_token_ids = ctids + self.multimodal_output = mm + + class _Wrap: + def __init__(self, ctids, mm): + self.outputs = [_Out(ctids, mm)] + self.finished = True + + mm = {"codes": {"audio": audio}, "meta": {}} + src = [_Wrap(ctids=[1, 2, 3, 4], mm=mm)] + out = talker2code2wav_token_only(src, prompt=None) + assert len(out) == 1 + # No ref_code -> ref_frames = 0; expected prompt_len = Q * (#kept-after-crop) = 4 * 3 = 12 + assert len(out[0]["prompt_token_ids"]) == Q * 3 + + # === full_payload path === + pooling_output = {"codes.audio": audio} + req = SimpleNamespace(output_token_ids=[1, 2, 3, 4]) # seq_len = 3 + payload = talker2code2wav_full_payload(None, pooling_output, req) + assert payload is not None + # After filter + crop, kept rows = [row4, row5, row6] = [[200,2047,210,220],[300,310,320,330],[400,410,420,430]] + # Codebook-major flatten: transpose [3, Q] -> [Q, 3] -> reshape(-1) + cropped = torch.tensor(kept[-3:], dtype=torch.long) + expected = cropped.transpose(0, 1).reshape(-1).tolist() + assert payload["codes"]["audio"] == expected + # Sanity: confirm the boundary-valid 2047 survived (codex P2 #3 regression guard). + assert _CODEBOOK_SIZE - 1 in payload["codes"]["audio"] + # Sanity: confirm no negative or >=_CODEBOOK_SIZE codec id leaked through. + assert all(0 <= v < _CODEBOOK_SIZE for v in payload["codes"]["audio"]) + + def test_cosyvoice3_text2flow_token_only_smoke() -> None: - """Smoke: cosyvoice3 token-only marks _is_sync_input + carries ids.prompt only.""" + """Smoke: cosyvoice3 token-only carries ids.prompt only.""" from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import ( text2flow_token_only, ) - assert getattr(text2flow_token_only, "_is_sync_input", False) is True - class _Out: def __init__(self, tids): self.cumulative_token_ids = tids @@ -903,13 +973,11 @@ def test_cosyvoice3_full_payload_replace_keys_present() -> None: def test_ming_flash_omni_thinker2talker_token_only_smoke() -> None: - """Smoke: ming_flash_omni token-only marks _is_sync_input + carries voice metadata.""" + """Smoke: ming_flash_omni token-only carries voice metadata.""" from vllm_omni.model_executor.stage_input_processors.ming_flash_omni import ( thinker2talker_token_only, ) - assert getattr(thinker2talker_token_only, "_is_sync_input", False) is True - class _Out: def __init__(self, text): self.text = text @@ -934,20 +1002,8 @@ def __init__(self, info): assert info["ming_task"] == "omni" -def test_ming_flash_omni_thinker2talker_full_payload_noop() -> None: - """ming_flash_omni thinker2talker_full_payload is a no-op (returns None).""" - from vllm_omni.model_executor.stage_input_processors.ming_flash_omni import ( - thinker2talker_full_payload, - ) - - payload = thinker2talker_full_payload(None, {"anything": "ignored"}, None) - assert payload is None - - def test_qwen2_5_omni_thinker2talker_token_only_smoke() -> None: - """Smoke: qwen2_5_omni thinker token-only marks _is_sync_input + ports legacy body.""" - import torch - + """Smoke: qwen2_5_omni thinker token-only allocates prompt slots; bulk payload ships via connector.""" from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( TALKER_CODEC_END_TOKEN_ID, TALKER_CODEC_PAD_TOKEN_ID, @@ -955,34 +1011,25 @@ def test_qwen2_5_omni_thinker2talker_token_only_smoke() -> None: thinker2talker_token_only, ) - assert getattr(thinker2talker_token_only, "_is_sync_input", False) is True - - class _Out: - def __init__(self, ctids, mm): - self.cumulative_token_ids = ctids - self.multimodal_output = mm - class _Wrap: - def __init__(self, prompt_tids, ctids, mm, rid): - self.outputs = [_Out(ctids, mm)] + def __init__(self, prompt_tids, rid): + self.outputs = [object()] self.prompt_token_ids = prompt_tids self.request_id = rid class _Prompt(dict): pass - # Latent shaped [prompt_len + decode_len, hidden] = [5 + 3, 8] - latent = torch.randn(8, 8) - src = [_Wrap(prompt_tids=[1, 2, 3, 4, 5], ctids=[10, 20, 30], mm={"latent": latent}, rid="r-1")] + src = [_Wrap(prompt_tids=[1, 2, 3, 4, 5], rid="r-1")] prompt = [_Prompt(multi_modal_data=None)] out = thinker2talker_token_only(src, prompt=prompt) assert len(out) == 1 - # Talker prompt = START + PAD*prompt_len + END expected_prompt_len = 1 + len([1, 2, 3, 4, 5]) + 1 assert len(out[0]["prompt_token_ids"]) == expected_prompt_len assert out[0]["prompt_token_ids"][0] == TALKER_CODEC_START_TOKEN_ID assert out[0]["prompt_token_ids"][-1] == TALKER_CODEC_END_TOKEN_ID assert all(t == TALKER_CODEC_PAD_TOKEN_ID for t in out[0]["prompt_token_ids"][1:-1]) + assert out[0]["additional_information"] is None def test_qwen2_5_omni_thinker2talker_full_payload_noop() -> None: diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index c54a39a1e38..0db1d4fbae8 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -1841,12 +1841,19 @@ def test_async_chunk_dispatches_processors(self): ) assert async_stages[1].custom_process_input_func is None - # async_chunk=False → stage 0 has no streaming processor, stage 1's - # batch-end processor wires up. + # async_chunk=False → stage 0 ships the bulk codec via the + # worker-connector full-payload producer; stage 1 wires the + # ``_token_only`` placeholder so the orchestrator emits no + # legacy ``additional_information``-shaped input (PR3 sync- + # via-connector data plane). sync_stages = merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=False)) - assert "custom_process_next_stage_input_func" not in sync_stages[0].yaml_engine_args + assert ( + sync_stages[0] + .yaml_engine_args["custom_process_next_stage_input_func"] + .endswith("talker2code2wav_full_payload") + ) assert sync_stages[1].custom_process_input_func is not None - assert sync_stages[1].custom_process_input_func.endswith("talker2code2wav") + assert sync_stages[1].custom_process_input_func.endswith("talker2code2wav_token_only") def test_async_chunk_dispatches_qwen3_omni_processors(self): import runpy @@ -1883,6 +1890,44 @@ def test_async_chunk_dispatches_qwen3_omni_processors(self): .endswith("talker2code2wav_full_payload") ) + def test_ming_flash_omni_topology(self): + """Guard ming_flash_omni's PR3 cleanup: stage 0 has no full-payload + producer hook (the connector path was removed as fake -- arch is not + in ``_FULL_PAYLOAD_INPUT_STAGES``), and stage 1 still wires the + legacy ``thinker2talker`` (custom_process_input_func) plus the + ``thinker2talker_token_only`` placeholder (sync_process_input_func). + Merge under either async_chunk mode must not re-introduce a + stage-0 full-payload hook.""" + from vllm_omni.config.stage_config import DeployConfig, merge_pipeline_deploy + + pipeline = _PIPELINE_REGISTRY["ming_flash_omni"] + + stage0, stage1 = pipeline.stages + assert stage0.custom_process_next_stage_input_func is None, ( + "ming_flash_omni stage 0 must not declare a full-payload producer " + "(connector path is not active for this arch)." + ) + assert stage1.custom_process_input_func is not None + assert stage1.custom_process_input_func.endswith("thinker2talker") + assert stage1.sync_process_input_func is not None + assert stage1.sync_process_input_func.endswith("thinker2talker_token_only") + + # async_chunk=True must now be rejected: removing the fake hook means + # there is no next-stage input processor for the validator to accept. + # (Positive consequence -- users can't accidentally enable async_chunk + # on an arch that doesn't actually support it.) + import pytest as _pytest + + with _pytest.raises(ValueError, match="async_chunk=True"): + merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=True)) + + # async_chunk=False merges cleanly and stage-0 yaml_engine_args carries + # no spurious full-payload hook. + merged = merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=False)) + assert "custom_process_next_stage_input_func" not in merged[0].yaml_engine_args, ( + "stage-0 full-payload hook unexpectedly re-appeared in yaml_engine_args" + ) + class TestSamplingConstraintsPrecedence: """Test that pipeline sampling_constraints override deploy defaults.""" diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index d04eab3b334..b5e4ad4a4ec 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -410,6 +410,9 @@ def _make_full_payload_accumulation_runner( runner._custom_process_func = object() runner._pending_full_payload_send = {} runner._stage_id = 1 + # Non-None sentinel: the gate short-circuits to False when no connector + # is configured at all (terminal stages in pipelines with no connector). + runner._omni_connector = object() return runner diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index 26033f11328..a8d68669b82 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -32,10 +32,6 @@ # here without also wiring its worker connector init produces a permanent # Stage 1 hang (gate parks the request, no transport ever releases it). # -# The `_is_sync_input` markers on per-model `*_token_only` builders in -# stage_input_processors/ remain as forward-compat documentation; when -# init is generalised this whitelist can move back to a structural marker -# check or be dropped entirely. _FULL_PAYLOAD_INPUT_STAGES: frozenset[tuple[str, str]] = frozenset( { ("Qwen3OmniMoeForConditionalGeneration", "talker"), diff --git a/vllm_omni/model_executor/models/ming_flash_omni/pipeline.py b/vllm_omni/model_executor/models/ming_flash_omni/pipeline.py index c818aebad3d..a1e1ef4699a 100644 --- a/vllm_omni/model_executor/models/ming_flash_omni/pipeline.py +++ b/vllm_omni/model_executor/models/ming_flash_omni/pipeline.py @@ -42,7 +42,6 @@ # Thinker reads the LLM sub-config of BailingMM2Config hf_config_name="llm_config", engine_output_type="text", - custom_process_next_stage_input_func=f"{_PROC}.thinker2talker_full_payload", sampling_constraints={"detokenize": True}, ), StagePipelineConfig( diff --git a/vllm_omni/model_executor/models/qwen3_tts/pipeline.py b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py index d37dd23c4fe..7f50931ddf7 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/pipeline.py +++ b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py @@ -41,7 +41,13 @@ final_output_type="audio", engine_output_type="audio", model_arch="Qwen3TTSCode2Wav", - custom_process_input_func=f"{_PROC}.talker2code2wav", + # ``sync_process_input_func`` is the only input-proc override for + # this stage in sync (non-async-chunk) mode: a length-only + # ``_token_only`` placeholder. The bulk codec payload itself + # ships via the worker connector from stage 0's + # ``talker2code2wav_full_payload`` producer. Under async_chunk + # mode no pre-stage processing is needed -- chunks deliver + # directly to the consumer. sync_process_input_func=f"{_PROC}.talker2code2wav_token_only", sampling_constraints={"detokenize": True}, extras={"tts_args": {"max_instructions_length": 500}}, diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index 3aed006600b..3bdd9dd8934 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -383,9 +383,6 @@ def text2flow_token_only( return engine_inputs -text2flow_token_only._is_sync_input = True - - def text2flow_full_payload( transfer_manager, pooling_output, diff --git a/vllm_omni/model_executor/stage_input_processors/covo_audio.py b/vllm_omni/model_executor/stage_input_processors/covo_audio.py index c5cdc312cc6..e56bf845b7f 100644 --- a/vllm_omni/model_executor/stage_input_processors/covo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/covo_audio.py @@ -74,12 +74,6 @@ def llm2code2wav_token_only( return code2wav_inputs -# Mark for forward compatibility; current consumer wait gating is -# _FULL_PAYLOAD_INPUT_STAGES-driven (see the mixin -# should_accumulate_full_payload_output docstring). -llm2code2wav_token_only._is_sync_input = True - - def llm2code2wav_full_payload( transfer_manager: Any, pooling_output: dict[str, Any], diff --git a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py index 2ff4d55a7cd..cd9ddd583ff 100644 --- a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py @@ -265,9 +265,3 @@ def token2image_to_token2audio_token_only( source_stage_id = engine_input_source[0] if engine_input_source else 0 source_outputs = stage_list[source_stage_id].engine_outputs return _token_only_from_source(source_outputs) - - -# Mark sync-side builders for forward compatibility; current consumer -# wait gating is _FULL_PAYLOAD_INPUT_STAGES-driven. -token2text_to_token2image_token_only._is_sync_input = True -token2image_to_token2audio_token_only._is_sync_input = True diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index c6b1017cf41..67309967e12 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -389,9 +389,6 @@ def llm2code2wav_token_only( return code2wav_inputs -llm2code2wav_token_only._is_sync_input = True - - def llm2code2wav_full_payload( transfer_manager, pooling_output: dict, diff --git a/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py b/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py index 40830bc0a55..938018856f4 100644 --- a/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py @@ -535,22 +535,13 @@ def thinker2talker( return talker_inputs -# ============================================================================ -# Worker-connector data plane (non-async-chunk path) -- inactive for -# ming_flash_omni. -# ming_flash_omni's thinker->talker bridge passes detokenized text only; -# voice/speaker metadata flows through the USER request's -# additional_information, not the model's pooler_output. No heavy -# tensor to migrate, so ``thinker2talker_full_payload`` returns None. # ming_flash_omni is not in ``_OMNI_CONNECTOR_INIT_ARCHS`` or # ``_FULL_PAYLOAD_INPUT_STAGES``, so the worker connector is not # initialised for this arch and the consumer never waits on a connector -# payload; data flows through ``additional_information`` written by -# ``thinker2talker_token_only``. The ``*_full_payload`` definition is -# retained for forward compatibility. -# ============================================================================ - -_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() +# payload. Data flows through ``additional_information`` written by +# ``thinker2talker_token_only`` (wired as ``sync_process_input_func`` +# in the pipeline) or the legacy ``thinker2talker`` (wired as +# ``custom_process_input_func``). def thinker2talker_token_only( diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index a48a844e3a1..cf4d6ad9780 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -99,11 +99,11 @@ def talker2code2wav( # omni_scheduling_coordinator): # - thinker->talker reads accumulated ``pooling_output["hidden"]`` and # packs an OmniPayload-shaped dict (embed.prefill / -# hidden_states.output / ids.prompt / ids.output) for the talker. -# ``thinker2talker_token_only`` writes the same shape into -# ``additional_information`` as a legacy sync fallback; the talker's -# ``talker_preprocess`` reads either source through the same payload -# keys. +# hidden_states.output / ids.prompt / ids.output) for the talker, which +# the talker's ``talker_preprocess`` reads from +# ``model_intermediate_buffer``. The shape matches what legacy +# ``thinker2talker`` writes into ``additional_information`` as a debug +# fallback; ``thinker2talker_token_only`` only allocates prompt slots. # - talker->code2wav strips TALKER_CODEC_{START,END} boundary tokens # and ships the codec token ids. # ============================================================================ @@ -164,9 +164,6 @@ def talker2code2wav_token_only( return code2wav_inputs -talker2code2wav_token_only._is_sync_input = True - - def talker2code2wav_full_payload( transfer_manager, pooling_output: dict, @@ -199,35 +196,31 @@ def talker2code2wav_full_payload( # per decode step on ``pooling_output["hidden"]`` (unpacked from # ``OmniOutput.text_hidden_states``); the full-payload accumulator # concatenates them so ``thinker2talker_full_payload`` sees the full -# prefill+decode trajectory and packs an OmniPayload-shaped dict. -# ``thinker2talker_token_only`` writes the same shape into -# ``additional_information`` as a legacy sync fallback; the talker's -# ``talker_preprocess`` reads ``info_dict`` regardless of source. +# prefill+decode trajectory and packs an OmniPayload-shaped dict that +# the talker's ``talker_preprocess`` reads from +# ``model_intermediate_buffer``. ``thinker2talker_token_only`` only +# allocates the talker's codec prompt slots; legacy +# ``thinker2talker`` above remains as a debug fallback that bundles the +# same shape into ``additional_information``. # ============================================================================ -_FULL_PAYLOAD_REPLACE_KEYS: frozenset[str] = frozenset() - def thinker2talker_token_only( source_outputs, prompt: OmniTokensPrompt | TextPrompt = None, requires_multimodal_data: bool = False, ): - """Sync-side builder for the non-async-chunk thinker->talker path. - - Body mirrors the legacy ``thinker2talker`` above: packs an - OmniPayload-shaped dict (hidden_states.output / embed.prefill / - ids.prompt / ids.output) into ``additional_information``, allocates - TALKER_CODEC_{START,PAD,END} prompt slots, and forwards - ``multi_modal_data``. Serves as a legacy sync fallback; the same - shape is also built by ``thinker2talker_full_payload`` below and - shipped via the worker connector. - - The ``_is_sync_input = True`` marker below is currently dormant - forward-compat documentation -- the consumer-wait gate is - whitelist-driven via ``_FULL_PAYLOAD_INPUT_STAGES`` (see the mixin - ``should_accumulate_full_payload_output`` docstring), not by this - marker. + """Placeholder builder for the connector-driven thinker->talker path. + + Allocates the TALKER_CODEC_{START,PAD,END} prompt slots sized to the + thinker prompt length and forwards ``multi_modal_data``. The bulk + payload (hidden_states / embed / ids) ships exclusively through + ``thinker2talker_full_payload`` via the worker connector and lands + in ``model_intermediate_buffer`` before the talker's forward() runs. + + Consumer-wait gating is whitelist-driven via + ``_FULL_PAYLOAD_INPUT_STAGES`` (see the mixin + ``should_accumulate_full_payload_output`` docstring). """ thinker_outputs = source_outputs talker_inputs = [] @@ -237,29 +230,14 @@ def thinker2talker_token_only( thinker_output.request_id: p.get("multi_modal_data", None) for thinker_output, p in zip(thinker_outputs, prompt) } - for i, thinker_output in enumerate(thinker_outputs): - output = thinker_output.outputs[0] + for thinker_output in thinker_outputs: prompt_token_ids = thinker_output.prompt_token_ids - thinker_output_ids = output.cumulative_token_ids - prompt_token_ids_len = len(prompt_token_ids) - mm: OmniPayload = output.multimodal_output - latent = mm["latent"] - thinker_hidden_states = latent.clone().detach().to(latent.device) - decode_hidden = thinker_hidden_states[prompt_token_ids_len:].to(torch.float32) - prefill_hidden = thinker_hidden_states[:prompt_token_ids_len].to(torch.float32) - additional_information = to_dict( - OmniPayloadStruct( - hidden_states=HiddenStatesStruct(output=decode_hidden, output_shape=list(decode_hidden.shape)), - embed=EmbeddingsStruct(prefill=prefill_hidden, prefill_shape=list(prefill_hidden.shape)), - ids=IdsStruct(prompt=list(prompt_token_ids), output=list(thinker_output_ids)), - ) - ) talker_inputs.append( OmniTokensPrompt( prompt_token_ids=[TALKER_CODEC_START_TOKEN_ID] + [TALKER_CODEC_PAD_TOKEN_ID] * (len(prompt_token_ids)) + [TALKER_CODEC_END_TOKEN_ID], - additional_information=additional_information, + additional_information=None, multi_modal_data=( multi_modal_data[thinker_output.request_id] if requires_multimodal_data and multi_modal_data is not None @@ -272,9 +250,6 @@ def thinker2talker_token_only( return talker_inputs -thinker2talker_token_only._is_sync_input = True - - def thinker2talker_full_payload( transfer_manager, pooling_output, @@ -293,11 +268,12 @@ def thinker2talker_full_payload( We split it at ``len(prompt_token_ids)`` into prefill embeddings and decode hidden states, then pack the ``OmniPayload``-shaped dict that - the talker's ``thinker_to_talker_process`` already reads (keys - ``embed.prefill`` / ``hidden_states.output`` / ``ids.prompt`` / - ``ids.output``). Shape matches what ``thinker2talker_token_only`` - writes into ``additional_information``, so the talker consumes the - same payload layout from either path. + the talker's ``thinker_to_talker_process`` reads from + ``model_intermediate_buffer`` (keys ``embed.prefill`` / + ``hidden_states.output`` / ``ids.prompt`` / ``ids.output``). Shape + matches what legacy ``thinker2talker`` writes into + ``additional_information`` as a debug fallback, so the talker + consumes the same payload layout from either path. Like ``qwen3_omni.thinker2talker_full_payload``, we apply a finish-reason-aware stop-row trim: vLLM v1 appends the sampled diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 83b267d5933..77a9af6c0c4 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -955,8 +955,3 @@ def talker2code2wav( ) return code2wav_inputs - - -# Mark for forward compatibility; current consumer wait gating is -# _FULL_PAYLOAD_INPUT_STAGES-driven (see should_accumulate_full_payload_output above). -thinker2talker_token_only._is_sync_input = True diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 15de87741b0..cb9f47a4f01 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -324,6 +324,21 @@ def _filter_audio_codes_qwen3_tts(audio_codes: torch.Tensor) -> torch.Tensor: return audio_codes[valid_mask] +def _coerce_ref_code_len(raw) -> int: + """Coerce mm["meta"]["ref_code_len"] / pooling_output["meta.ref_code_len"] + raw value (Tensor | int | None) into a non-negative int. Mirrors the + extraction inlined in the legacy ``talker2code2wav`` path; clamps any + negative input to 0 since downstream code treats this as a non-negative + frame count.""" + if isinstance(raw, torch.Tensor): + value = int(raw.reshape(-1)[-1].item()) if raw.numel() > 0 else 0 + elif raw is None: + value = 0 + else: + value = int(raw) + return max(value, 0) + + def _normalize_ref_code(ref_code, num_quantizers: int, ref_code_len: int): """Coerce ref_code into a [ref_len, Q] tensor or None. Mirrors orchestrator path.""" if isinstance(ref_code, list): @@ -383,12 +398,7 @@ def talker2code2wav_token_only( ref_code_raw = mm_codes.get("ref") if isinstance(mm_codes, dict) else None ref_code_len_raw = mm.get("meta", {}).get("ref_code_len") if isinstance(mm.get("meta"), dict) else None - if isinstance(ref_code_len_raw, torch.Tensor): - ref_code_len = int(ref_code_len_raw.reshape(-1)[-1].item()) if ref_code_len_raw.numel() > 0 else 0 - elif ref_code_len_raw is None: - ref_code_len = 0 - else: - ref_code_len = int(ref_code_len_raw) + ref_code_len = _coerce_ref_code_len(ref_code_len_raw) _, ref_frames = _normalize_ref_code(ref_code_raw, num_quantizers, ref_code_len) # Codebook-major flat: Q * (ref_frames + audio_frames) @@ -412,9 +422,6 @@ def talker2code2wav_token_only( return code2wav_inputs -talker2code2wav_token_only._is_sync_input = True - - def talker2code2wav_full_payload( transfer_manager, pooling_output, @@ -458,12 +465,7 @@ def talker2code2wav_full_payload( meta_nested = pooling_output.get("meta") if isinstance(meta_nested, dict): ref_code_len_raw = meta_nested.get("ref_code_len") - if isinstance(ref_code_len_raw, torch.Tensor): - ref_code_len = int(ref_code_len_raw.reshape(-1)[-1].item()) if ref_code_len_raw.numel() > 0 else 0 - elif ref_code_len_raw is None: - ref_code_len = 0 - else: - ref_code_len = int(ref_code_len_raw) + ref_code_len = _coerce_ref_code_len(ref_code_len_raw) # codes.ref — flat dotted then nested fallback. ref_code_raw = pooling_output.get("codes.ref") diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index 550971484ff..e0ea078979e 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -54,17 +54,6 @@ def should_accumulate_full_payload_output(model_config, custom_process_func) -> stage_config's ``custom_process_next_stage_input_func`` or the ``*_full_payload`` derivative of ``custom_process_input_func``), the stage is not in async_chunk mode, and ``model_stage`` is set. - - NOTE: the ``_is_sync_input`` marker is on the *consumer-side* - ``*_token_only`` builder, not on the ``*_full_payload`` payload builder that - workers load on the *producer* side, so checking it here would - always return False and the full-payload accumulator would never run. The - marker itself is currently dormant forward-compat documentation: - the consumer-side scheduler gate - (``uses_full_payload_input_coordinator`` in - ``omni_scheduling_coordinator.py``) is whitelist-driven on - ``(model_arch, model_stage)`` against ``_FULL_PAYLOAD_INPUT_STAGES`` - -- adding the marker alone does not open a consumer-wait gate. """ if custom_process_func is None: return False From 485ff273a48406f3f659888ffc8e76c8c678f44a Mon Sep 17 00:00:00 2001 From: natureofnature Date: Thu, 21 May 2026 17:42:14 +0000 Subject: [PATCH 12/19] update for review Signed-off-by: natureofnature --- vllm_omni/core/sched/omni_ar_scheduler.py | 2 +- .../core/sched/omni_generation_scheduler.py | 2 +- vllm_omni/core/sched/omni_scheduler_mixin.py | 26 ++++++------ .../core/sched/omni_scheduling_coordinator.py | 8 ++-- vllm_omni/core/sched/output.py | 4 +- .../stage_input_processors/cosyvoice3.py | 16 ++++++++ .../stage_input_processors/covo_audio.py | 7 ++++ .../stage_input_processors/dynin_omni.py | 9 +++++ .../stage_input_processors/mimo_audio.py | 17 ++++++++ .../stage_input_processors/qwen2_5_omni.py | 40 ++++++++++++++++++- .../stage_input_processors/qwen3_tts.py | 18 +++++++++ .../omni_connector_model_runner_mixin.py | 13 ++++-- 12 files changed, 136 insertions(+), 26 deletions(-) diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index c306765b891..3d51765a9af 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -566,7 +566,7 @@ def update_from_output( engine_core_outputs[0] = eco = EngineCoreOutputs() eco.scheduler_stats = stats - self._capture_omni_connector_output(model_runner_output, model_mode="ar") + self._capture_omni_connector_output(model_runner_output) # Free blocks that were held for transfer (kv_ready and # active_kv_transfers updates already done before the per-request loop). diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 403bd4c5db7..8d665574213 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -659,7 +659,7 @@ def update_from_output( engine_core_outputs[0] = eco = EngineCoreOutputs() eco.scheduler_stats = stats - self._capture_omni_connector_output(model_runner_output, model_mode="generation") + self._capture_omni_connector_output(model_runner_output) return engine_core_outputs diff --git a/vllm_omni/core/sched/omni_scheduler_mixin.py b/vllm_omni/core/sched/omni_scheduler_mixin.py index 16b7c373c99..78284efa3e9 100644 --- a/vllm_omni/core/sched/omni_scheduler_mixin.py +++ b/vllm_omni/core/sched/omni_scheduler_mixin.py @@ -8,7 +8,7 @@ from vllm.v1.engine import EngineCoreEventType from vllm.v1.request import Request, RequestStatus, StreamingUpdate -from vllm_omni.core.sched.output import OmniInputRegistration, OmniSchedulerOutput +from vllm_omni.core.sched.output import OmniChunkRecvHandle, OmniSchedulerOutput logger = init_logger(__name__) @@ -101,21 +101,19 @@ def _process_pending_input_timeouts(self) -> None: ) self.finish_requests(present_ids, RequestStatus.FINISHED_ERROR) - def _capture_omni_connector_output(self, model_runner_output: Any, model_mode: str) -> None: + def _capture_omni_connector_output(self, model_runner_output: Any) -> None: """Stash the model runner's omni_connector_output for next schedule(). - Called at the tail of every ``update_from_output()``. Identical - between AR and generation schedulers except for ``model_mode``. - - NOTE: this method only stashes the output. Applying the metadata - is the responsibility of ``_consume_pending_connector_output()`` - at the start of the next ``schedule()`` cycle. Applying it twice - (once here, once on consume) is unsafe under - ``update_request_metadata`` in generation mode, which resets - ``prompt_token_ids`` / ``_output_token_ids`` / ``num_computed_tokens`` - and would clobber any progress between the two calls. + Called at the tail of every ``update_from_output()`` -- identical + between AR and generation schedulers. Only stashes the output; + applying the metadata is the responsibility of + ``_consume_pending_connector_output()`` at the start of the next + ``schedule()`` cycle. Applying it twice (once here, once on + consume) is unsafe under ``update_request_metadata`` in + generation mode, which resets ``prompt_token_ids`` / + ``_output_token_ids`` / ``num_computed_tokens`` and would + clobber any progress between the two calls. """ - del model_mode # only used by the (removed) double-apply branch omni_output = getattr(model_runner_output, "omni_connector_output", None) if omni_output is None: return @@ -126,7 +124,7 @@ def _wrap_omni_scheduler_output( base: SchedulerOutput, *, finished_requests_needing_kv_transfer: dict | None = None, - pending_input_registrations: list[OmniInputRegistration] | None = None, + pending_input_registrations: list[OmniChunkRecvHandle] | None = None, ) -> OmniSchedulerOutput: """Wrap a base ``SchedulerOutput`` in ``OmniSchedulerOutput``. diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index a8d68669b82..4056fd93861 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -19,7 +19,7 @@ from vllm.logger import init_logger from vllm.v1.request import Request, RequestStatus -from vllm_omni.core.sched.output import OmniInputRegistration +from vllm_omni.core.sched.output import OmniChunkRecvHandle logger = init_logger(__name__) @@ -111,7 +111,7 @@ def __init__(self, scheduler_max_num_seqs: int, stage_id: int = 0, async_chunk: # can call register_chunk_recv(). Typed concretely (not list[Any]) so # the surrounding OmniSchedulerOutput stays msgspec-friendly across # default, PD-disagg, and multi-node executor IPC paths. - self.pending_input_registrations: list[OmniInputRegistration] = [] + self.pending_input_registrations: list[OmniChunkRecvHandle] = [] # Monotonic timestamp recording when each request first entered # WAITING_FOR_CHUNK or WAITING_FOR_INPUT. Used by @@ -219,7 +219,7 @@ def process_pending_full_payload_inputs( to_remove.append(request) self._waiting_for_input.append(request) self.pending_input_registrations.append( - OmniInputRegistration( + OmniChunkRecvHandle( request_id=request.request_id, external_req_id=getattr(request, "external_req_id", None), ) @@ -232,7 +232,7 @@ def process_pending_full_payload_inputs( to_remove.append(request) self._waiting_for_input.append(request) self.pending_input_registrations.append( - OmniInputRegistration( + OmniChunkRecvHandle( request_id=request.request_id, external_req_id=getattr(request, "external_req_id", None), ) diff --git a/vllm_omni/core/sched/output.py b/vllm_omni/core/sched/output.py index 800eaf39815..29cd872998f 100644 --- a/vllm_omni/core/sched/output.py +++ b/vllm_omni/core/sched/output.py @@ -72,7 +72,7 @@ class OmniCachedRequestData(CachedRequestData): @dataclass -class OmniInputRegistration: +class OmniChunkRecvHandle: """Minimal identifier carried from scheduler to runner for chunk-recv registration. @@ -93,4 +93,4 @@ class OmniSchedulerOutput(SchedulerOutput): """Scheduler output with omni-specific transfer metadata.""" finished_requests_needing_kv_transfer: dict[str, dict] = field(default_factory=dict) - pending_input_registrations: list[OmniInputRegistration] = field(default_factory=list) + pending_input_registrations: list[OmniChunkRecvHandle] = field(default_factory=list) diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py index 3bdd9dd8934..4c7245e773f 100644 --- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py +++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py @@ -7,6 +7,7 @@ import numpy as np import torch from vllm.inputs import TextPrompt +from vllm.logger import init_logger from vllm_omni.data_entry_keys import ( CodesStruct, @@ -16,6 +17,8 @@ ) from vllm_omni.inputs.data import OmniTokensPrompt +logger = init_logger(__name__) + _COSYVOICE3_SPEECH_TOKEN_SIZE = 6561 @@ -396,7 +399,14 @@ def text2flow_full_payload( (see cosyvoice3.py:671 in the code2wav forward — runtime_info pickup). """ del transfer_manager + rid = getattr(request, "external_req_id", None) or getattr(request, "request_id", "?") if not isinstance(pooling_output, dict): + logger.warning( + "cosyvoice3.text2flow_full_payload: pooling_output not a dict " + "(type=%s) for req=%s; consumer wait gate may hang.", + type(pooling_output).__name__, + rid, + ) return None embed_out: dict[str, Any] = {} for key in ("speech_token", "speech_feat", "embedding"): @@ -408,6 +418,12 @@ def text2flow_full_payload( if isinstance(v, torch.Tensor) and v.numel() > 0: embed_out[key] = v if not embed_out: + logger.warning( + "cosyvoice3.text2flow_full_payload: no embed.{speech_token,speech_feat,embedding} " + "found in pooling_output (keys=%s) for req=%s; consumer wait gate may hang.", + list(pooling_output.keys()), + rid, + ) return None return { "meta": { diff --git a/vllm_omni/model_executor/stage_input_processors/covo_audio.py b/vllm_omni/model_executor/stage_input_processors/covo_audio.py index e56bf845b7f..7b5ed8c266d 100644 --- a/vllm_omni/model_executor/stage_input_processors/covo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/covo_audio.py @@ -2,10 +2,13 @@ from typing import Any import torch +from vllm.logger import init_logger from vllm_omni.inputs.data import OmniTokensPrompt from vllm_omni.model_executor.models.covo_audio.config_covo_audio import COVO_AUDIO_TOKEN_INDEX +logger = init_logger(__name__) + # Per-model REPLACE-keys for the full-payload accumulator (none for covo_audio: # the producer side does not emit per-step hidden_states / model_outputs; # llm2code2wav_full_payload reads token_ids directly from `request`). @@ -88,6 +91,10 @@ def llm2code2wav_full_payload( """ output_token_ids = list(getattr(request, "output_token_ids", None) or []) if not output_token_ids: + logger.warning( + "covo_audio.llm2code2wav_full_payload: empty output_token_ids for req=%s; consumer wait gate may hang.", + getattr(request, "request_id", "?"), + ) return None audio_codes = _filter_audio_codes(output_token_ids) return { diff --git a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py index cd9ddd583ff..5a5804de394 100644 --- a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py @@ -5,9 +5,12 @@ import torch from vllm.inputs import TextPrompt +from vllm.logger import init_logger from vllm_omni.inputs.data import OmniTokensPrompt +logger = init_logger(__name__) + def _to_prompt_dict(prompt_item: OmniTokensPrompt | TextPrompt | str | None) -> dict[str, Any]: if isinstance(prompt_item, dict): @@ -178,6 +181,12 @@ def _build_full_payload(pooling_output: dict[str, Any] | None, request: Any) -> if not token_ids and request is not None: token_ids = _to_token_id_list(getattr(request, "output_token_ids", None)) if not token_ids: + logger.warning( + "dynin_omni._build_full_payload: no token_ids found in pooling_output " + "(keys=%s) or request.output_token_ids for req=%s; consumer wait gate may hang.", + list(pooling_output.keys()), + getattr(request, "request_id", "?"), + ) return None src_additional_info = getattr(request, "additional_information", {}) if request is not None else {} diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index 67309967e12..2443ad2c479 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -405,7 +405,14 @@ def llm2code2wav_full_payload( is kept in case a future runtime path bypasses `flatten_payload`. """ del transfer_manager + rid = getattr(request, "request_id", "?") if not isinstance(pooling_output, dict): + logger.warning( + "mimo_audio.llm2code2wav_full_payload: pooling_output not a dict " + "(type=%s) for req=%s; consumer wait gate may hang.", + type(pooling_output).__name__, + rid, + ) return None codec_codes = pooling_output.get("codes.audio") if codec_codes is None: @@ -414,10 +421,20 @@ def llm2code2wav_full_payload( if isinstance(codes, dict): codec_codes = codes.get("audio") if not isinstance(codec_codes, torch.Tensor) or codec_codes.numel() == 0: + logger.warning( + "mimo_audio.llm2code2wav_full_payload: missing/empty codes.audio " + "(keys=%s) for req=%s; consumer wait gate may hang.", + list(pooling_output.keys()), + rid, + ) return None codec_codes = codec_codes.to(torch.long) codec_codes = _filter_zero_codec_rows(codec_codes) if codec_codes.numel() == 0: + logger.warning( + "mimo_audio.llm2code2wav_full_payload: codec_codes empty after _filter_zero_codec_rows for req=%s.", + rid, + ) return None pad_vec = torch.tensor([TALKER_CODEC_PAD_TOKEN_ID] * 4) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index cf4d6ad9780..8be1ca3b590 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -176,11 +176,22 @@ def talker2code2wav_full_payload( boundary tokens and pack a minimal payload. """ del transfer_manager + rid = getattr(request, "request_id", "?") token_ids = list(getattr(request, "output_token_ids", None) or []) if not token_ids: + logger.warning( + "qwen2_5_omni.talker2code2wav_full_payload: empty output_token_ids " + "for req=%s; consumer wait gate may hang.", + rid, + ) return None token_ids = _strip_codec_boundaries(token_ids) if not token_ids: + logger.warning( + "qwen2_5_omni.talker2code2wav_full_payload: codec ids empty after " + "stripping boundary tokens for req=%s; consumer wait gate may hang.", + rid, + ) return None return { "codes": {"audio": token_ids}, @@ -285,11 +296,24 @@ def thinker2talker_full_payload( heuristic. """ del transfer_manager + rid = getattr(request, "request_id", "?") if not isinstance(pooling_output, dict): + logger.warning( + "qwen2_5_omni.thinker2talker_full_payload: pooling_output not a dict " + "(type=%s) for req=%s; consumer wait gate may hang.", + type(pooling_output).__name__, + rid, + ) return None hidden = pooling_output.get("hidden") if not isinstance(hidden, torch.Tensor): + logger.warning( + "qwen2_5_omni.thinker2talker_full_payload: missing 'hidden' tensor " + "(keys=%s) for req=%s; consumer wait gate may hang.", + list(pooling_output.keys()), + rid, + ) return None def _ensure_list(x): @@ -352,6 +376,13 @@ def _ensure_list(x): h = hidden.detach().cpu().to(torch.float32) target_rows = max(0, len(all_token_ids) - stop_emission_drop) if target_rows <= 0: + logger.warning( + "qwen2_5_omni.thinker2talker_full_payload: target_rows<=0 " + "(all_token_ids=%d, stop_drop=%d) for req=%s; nothing to ship.", + len(all_token_ids), + stop_emission_drop, + getattr(request, "request_id", "?"), + ) return None if h.dim() >= 1 and h.shape[0] > target_rows: logger.warning( @@ -366,8 +397,15 @@ def _ensure_list(x): prompt_len = len(prompt_token_ids) if h.shape[0] < prompt_len: - # Under-captured prefill — defensively skip rather than ship a + # Under-captured prefill -- defensively skip rather than ship a # truncated payload that would confuse the talker's prefill path. + logger.warning( + "qwen2_5_omni.thinker2talker_full_payload: hidden rows=%d < prompt_len=%d " + "for req=%s; under-captured prefill, skipping payload.", + int(h.shape[0]), + prompt_len, + getattr(request, "request_id", "?"), + ) return None prefill_hidden = h[:prompt_len] diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index cb9f47a4f01..abe32a103db 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -436,7 +436,14 @@ def talker2code2wav_full_payload( crop to seq_len, prepend ref, codebook-major flatten). """ del transfer_manager + rid = getattr(request, "request_id", "?") if not isinstance(pooling_output, dict): + logger.warning( + "qwen3_tts.talker2code2wav_full_payload: pooling_output not a dict " + "(type=%s) for req=%s; consumer wait gate may hang.", + type(pooling_output).__name__, + rid, + ) return None # codes.audio — try flat dotted first (flatten_payload), then nested fallback. @@ -446,10 +453,21 @@ def talker2code2wav_full_payload( if isinstance(codes_nested, dict): audio = codes_nested.get("audio") if not isinstance(audio, torch.Tensor) or audio.numel() == 0: + logger.warning( + "qwen3_tts.talker2code2wav_full_payload: missing/empty codes.audio " + "(keys=%s) for req=%s; consumer wait gate may hang.", + list(pooling_output.keys()), + rid, + ) return None audio = audio.to(torch.long) audio = _filter_audio_codes_qwen3_tts(audio) if audio.numel() == 0: + logger.warning( + "qwen3_tts.talker2code2wav_full_payload: audio empty after codec " + "filter (negative/all-zero/out-of-range rows dropped) for req=%s.", + rid, + ) return None output_token_ids = list(getattr(request, "output_token_ids", None) or []) diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index e0ea078979e..ea89e11c1a4 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -252,9 +252,16 @@ def cleanup_finished_request(self, req_id: str) -> None: try: self.flush_full_payload_outputs({req_id}) except Exception: - # Defensive: connector may not be initialised for archs - # outside the connector init allowlist. Cleanup must still proceed. - pass + # Cleanup must still proceed regardless of flush errors here -- + # we already gated on ``_omni_connector_initialized`` upstream, + # so any exception here reflects a real connector-side issue + # (shared memory corruption, background thread crash) worth + # surfacing rather than silently swallowing. + logger.warning( + "flush_full_payload_outputs(%s) raised during cleanup; continuing tear-down.", + req_id, + exc_info=True, + ) ext_id = self._request_ids_mapping.pop(req_id, None) keys_to_clean: list[str] = [req_id] From a7eab5cd7e560eb018c258ddde547f540540c091 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Tue, 26 May 2026 09:17:03 +0000 Subject: [PATCH 13/19] minor fix Signed-off-by: natureofnature --- tests/worker/test_omni_gpu_model_runner.py | 31 +++++++++----- .../stage_input_processors/qwen3_omni.py | 40 +++++++++++++++++-- .../omni_connector_model_runner_mixin.py | 15 ++++--- 3 files changed, 67 insertions(+), 19 deletions(-) diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index b5e4ad4a4ec..9c8640b51bd 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -400,12 +400,16 @@ def _make_full_payload_accumulation_runner( model_arch="Qwen3OmniMoeForConditionalGeneration", model_stage="talker", async_chunk=False, + final_output=False, + custom_process_next_stage_input_func="module.full_payload", ): runner = object.__new__(OmniConnectorModelRunnerMixin) runner.model_config = SimpleNamespace( model_arch=model_arch, model_stage=model_stage, async_chunk=async_chunk, + final_output=final_output, + custom_process_next_stage_input_func=custom_process_next_stage_input_func, ) runner._custom_process_func = object() runner._pending_full_payload_send = {} @@ -475,22 +479,27 @@ def test_accumulate_full_payload_output_keeps_all_zero_qwen3_omni_prefill_placeh def test_full_payload_output_accumulation_hook_matrix(): - """Producer-side gate: fires iff custom_process_func is loaded and not async_chunk. + """Producer-side gate: fires iff an explicit next-stage payload hook is loaded. - The gate is a structural check on the loaded payload builder. - `_custom_process_func is None` short-circuits; - that maps to terminal stages (e.g. code2wav, qwen3_tts code2wav, qwen2_5 - code2wav) whose stage_config has no `custom_process_next_stage_input_func` - and no `*_full_payload` derivative of `custom_process_input_func`. + A derived `*_full_payload` helper from `custom_process_input_func` is not + enough: terminal/input-only consumer stages must not enqueue orphan + downstream payloads. """ - # Thinker / talker producer stages: payload builder loaded -> gate fires. + # Thinker / talker producer stages: explicit next-stage payload hook -> gate fires. assert _make_full_payload_accumulation_runner(model_stage="thinker")._should_accumulate_full_payload_output() assert _make_full_payload_accumulation_runner(model_stage="talker")._should_accumulate_full_payload_output() - # Terminal stage: emulate `_load_custom_func` returning None (no downstream). - runner = _make_full_payload_accumulation_runner(model_stage="code2wav") - runner._custom_process_func = None - runner._should_accumulate_full_payload_output_cached = None + # Terminal stage: even if _load_custom_func derived a builder from + # custom_process_input_func, final output stages are not producers. + runner = _make_full_payload_accumulation_runner(model_stage="code2wav", final_output=True) + assert not runner._should_accumulate_full_payload_output() + + # Input-only consumer stage without an explicit producer hook must not + # accumulate/send just because a same-module *_full_payload helper exists. + runner = _make_full_payload_accumulation_runner( + model_stage="token2audio", + custom_process_next_stage_input_func=None, + ) assert not runner._should_accumulate_full_payload_output() # async_chunk mode -> gate off. diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 77a9af6c0c4..d2f7aa5c34f 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -455,7 +455,13 @@ def thinker2talker_full_payload( request: OmniEngineCoreRequest, ) -> dict[str, Any] | None: """Pack complete thinker output for the non-async connector path.""" + rid = getattr(request, "request_id", None) if not isinstance(pooling_output, dict): + logger.warning( + "thinker2talker_full_payload: pooling_output not a dict (type=%s) for req=%s; consumer wait gate may hang.", + type(pooling_output).__name__, + rid, + ) return None layers = { @@ -468,11 +474,13 @@ def thinker2talker_full_payload( hidden = pooling_output.get("hidden") thinker_emb = hidden if isinstance(hidden, torch.Tensor) else None if thinker_emb is None or thinker_hid is None: - logger.debug( - "thinker2talker_full_payload: missing thinker tensors for req=%s (embed=%s hidden=%s)", - getattr(request, "request_id", None), + logger.warning( + "thinker2talker_full_payload: missing thinker tensors for req=%s " + "(embed=%s hidden=%s keys=%s); consumer wait gate may hang.", + rid, thinker_emb is not None, thinker_hid is not None, + list(pooling_output.keys()), ) return None @@ -860,7 +868,14 @@ def talker2code2wav_full_payload( request: OmniEngineCoreRequest, ) -> dict[str, Any] | None: """Pack complete talker codec output for the non-async connector path.""" + rid = getattr(request, "request_id", None) if not isinstance(pooling_output, dict): + logger.warning( + "talker2code2wav_full_payload: pooling_output not a dict " + "(type=%s) for req=%s; consumer wait gate may hang.", + type(pooling_output).__name__, + rid, + ) return None code_predictor_codes = pooling_output.get("codes.audio") if code_predictor_codes is None: @@ -868,10 +883,19 @@ def talker2code2wav_full_payload( if isinstance(codes, dict): code_predictor_codes = codes.get("audio") if code_predictor_codes is None: + logger.warning( + "talker2code2wav_full_payload: missing codes.audio (keys=%s) for req=%s; consumer wait gate may hang.", + list(pooling_output.keys()), + rid, + ) return None if not isinstance(code_predictor_codes, torch.Tensor): code_predictor_codes = torch.as_tensor(code_predictor_codes) if code_predictor_codes.numel() == 0: + logger.warning( + "talker2code2wav_full_payload: empty codes.audio for req=%s; consumer wait gate may hang.", + rid, + ) return None output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) @@ -881,6 +905,16 @@ def talker2code2wav_full_payload( list(output_token_ids), ) if code_predictor_codes.numel() == 0: + logger.warning( + "talker2code2wav_full_payload: no valid codec rows after filtering " + "(raw_shape=%s output_ids_len=%d aligned_rows=%s valid_rows=%s) for req=%s; " + "consumer wait gate may hang.", + raw_shape, + len(output_token_ids), + codec_stats["aligned_rows"], + codec_stats["valid_rows"], + rid, + ) return None codec_codes = code_predictor_codes.transpose(0, 1).cpu().reshape(-1).tolist() diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index ea89e11c1a4..74dd9aa9b70 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -49,16 +49,21 @@ def should_accumulate_full_payload_output(model_config, custom_process_func) -> bool: """Producer-side structural gate. - Fires iff the worker has a connector payload builder loaded - (``custom_process_func`` resolved via ``_load_custom_func`` from the - stage_config's ``custom_process_next_stage_input_func`` or the - ``*_full_payload`` derivative of ``custom_process_input_func``), the - stage is not in async_chunk mode, and ``model_stage`` is set. + Fires iff the stage explicitly declares a downstream full-payload + producer hook via ``custom_process_next_stage_input_func``. Consumer + stages may have ``custom_process_input_func`` values that can be + mechanically derived to ``*_full_payload`` helper names in the same + module; those are intentionally not enough to make the stage a producer. """ if custom_process_func is None: return False if getattr(model_config, "async_chunk", False): return False + if getattr(model_config, "final_output", False): + return False + next_stage_func = getattr(model_config, "custom_process_next_stage_input_func", None) + if not isinstance(next_stage_func, str) or not next_stage_func: + return False return getattr(model_config, "model_stage", None) is not None From 21dcb1aaa55a2c6e17ac4719a49ac8b17d35381c Mon Sep 17 00:00:00 2001 From: natureofnature Date: Tue, 26 May 2026 14:15:18 +0000 Subject: [PATCH 14/19] minor fix Signed-off-by: natureofnature --- .../test_qwen3_omni_streaming_helpers.py | 40 ++++ .../stage_input_processors/qwen2_5_omni.py | 32 +-- .../stage_input_processors/qwen3_omni.py | 184 +++++++++--------- 3 files changed, 147 insertions(+), 109 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index c01e098373e..179ab92b3d9 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -628,6 +628,46 @@ def test_qwen2_5_omni_talker2code2wav_full_payload_smoke() -> None: assert payload["meta"]["finished"].item() is True +def test_qwen2_5_omni_talker2code2wav_filters_control_tokens_and_placeholders() -> None: + """Qwen2.5 code2wav receives codec ids only, not talker prompt/control ids.""" + from types import SimpleNamespace + + from vllm_omni.model_executor.stage_input_processors.qwen2_5_omni import ( + TALKER_CODEC_END_TOKEN_ID, + TALKER_CODEC_PAD_TOKEN_ID, + TALKER_CODEC_START_TOKEN_ID, + talker2code2wav_full_payload, + talker2code2wav_token_only, + ) + + class _Out: + def __init__(self, tids): + self.cumulative_token_ids = tids + + class _Wrap: + def __init__(self, tids): + self.outputs = [_Out(tids)] + + raw_ids = [ + TALKER_CODEC_START_TOKEN_ID, + TALKER_CODEC_PAD_TOKEN_ID, + 5, + 6, + TALKER_CODEC_END_TOKEN_ID, + -1, + -1, + ] + + token_only = talker2code2wav_token_only([_Wrap(raw_ids)]) + assert len(token_only) == 1 + assert len(token_only[0]["prompt_token_ids"]) == 4 + + payload = talker2code2wav_full_payload(None, {}, SimpleNamespace(output_token_ids=raw_ids)) + assert payload is not None + assert payload["codes"]["audio"] == [5, 6, 6, 6] + assert payload["meta"]["finished"].item() is True + + def test_mimo_audio_llm2code2wav_token_only_smoke() -> None: """Smoke: mimo_audio token-only builder sizes prompt.""" import torch diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index 8be1ca3b590..5046fa649f2 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -115,25 +115,27 @@ def talker2code2wav( def _strip_codec_boundaries(token_ids: list[int]) -> list[int]: - """Drop TALKER_CODEC_START/END boundary tokens (mirror talker2code2wav) - and filter sentinel/invalid codec ids. - - The talker emits codec ids on `request.output_token_ids`. Negative ids - (e.g., -1) appear as "stopped early" / "no token sampled this step" - sentinels and are NOT valid codec embedding indices. Passing -1 to - `torch.embedding` triggers a CUDA gather-kernel OOB assert in the - code2wav stage (`vectorized_gather_kernel index out of bounds`). We - filter them here at the producer-side strip so the worker connector - payload only ships valid codec ids. + """Keep only real codec ids for the code2wav stage. + + The talker stream can contain prompt/control ids (START/PAD/END/MASK) in + addition to sampled codec ids. Code2wav expects codec ids only; carrying + the prompt PAD span forward can inflate the sequence enough to OOM on L4. + Async scheduling may also leave trailing ``-1`` placeholders, so preserve + their length by repeating the last valid codec id. """ tids = list(token_ids) - if tids and tids[0] == TALKER_CODEC_START_TOKEN_ID: - tids = tids[1:] + trailing_placeholder_count = 0 + while trailing_placeholder_count < len(tids) and tids[-1 - trailing_placeholder_count] == -1: + trailing_placeholder_count += 1 + if tids and tids[-1] == TALKER_CODEC_END_TOKEN_ID: tids = tids[:-1] - # Filter negative sentinel ids that the talker engine may insert. - tids = [t for t in tids if t >= 0] - return tids + trailing_placeholder_count = 0 + + codec_ids = [tid for tid in tids if 0 <= tid < TALKER_CODEC_PAD_TOKEN_ID] + if trailing_placeholder_count > 0 and codec_ids: + codec_ids.extend([codec_ids[-1]] * trailing_placeholder_count) + return codec_ids def talker2code2wav_token_only( diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index d2f7aa5c34f..9f7a38bba23 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -125,6 +125,86 @@ def _is_valid_qwen3_codec_token_id(token_id: Any) -> bool: return 0 <= token_id < _QWEN3_CODEC_CODEBOOK_SIZE +def _prepare_qwen3_talker_prefill( + request: Any, + thinker_emb: torch.Tensor, + thinker_hid: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, list[Any], list[Any]]: + """Build token ids and trim thinker rows that talker should consume.""" + prompt_token_ids = _ensure_list(getattr(request, "prompt_token_ids", []) or []) + output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) + all_token_ids = _ensure_list(getattr(request, "all_token_ids", None) or []) + if not all_token_ids: + all_token_ids = list(prompt_token_ids) + list(output_token_ids) + + # vLLM appends sampled tokens before stop checking. Stop-finished requests + # therefore have one final hidden row that the talker should not consume. + status = getattr(request, "status", None) + status_name = getattr(status, "name", None) or "" + if not status_name and status is not None: + status_name = str(status).rsplit(".", 1)[-1] + stop_emission_drop = 1 if status_name == "FINISHED_STOPPED" else 0 + + if stop_emission_drop == 0 and not status_name and output_token_ids: + sampling_params = getattr(request, "sampling_params", None) + if sampling_params is not None: + stop_ids: set[int] = set() + for sid in getattr(sampling_params, "stop_token_ids", None) or (): + if isinstance(sid, int): + stop_ids.add(sid) + if not bool(getattr(sampling_params, "ignore_eos", False)): + for eos in ( + getattr(sampling_params, "eos_token_id", None), + getattr(sampling_params, "_eos_token_id", None), + ): + if isinstance(eos, int): + stop_ids.add(eos) + for sid in ( + getattr(sampling_params, "all_stop_token_ids", None) + or getattr(sampling_params, "_all_stop_token_ids", None) + or () + ): + if isinstance(sid, int): + stop_ids.add(sid) + if stop_ids and output_token_ids[-1] in stop_ids: + stop_emission_drop = 1 + + target_rows = max(0, len(all_token_ids) - stop_emission_drop) + request_id = getattr(request, "request_id", None) + + def _trim(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dim() < 1 or tensor.shape[0] == 0 or target_rows <= 0: + return tensor + if tensor.shape[0] > target_rows + 1: + logger.warning( + "thinker2talker_full_payload: unexpected excess rows " + "(got %d, target %d, stop_drop %d) for req=%s; trimming to target", + int(tensor.shape[0]), + target_rows, + stop_emission_drop, + request_id, + ) + if tensor.shape[0] > target_rows: + return tensor[:target_rows] + if tensor.shape[0] < target_rows: + logger.debug( + "thinker2talker_full_payload: under-captured rows " + "(got %d, target %d, stop_drop %d) for req=%s; talker may index past end", + int(tensor.shape[0]), + target_rows, + stop_emission_drop, + request_id, + ) + return tensor + + return ( + _trim(thinker_emb), + _trim(thinker_hid), + list(prompt_token_ids), + list(all_token_ids), + ) + + def _extract_qwen3_full_payload_codec_rows( code_predictor_codes: torch.Tensor, output_token_ids: list[int], @@ -484,100 +564,16 @@ def thinker2talker_full_payload( ) return None - prompt_token_ids = _ensure_list(getattr(request, "prompt_token_ids", []) or []) - output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) - all_token_ids = _ensure_list(getattr(request, "all_token_ids", None) or []) - if not all_token_ids: - all_token_ids = list(prompt_token_ids) + list(output_token_ids) - - # Length-aware trim of accumulated thinker output, finish-reason-aware. - # vLLM appends the sampled token to `output_token_ids` BEFORE - # `check_stop`, so a stop-finished request has accumulator_rows - # == len(all_token_ids) including the stop emission row -- the - # talker must NOT consume that row. Max-token finishes do not - # append an extra forward, so no drop is needed. Primary: - # distinguish via `request.status`. Fallback only when status - # is absent: last-token-in-stop-id heuristic. - status = getattr(request, "status", None) - status_name = getattr(status, "name", None) or "" - if not status_name and status is not None: - status_name = str(status).rsplit(".", 1)[-1] - stop_emission_drop = 1 if status_name == "FINISHED_STOPPED" else 0 - if stop_emission_drop == 0 and not status_name and output_token_ids: - # Worker-side CachedRequestState has no `.status` field in vLLM - # v1, so this fallback runs for every production request. When - # `sampling_params.ignore_eos=True` vLLM continues past EOS, so - # a length-capped finish whose last sampled token coincidentally - # equals EOS must NOT be trimmed -- skip EOS from the stop set - # in that case. Custom `stop_token_ids` are still treated as - # stops; vLLM's `check_stop` runs stop-id matching before the - # length cap and ignores `ignore_eos` for `stop_token_ids`, so - # a last-token match there is unambiguously a stop finish. - sampling_params = getattr(request, "sampling_params", None) - if sampling_params is not None: - stop_ids: set[int] = set() - ignore_eos = bool(getattr(sampling_params, "ignore_eos", False)) - # Custom stop_token_ids always trigger stop in vLLM, regardless - # of ignore_eos (vLLM v1: `update_from_generation_config` writes - # secondary EOSes here too). Read the public list. - for sid in getattr(sampling_params, "stop_token_ids", None) or (): - if isinstance(sid, int): - stop_ids.add(sid) - # EOS sources are only stops when ignore_eos=False. Read both - # the public @property (`eos_token_id`, `all_stop_token_ids`) - # AND the private fields (`_eos_token_id`, `_all_stop_token_ids`) - # because property behavior can vary across msgspec serialization - # boundaries while the private fields are always serialized. - if not ignore_eos: - for eos in ( - getattr(sampling_params, "eos_token_id", None), - getattr(sampling_params, "_eos_token_id", None), - ): - if isinstance(eos, int): - stop_ids.add(eos) - for sid in ( - getattr(sampling_params, "all_stop_token_ids", None) - or getattr(sampling_params, "_all_stop_token_ids", None) - or () - ): - if isinstance(sid, int): - stop_ids.add(sid) - if stop_ids and output_token_ids[-1] in stop_ids: - stop_emission_drop = 1 - target_rows = max(0, len(all_token_ids) - stop_emission_drop) - - def _trim_to_target(t): - if not isinstance(t, torch.Tensor) or t.dim() < 1 or t.shape[0] == 0: - return t - if target_rows <= 0: - # Defensive: empty prompt+output (or stop-only output) should - # not reach this builder; keep all rows rather than slicing - # to zero. - return t - if t.shape[0] > target_rows + 1: - logger.warning( - "thinker2talker_full_payload: unexpected excess rows " - "(got %d, target %d, stop_drop %d) for req=%s; trimming to target", - int(t.shape[0]), - target_rows, - stop_emission_drop, - getattr(request, "request_id", None), - ) - if t.shape[0] > target_rows: - return t[:target_rows] - if t.shape[0] < target_rows: - logger.debug( - "thinker2talker_full_payload: under-captured rows " - "(got %d, target %d, stop_drop %d) for req=%s; talker may index past end", - int(t.shape[0]), - target_rows, - stop_emission_drop, - getattr(request, "request_id", None), - ) - return t - - thinker_emb_prefill = _trim_to_target(thinker_emb) - thinker_hid_prefill = _trim_to_target(thinker_hid) + ( + thinker_emb_prefill, + thinker_hid_prefill, + prompt_token_ids, + all_token_ids, + ) = _prepare_qwen3_talker_prefill( + request, + thinker_emb, + thinker_hid, + ) payload: OmniPayload = { "embed": { From a397d09a89870af61fe1970d4b7f97157f076b8d Mon Sep 17 00:00:00 2001 From: natureofnature Date: Tue, 26 May 2026 17:44:39 +0000 Subject: [PATCH 15/19] minor fix Signed-off-by: natureofnature --- .../test_qwen3_omni_streaming_helpers.py | 23 +++++ .../stage_input_processors/qwen3_omni.py | 93 ++++++++++++------- 2 files changed, 85 insertions(+), 31 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index 179ab92b3d9..fa25a725eaf 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -246,6 +246,29 @@ def test_thinker2talker_full_payload_drops_stop_emission_row_when_finished_stopp assert payload["hidden_states"]["output"].shape[0] == 2 +def test_thinker2talker_full_payload_returns_empty_when_stop_consumes_all_rows() -> None: + """If the only generated row is a stop emission, ship empty tensors.""" + request = SimpleNamespace( + request_id="thinker-stop-only", + prompt_token_ids=[], + output_token_ids=[151645], + all_token_ids=[151645], + sampling_params=None, + status=SimpleNamespace(name="FINISHED_STOPPED"), + ) + pooling_output = { + "hidden_states.layer_0": torch.ones(1, 2), + "hidden_states.layer_24": torch.full((1, 2), 2.0), + "embed.tts_bos": torch.zeros(1, 2), + } + + payload = q3.thinker2talker_full_payload(None, pooling_output, request) + + assert payload is not None + assert payload["embed"]["prefill"].shape[0] == 0 + assert payload["hidden_states"]["output"].shape[0] == 0 + + def test_thinker2talker_full_payload_drops_stop_emission_via_eos_fallback() -> None: """Stop-detection fallback: last token in sampling_params.eos_token_id.""" EOS = 151645 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 9f7a38bba23..2818cf52a85 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -125,6 +125,57 @@ def _is_valid_qwen3_codec_token_id(token_id: Any) -> bool: return 0 <= token_id < _QWEN3_CODEC_CODEBOOK_SIZE +def _qwen3_status_name(request: Any) -> str: + status = getattr(request, "status", None) + status_name = getattr(status, "name", None) or "" + if not status_name and status is not None: + status_name = str(status).rsplit(".", 1)[-1] + return status_name + + +def _qwen3_stop_token_ids(sampling_params: Any) -> set[int]: + stop_ids: set[int] = set() + for sid in getattr(sampling_params, "stop_token_ids", None) or (): + if isinstance(sid, int): + stop_ids.add(sid) + + if bool(getattr(sampling_params, "ignore_eos", False)): + return stop_ids + + for eos in ( + getattr(sampling_params, "eos_token_id", None), + getattr(sampling_params, "_eos_token_id", None), + ): + if isinstance(eos, int): + stop_ids.add(eos) + for sid in ( + getattr(sampling_params, "all_stop_token_ids", None) + or getattr(sampling_params, "_all_stop_token_ids", None) + or () + ): + if isinstance(sid, int): + stop_ids.add(sid) + return stop_ids + + +def _qwen3_stop_emission_drop(request: Any, output_token_ids: list[Any]) -> int: + """Return 1 when the final thinker row is the stop-token emission.""" + status_name = _qwen3_status_name(request) + if status_name == "FINISHED_STOPPED": + return 1 + if status_name or not output_token_ids: + # Explicit non-stop finishes, especially FINISHED_LENGTH_CAPPED, should + # keep all generated rows. + return 0 + + sampling_params = getattr(request, "sampling_params", None) + if sampling_params is None: + return 0 + + stop_ids = _qwen3_stop_token_ids(sampling_params) + return 1 if stop_ids and output_token_ids[-1] in stop_ids else 0 + + def _prepare_qwen3_talker_prefill( request: Any, thinker_emb: torch.Tensor, @@ -139,42 +190,22 @@ def _prepare_qwen3_talker_prefill( # vLLM appends sampled tokens before stop checking. Stop-finished requests # therefore have one final hidden row that the talker should not consume. - status = getattr(request, "status", None) - status_name = getattr(status, "name", None) or "" - if not status_name and status is not None: - status_name = str(status).rsplit(".", 1)[-1] - stop_emission_drop = 1 if status_name == "FINISHED_STOPPED" else 0 - - if stop_emission_drop == 0 and not status_name and output_token_ids: - sampling_params = getattr(request, "sampling_params", None) - if sampling_params is not None: - stop_ids: set[int] = set() - for sid in getattr(sampling_params, "stop_token_ids", None) or (): - if isinstance(sid, int): - stop_ids.add(sid) - if not bool(getattr(sampling_params, "ignore_eos", False)): - for eos in ( - getattr(sampling_params, "eos_token_id", None), - getattr(sampling_params, "_eos_token_id", None), - ): - if isinstance(eos, int): - stop_ids.add(eos) - for sid in ( - getattr(sampling_params, "all_stop_token_ids", None) - or getattr(sampling_params, "_all_stop_token_ids", None) - or () - ): - if isinstance(sid, int): - stop_ids.add(sid) - if stop_ids and output_token_ids[-1] in stop_ids: - stop_emission_drop = 1 - + stop_emission_drop = _qwen3_stop_emission_drop(request, output_token_ids) target_rows = max(0, len(all_token_ids) - stop_emission_drop) request_id = getattr(request, "request_id", None) def _trim(tensor: torch.Tensor) -> torch.Tensor: - if tensor.dim() < 1 or tensor.shape[0] == 0 or target_rows <= 0: + if tensor.dim() < 1 or tensor.shape[0] == 0: return tensor + if target_rows <= 0: + logger.warning( + "thinker2talker_full_payload: target_rows<=0 " + "(all_token_ids=%d, stop_drop=%d) for req=%s; shipping empty tensor", + len(all_token_ids), + stop_emission_drop, + request_id, + ) + return tensor[:0] if tensor.shape[0] > target_rows + 1: logger.warning( "thinker2talker_full_payload: unexpected excess rows " From 73a78e915f237b4830cb287439caa153a46ed38a Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 27 May 2026 00:50:02 +0000 Subject: [PATCH 16/19] revert qwen3 omni trim Signed-off-by: natureofnature --- .../test_qwen3_omni_streaming_helpers.py | 274 +----------------- .../stage_input_processors/qwen3_omni.py | 145 ++------- 2 files changed, 25 insertions(+), 394 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index fa25a725eaf..2e44a8e7776 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -168,7 +168,7 @@ def test_talker2code2wav_full_payload_keeps_all_zero_codec_rows() -> None: def test_thinker2talker_full_payload_packs_complete_tensors() -> None: - """Standard max_tokens finish path: rows == target → no trim.""" + """Full-payload path drops the terminal thinker row before talker prefill.""" request = SimpleNamespace( request_id="thinker", prompt_token_ids=[151644, 872], @@ -188,278 +188,6 @@ def test_thinker2talker_full_payload_packs_complete_tensors() -> None: assert payload["embed"]["prefill"].device.type == "cpu" assert payload["hidden_states"]["output"].device.type == "cpu" assert payload["next_stage_prompt_len"] > 0 - # Lock down the no-trim invariant for rows == target. - assert payload["embed"]["prefill"].shape[0] == 3 - assert payload["hidden_states"]["output"].shape[0] == 3 - - -def test_thinker2talker_full_payload_trims_excess_stop_token_row() -> None: - """Excess-rows path: rows == target + 1 → trim trailing row.""" - request = SimpleNamespace( - request_id="thinker-excess", - prompt_token_ids=[151644, 872], - output_token_ids=[3], - all_token_ids=[151644, 872, 3], - sampling_params=None, - status=None, - ) - pooling_output = { - "hidden_states.layer_0": torch.ones(4, 2), - "hidden_states.layer_24": torch.full((4, 2), 2.0), - "embed.tts_bos": torch.zeros(1, 2), - } - - payload = q3.thinker2talker_full_payload(None, pooling_output, request) - - assert payload is not None - assert payload["embed"]["prefill"].shape[0] == 3 - assert payload["hidden_states"]["output"].shape[0] == 3 - - -def test_thinker2talker_full_payload_drops_stop_emission_row_when_finished_stopped() -> None: - """FINISHED_STOPPED: drop 1 extra row even when rows == target. - - vLLM appends the stop-token to output_token_ids before check_stop, so - len(all_token_ids) includes the stop slot AND the full-payload - accumulator has the stop emission's forward row. Both counts equal P+O (here 3). Talker - target should be P+O-1 (=2), not P+O. Without the extra drop the - stop emission's hidden state leaks into talker prefill. - """ - request = SimpleNamespace( - request_id="thinker-stop-finished", - prompt_token_ids=[151644, 872], - output_token_ids=[3], - all_token_ids=[151644, 872, 3], - sampling_params=None, - status=SimpleNamespace(name="FINISHED_STOPPED"), - ) - pooling_output = { - "hidden_states.layer_0": torch.ones(3, 2), - "hidden_states.layer_24": torch.full((3, 2), 2.0), - "embed.tts_bos": torch.zeros(1, 2), - } - - payload = q3.thinker2talker_full_payload(None, pooling_output, request) - - assert payload is not None - assert payload["embed"]["prefill"].shape[0] == 2 - assert payload["hidden_states"]["output"].shape[0] == 2 - - -def test_thinker2talker_full_payload_returns_empty_when_stop_consumes_all_rows() -> None: - """If the only generated row is a stop emission, ship empty tensors.""" - request = SimpleNamespace( - request_id="thinker-stop-only", - prompt_token_ids=[], - output_token_ids=[151645], - all_token_ids=[151645], - sampling_params=None, - status=SimpleNamespace(name="FINISHED_STOPPED"), - ) - pooling_output = { - "hidden_states.layer_0": torch.ones(1, 2), - "hidden_states.layer_24": torch.full((1, 2), 2.0), - "embed.tts_bos": torch.zeros(1, 2), - } - - payload = q3.thinker2talker_full_payload(None, pooling_output, request) - - assert payload is not None - assert payload["embed"]["prefill"].shape[0] == 0 - assert payload["hidden_states"]["output"].shape[0] == 0 - - -def test_thinker2talker_full_payload_drops_stop_emission_via_eos_fallback() -> None: - """Stop-detection fallback: last token in sampling_params.eos_token_id.""" - EOS = 151645 - request = SimpleNamespace( - request_id="thinker-stop-fallback", - prompt_token_ids=[151644, 872], - output_token_ids=[3, EOS], - all_token_ids=[151644, 872, 3, EOS], - sampling_params=SimpleNamespace(eos_token_id=EOS, stop_token_ids=None), - status=None, - ) - pooling_output = { - "hidden_states.layer_0": torch.ones(4, 2), - "hidden_states.layer_24": torch.full((4, 2), 2.0), - "embed.tts_bos": torch.zeros(1, 2), - } - - payload = q3.thinker2talker_full_payload(None, pooling_output, request) - - assert payload is not None - assert payload["embed"]["prefill"].shape[0] == 3 - assert payload["hidden_states"]["output"].shape[0] == 3 - - -def test_thinker2talker_full_payload_no_drop_when_finished_length_capped() -> None: - """FINISHED_LENGTH_CAPPED (max_tokens): no extra drop applied.""" - request = SimpleNamespace( - request_id="thinker-length-capped", - prompt_token_ids=[151644, 872], - output_token_ids=[3], - all_token_ids=[151644, 872, 3], - sampling_params=SimpleNamespace(eos_token_id=999, stop_token_ids=None), - status=SimpleNamespace(name="FINISHED_LENGTH_CAPPED"), - ) - pooling_output = { - "hidden_states.layer_0": torch.ones(3, 2), - "hidden_states.layer_24": torch.full((3, 2), 2.0), - "embed.tts_bos": torch.zeros(1, 2), - } - - payload = q3.thinker2talker_full_payload(None, pooling_output, request) - - assert payload is not None - assert payload["embed"]["prefill"].shape[0] == 3 - assert payload["hidden_states"]["output"].shape[0] == 3 - - -def test_thinker2talker_full_payload_drops_via_private_eos_field() -> None: - """Worker-side sampling_params where the public `eos_token_id` property is - None but the private `_eos_token_id` / `_all_stop_token_ids` carry the - primary EOS (the msgspec-deserialization shape on the worker boundary). - - The fallback must read the private fields to detect the stop. - """ - EOS = 151643 - request = SimpleNamespace( - request_id="thinker-private-eos", - prompt_token_ids=[151644, 872], - output_token_ids=[3, EOS], - all_token_ids=[151644, 872, 3, EOS], - # Public `eos_token_id` looks empty; only the private fields carry it. - sampling_params=SimpleNamespace( - eos_token_id=None, - stop_token_ids=None, - ignore_eos=False, - _eos_token_id=EOS, - _all_stop_token_ids={EOS}, - ), - status=None, - ) - pooling_output = { - "hidden_states.layer_0": torch.ones(4, 2), - "hidden_states.layer_24": torch.full((4, 2), 2.0), - "embed.tts_bos": torch.zeros(1, 2), - } - - payload = q3.thinker2talker_full_payload(None, pooling_output, request) - - assert payload is not None - assert payload["embed"]["prefill"].shape[0] == 3 - assert payload["hidden_states"]["output"].shape[0] == 3 - - -def test_thinker2talker_full_payload_drops_via_all_stop_token_ids() -> None: - """Secondary EOS only in `_all_stop_token_ids` (not in `_eos_token_id`): - multi-EOS Qwen3 case where the model finished on a secondary EOS. - """ - SECONDARY_EOS = 151645 - request = SimpleNamespace( - request_id="thinker-secondary-eos", - prompt_token_ids=[151644, 872], - output_token_ids=[3, SECONDARY_EOS], - all_token_ids=[151644, 872, 3, SECONDARY_EOS], - sampling_params=SimpleNamespace( - eos_token_id=151643, # primary, not the one we hit - stop_token_ids=None, - ignore_eos=False, - _eos_token_id=151643, - _all_stop_token_ids={151643, SECONDARY_EOS}, - ), - status=None, - ) - pooling_output = { - "hidden_states.layer_0": torch.ones(4, 2), - "hidden_states.layer_24": torch.full((4, 2), 2.0), - "embed.tts_bos": torch.zeros(1, 2), - } - - payload = q3.thinker2talker_full_payload(None, pooling_output, request) - - assert payload is not None - assert payload["embed"]["prefill"].shape[0] == 3 - assert payload["hidden_states"]["output"].shape[0] == 3 - - -def test_thinker2talker_full_payload_no_drop_when_ignore_eos_and_trailing_eos() -> None: - """ignore_eos=True + length-capped + last token == EOS: no drop. - - Production worker uses CachedRequestState (no `.status` field), so - the status path doesn't catch this case; we rely on the - `sampling_params.ignore_eos` flag in the fallback to suppress the - EOS-as-stop heuristic. - """ - EOS = 151645 - request = SimpleNamespace( - request_id="thinker-ignore-eos-trailing-eos", - prompt_token_ids=[151644, 872], - output_token_ids=[3, EOS], - all_token_ids=[151644, 872, 3, EOS], - sampling_params=SimpleNamespace(eos_token_id=EOS, stop_token_ids=None, ignore_eos=True), - status=None, # production worker state has no status - ) - pooling_output = { - "hidden_states.layer_0": torch.ones(4, 2), - "hidden_states.layer_24": torch.full((4, 2), 2.0), - "embed.tts_bos": torch.zeros(1, 2), - } - - payload = q3.thinker2talker_full_payload(None, pooling_output, request) - - assert payload is not None - assert payload["embed"]["prefill"].shape[0] == 4 - assert payload["hidden_states"]["output"].shape[0] == 4 - - -def test_thinker2talker_full_payload_no_drop_when_length_capped_with_trailing_eos() -> None: - """FINISHED_LENGTH_CAPPED + last token == EOS coincidence: no drop. - - Status path takes precedence over last-token heuristic. Without - this guard the fallback would incorrectly drop a row when a length-capped - request happens to end on the EOS token id. - """ - EOS = 151645 - request = SimpleNamespace( - request_id="thinker-len-cap-trailing-eos", - prompt_token_ids=[151644, 872], - output_token_ids=[3, EOS], - all_token_ids=[151644, 872, 3, EOS], - sampling_params=SimpleNamespace(eos_token_id=EOS, stop_token_ids=None), - status=SimpleNamespace(name="FINISHED_LENGTH_CAPPED"), - ) - pooling_output = { - "hidden_states.layer_0": torch.ones(4, 2), - "hidden_states.layer_24": torch.full((4, 2), 2.0), - "embed.tts_bos": torch.zeros(1, 2), - } - - payload = q3.thinker2talker_full_payload(None, pooling_output, request) - - assert payload is not None - assert payload["embed"]["prefill"].shape[0] == 4 - assert payload["hidden_states"]["output"].shape[0] == 4 - - -def test_thinker2talker_full_payload_preserves_under_capture() -> None: - """Under-capture path: rows < target → no trim, safe degrade.""" - request = SimpleNamespace( - request_id="thinker-undercap", - prompt_token_ids=[151644, 872], - output_token_ids=[3], - all_token_ids=[151644, 872, 3], - ) - pooling_output = { - "hidden_states.layer_0": torch.ones(2, 2), - "hidden_states.layer_24": torch.full((2, 2), 2.0), - "embed.tts_bos": torch.zeros(1, 2), - } - - payload = q3.thinker2talker_full_payload(None, pooling_output, request) - - assert payload is not None assert payload["embed"]["prefill"].shape[0] == 2 assert payload["hidden_states"]["output"].shape[0] == 2 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 2818cf52a85..b671951b201 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -125,117 +125,6 @@ def _is_valid_qwen3_codec_token_id(token_id: Any) -> bool: return 0 <= token_id < _QWEN3_CODEC_CODEBOOK_SIZE -def _qwen3_status_name(request: Any) -> str: - status = getattr(request, "status", None) - status_name = getattr(status, "name", None) or "" - if not status_name and status is not None: - status_name = str(status).rsplit(".", 1)[-1] - return status_name - - -def _qwen3_stop_token_ids(sampling_params: Any) -> set[int]: - stop_ids: set[int] = set() - for sid in getattr(sampling_params, "stop_token_ids", None) or (): - if isinstance(sid, int): - stop_ids.add(sid) - - if bool(getattr(sampling_params, "ignore_eos", False)): - return stop_ids - - for eos in ( - getattr(sampling_params, "eos_token_id", None), - getattr(sampling_params, "_eos_token_id", None), - ): - if isinstance(eos, int): - stop_ids.add(eos) - for sid in ( - getattr(sampling_params, "all_stop_token_ids", None) - or getattr(sampling_params, "_all_stop_token_ids", None) - or () - ): - if isinstance(sid, int): - stop_ids.add(sid) - return stop_ids - - -def _qwen3_stop_emission_drop(request: Any, output_token_ids: list[Any]) -> int: - """Return 1 when the final thinker row is the stop-token emission.""" - status_name = _qwen3_status_name(request) - if status_name == "FINISHED_STOPPED": - return 1 - if status_name or not output_token_ids: - # Explicit non-stop finishes, especially FINISHED_LENGTH_CAPPED, should - # keep all generated rows. - return 0 - - sampling_params = getattr(request, "sampling_params", None) - if sampling_params is None: - return 0 - - stop_ids = _qwen3_stop_token_ids(sampling_params) - return 1 if stop_ids and output_token_ids[-1] in stop_ids else 0 - - -def _prepare_qwen3_talker_prefill( - request: Any, - thinker_emb: torch.Tensor, - thinker_hid: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, list[Any], list[Any]]: - """Build token ids and trim thinker rows that talker should consume.""" - prompt_token_ids = _ensure_list(getattr(request, "prompt_token_ids", []) or []) - output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) - all_token_ids = _ensure_list(getattr(request, "all_token_ids", None) or []) - if not all_token_ids: - all_token_ids = list(prompt_token_ids) + list(output_token_ids) - - # vLLM appends sampled tokens before stop checking. Stop-finished requests - # therefore have one final hidden row that the talker should not consume. - stop_emission_drop = _qwen3_stop_emission_drop(request, output_token_ids) - target_rows = max(0, len(all_token_ids) - stop_emission_drop) - request_id = getattr(request, "request_id", None) - - def _trim(tensor: torch.Tensor) -> torch.Tensor: - if tensor.dim() < 1 or tensor.shape[0] == 0: - return tensor - if target_rows <= 0: - logger.warning( - "thinker2talker_full_payload: target_rows<=0 " - "(all_token_ids=%d, stop_drop=%d) for req=%s; shipping empty tensor", - len(all_token_ids), - stop_emission_drop, - request_id, - ) - return tensor[:0] - if tensor.shape[0] > target_rows + 1: - logger.warning( - "thinker2talker_full_payload: unexpected excess rows " - "(got %d, target %d, stop_drop %d) for req=%s; trimming to target", - int(tensor.shape[0]), - target_rows, - stop_emission_drop, - request_id, - ) - if tensor.shape[0] > target_rows: - return tensor[:target_rows] - if tensor.shape[0] < target_rows: - logger.debug( - "thinker2talker_full_payload: under-captured rows " - "(got %d, target %d, stop_drop %d) for req=%s; talker may index past end", - int(tensor.shape[0]), - target_rows, - stop_emission_drop, - request_id, - ) - return tensor - - return ( - _trim(thinker_emb), - _trim(thinker_hid), - list(prompt_token_ids), - list(all_token_ids), - ) - - def _extract_qwen3_full_payload_codec_rows( code_predictor_codes: torch.Tensor, output_token_ids: list[int], @@ -595,16 +484,30 @@ def thinker2talker_full_payload( ) return None - ( - thinker_emb_prefill, - thinker_hid_prefill, - prompt_token_ids, - all_token_ids, - ) = _prepare_qwen3_talker_prefill( - request, - thinker_emb, - thinker_hid, - ) + prompt_token_ids = _ensure_list(getattr(request, "prompt_token_ids", []) or []) + all_token_ids = _ensure_list(getattr(request, "all_token_ids", None) or []) + if not all_token_ids: + output_token_ids = _ensure_list(getattr(request, "output_token_ids", []) or []) + all_token_ids = list(prompt_token_ids) + list(output_token_ids) + + # Trim the trailing stop-token row from the accumulated thinker output. + # The accumulator captures one hidden-state row per executed thinker + # forward (prefill + every decode step including the one that emitted + # the stop_token), so for a finished request thinker_emb has exactly one + # row more than the rows the talker should consume. async_chunk's + # chunk-0 path naturally captures only the prefill / non-stop portion, + # which is why the [async_chunk] parametrization passes while [default] + # over-generates one codec frame on short outputs (e.g. + # test_one_word_prompt_001[default]: audio extends "London" with + # spurious phonemes). + if isinstance(thinker_emb, torch.Tensor) and thinker_emb.shape[0] > 0: + thinker_emb_prefill = thinker_emb[:-1] + else: + thinker_emb_prefill = thinker_emb + if isinstance(thinker_hid, torch.Tensor) and thinker_hid.shape[0] > 0: + thinker_hid_prefill = thinker_hid[:-1] + else: + thinker_hid_prefill = thinker_hid payload: OmniPayload = { "embed": { From ca614ee5049cbbb3b09ef57d59dfc2128cb61221 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 27 May 2026 03:19:13 +0000 Subject: [PATCH 17/19] remove frequent useless log Signed-off-by: natureofnature --- .../omni_connector_model_runner_mixin.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index 74dd9aa9b70..83f58edba20 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -250,23 +250,24 @@ def cleanup_finished_request(self, req_id: str) -> None: # downstream consumer (e.g. text-only on multi-modal arch) leave # the entry orphaned in _pending_full_payload_send across requests, # which empirically destabilises subsequent thinker forwards by - # making prefix-cache reuse observe stale accumulator state. - # flush is a near-no-op for paths with no consumer, and idempotent - # when the entry has already been flushed by the scheduler-driven - # path. - try: - self.flush_full_payload_outputs({req_id}) - except Exception: - # Cleanup must still proceed regardless of flush errors here -- - # we already gated on ``_omni_connector_initialized`` upstream, - # so any exception here reflects a real connector-side issue - # (shared memory corruption, background thread crash) worth - # surfacing rather than silently swallowing. - logger.warning( - "flush_full_payload_outputs(%s) raised during cleanup; continuing tear-down.", - req_id, - exc_info=True, - ) + # making prefix-cache reuse observe stale accumulator state. The + # flush is idempotent when the entry has already been flushed by the + # scheduler-driven path, but this cleanup path runs for every request, + # so skip it entirely when the request never accumulated a payload. + if req_id in self._pending_full_payload_send: + try: + self.flush_full_payload_outputs({req_id}) + except Exception: + # Cleanup must still proceed regardless of flush errors here -- + # we already gated on ``_omni_connector_initialized`` upstream, + # so any exception here reflects a real connector-side issue + # (shared memory corruption, background thread crash) worth + # surfacing rather than silently swallowing. + logger.warning( + "flush_full_payload_outputs(%s) raised during cleanup; continuing tear-down.", + req_id, + exc_info=True, + ) ext_id = self._request_ids_mapping.pop(req_id, None) keys_to_clean: list[str] = [req_id] @@ -899,6 +900,10 @@ def accumulate_full_payload_output( def flush_full_payload_outputs(self, finished_req_ids: set[str]) -> None: """Send accumulated full_payload outputs for requests that just finished.""" + pending_req_ids = set(self._pending_full_payload_send.keys()) + if not (finished_req_ids & pending_req_ids): + return + logger.info( "[Stage-%s] flush_full_payload_outputs: finished_req_ids=%s, pending=%s", self._stage_id, From e7e3458354f94706452d61f76ab389c81bb01c1c Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 27 May 2026 03:46:54 +0000 Subject: [PATCH 18/19] minor optimize tts performance Signed-off-by: natureofnature --- .../test_qwen3_omni_streaming_helpers.py | 119 ++++++++++++++++-- .../models/qwen3_tts/qwen3_tts_code2wav.py | 30 ++++- .../stage_input_processors/qwen3_tts.py | 2 +- 3 files changed, 136 insertions(+), 15 deletions(-) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index 2e44a8e7776..a04a75de875 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -523,7 +523,7 @@ def __init__(self, mm, tids): def test_qwen3_tts_talker2code2wav_full_payload_smoke() -> None: - """Smoke: qwen3_tts full_payload reads flat codes.audio + flattens col-major.""" + """Smoke: qwen3_tts full_payload reads flat codes.audio + flattens codebook-major.""" from types import SimpleNamespace import torch @@ -540,7 +540,10 @@ def test_qwen3_tts_talker2code2wav_full_payload_smoke() -> None: assert payload is not None assert "codes" in payload and "audio" in payload["codes"] # codebook-major: shape [3, 16] -> [16, 3] -> flatten = 48 entries - assert len(payload["codes"]["audio"]) == 48 + assert isinstance(payload["codes"]["audio"], torch.Tensor) + assert payload["codes"]["audio"].shape == (48,) + expected = audio.transpose(0, 1).reshape(-1) + assert torch.equal(payload["codes"]["audio"], expected) assert payload["meta"]["finished"].item() is True @@ -572,15 +575,16 @@ def test_qwen3_tts_full_payload_with_ref_code() -> None: # Exact expected: ref (prepended) + audio (no crop since seq_len > rows), then # transpose [5, 16] -> [16, 5] and flatten row-major (codebook-major). - expected = torch.cat([ref, audio], dim=0).transpose(0, 1).reshape(-1).tolist() - assert payload["codes"]["audio"] == expected, ( - f"codec flatten mismatch -- got first 8 = {payload['codes']['audio'][:8]}, expected first 8 = {expected[:8]}" + expected = torch.cat([ref, audio], dim=0).transpose(0, 1).reshape(-1) + assert torch.equal(payload["codes"]["audio"], expected), ( + f"codec flatten mismatch -- got first 8 = {payload['codes']['audio'][:8].tolist()}, " + f"expected first 8 = {expected[:8].tolist()}" ) - assert len(payload["codes"]["audio"]) == 80 # 16 quantizers * (2 ref + 3 audio) frames + assert payload["codes"]["audio"].shape == (80,) # 16 quantizers * (2 ref + 3 audio) frames # Sanity guards: first codebook-major column = [ref[0,0], ref[1,0], audio[0,0], ...], # so the prepend order must put 100 before 1. - first_col = payload["codes"]["audio"][:5] + first_col = payload["codes"]["audio"][:5].tolist() assert first_col == [100, 116, 1, 17, 33], ( f"first column wrong: {first_col} -- ref likely appended instead of prepended" ) @@ -601,7 +605,98 @@ def test_qwen3_tts_full_payload_nested_fallback() -> None: req = SimpleNamespace(output_token_ids=list(range(10))) payload = talker2code2wav_full_payload(None, pooling_output, req) assert payload is not None - assert len(payload["codes"]["audio"]) == 32 # 16 * 2 + assert isinstance(payload["codes"]["audio"], torch.Tensor) + assert payload["codes"]["audio"].shape == (32,) # 16 * 2 + + +def test_qwen3_tts_code2wav_prefers_connector_tensor_payload() -> None: + """Code2Wav should consume connector codec tensor instead of placeholder zeros.""" + import torch + + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav import ( + _codec_ids_from_payload_or_input, + ) + + placeholder = torch.zeros(6, dtype=torch.long) + codec = torch.arange(12, dtype=torch.long) + + out = _codec_ids_from_payload_or_input( + placeholder, + {"codes": {"audio": codec}}, + ) + + assert torch.equal(out, codec) + + +def test_qwen3_tts_code2wav_accepts_legacy_list_payload() -> None: + """Back-compat: old list full-payloads still override placeholder tokens.""" + import torch + + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav import ( + _codec_ids_from_payload_or_input, + ) + + placeholder = torch.zeros(6, dtype=torch.long) + + out = _codec_ids_from_payload_or_input( + placeholder, + {"codes": {"audio": [1, 2, 3, 4]}}, + ) + + assert torch.equal(out, torch.tensor([1, 2, 3, 4], dtype=torch.long)) + + +def test_qwen3_tts_code2wav_forward_decodes_connector_payload() -> None: + """Forward should decode real connector codes, not token-only placeholders.""" + from collections import Counter + + import torch + + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav import ( + Qwen3TTSCode2Wav, + ) + + class _Decoder: + def __init__(self): + self.last_codes = None + + def chunked_decode(self, codes, **_kwargs): + self.last_codes = codes.detach().clone() + return codes.sum(dim=1).to(torch.float32) + + decoder = _Decoder() + model = Qwen3TTSCode2Wav.__new__(Qwen3TTSCode2Wav) + torch.nn.Module.__init__(model) + model.decoder = decoder + model._num_quantizers = 2 + model._total_upsample = 1 + model._output_sample_rate = 24000 + model._decode_chunk_frames = 300 + model._decode_left_context_frames = 25 + model._decode_batch_bucket_frames = [] + model._decode_batch_max_size = 0 + model._decode_variable_chunk_batch_min_frames = 326 + model._logged_codec_stats = True + model._logged_malformed_codec_lengths = set() + model._batch_stats_enabled = False + model._batch_stats_log_every = 0 + model._batch_stats_forwards = 0 + model._batch_stats_groups = 0 + model._batch_stats_requests = 0 + model._batch_stats_padded_frames = 0 + model._batch_stats_decoded_frames = 0 + model._batch_stats_actual_frames = Counter() + model._batch_stats_bucket_groups = Counter() + + payload_codes = torch.tensor([1, 3, 2, 4], dtype=torch.long) + out = model.forward( + input_ids=torch.zeros(4, dtype=torch.long), + runtime_additional_information=[{"codes": {"audio": payload_codes}, "meta": {}}], + ) + + assert decoder.last_codes is not None + assert torch.equal(decoder.last_codes, torch.tensor([[[1, 3], [2, 4]]], dtype=torch.long)) + assert torch.equal(out.multimodal_outputs["model_outputs"][0], torch.tensor([3.0, 7.0])) def test_qwen3_tts_codec_filter_and_crop_edge_cases() -> None: @@ -666,12 +761,12 @@ def __init__(self, ctids, mm): # After filter + crop, kept rows = [row4, row5, row6] = [[200,2047,210,220],[300,310,320,330],[400,410,420,430]] # Codebook-major flatten: transpose [3, Q] -> [Q, 3] -> reshape(-1) cropped = torch.tensor(kept[-3:], dtype=torch.long) - expected = cropped.transpose(0, 1).reshape(-1).tolist() - assert payload["codes"]["audio"] == expected + expected = cropped.transpose(0, 1).reshape(-1) + assert torch.equal(payload["codes"]["audio"], expected) # Sanity: confirm the boundary-valid 2047 survived (codex P2 #3 regression guard). - assert _CODEBOOK_SIZE - 1 in payload["codes"]["audio"] + assert _CODEBOOK_SIZE - 1 in payload["codes"]["audio"].tolist() # Sanity: confirm no negative or >=_CODEBOOK_SIZE codec id leaked through. - assert all(0 <= v < _CODEBOOK_SIZE for v in payload["codes"]["audio"]) + assert bool(((payload["codes"]["audio"] >= 0) & (payload["codes"]["audio"] < _CODEBOOK_SIZE)).all()) def test_cosyvoice3_text2flow_token_only_smoke() -> None: diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index b46018e3616..b7c3c9d6e46 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -25,6 +25,27 @@ logger = init_logger(__name__) +def _codec_ids_from_payload_or_input( + input_ids: torch.Tensor, + runtime_info: dict[str, Any] | None, +) -> torch.Tensor: + """Prefer connector-delivered codec ids over token placeholders. + + In non-async full-payload mode, the scheduler only needs placeholder + token ids for allocation. The real codec sequence is delivered through + model_intermediate_buffer as ``codes.audio``. + """ + if isinstance(runtime_info, dict): + codes = runtime_info.get("codes") + if isinstance(codes, dict): + audio = codes.get("audio") + if isinstance(audio, torch.Tensor) and audio.numel() > 0: + return audio.reshape(-1).to(device=input_ids.device, dtype=torch.long) + if isinstance(audio, (list, tuple)) and audio: + return torch.as_tensor(audio, device=input_ids.device, dtype=torch.long).reshape(-1) + return input_ids.reshape(-1).to(dtype=torch.long) + + class Qwen3TTSCode2Wav(nn.Module): """Stage-1 code2wav model for Qwen3-TTS (GenerationModelRunner). Consumes frame-aligned codec tokens from input_ids and decodes waveform @@ -239,6 +260,7 @@ def forward( multimodal_outputs={"model_outputs": [empty], "sr": [sr_tensor]}, ) + runtime_infos = runtime_additional_information or [] ids = input_ids.reshape(-1).to(dtype=torch.long) request_ids_list = self._split_request_ids(ids, kwargs.get("seq_token_counts")) @@ -246,14 +268,18 @@ def forward( valid_codes_qf: list[torch.Tensor] = [] valid_indices: list[int] = [] left_context_size = [0] * len(request_ids_list) - if runtime_additional_information is not None: - for i, info in enumerate(runtime_additional_information): + if runtime_infos: + for i, info in enumerate(runtime_infos): if i >= len(left_context_size): break + if not isinstance(info, dict): + continue meta = info.get("meta", {}) if "left_context_size" in meta: left_context_size[i] = meta["left_context_size"] for i, req_ids in enumerate(request_ids_list): + runtime_info = runtime_infos[i] if i < len(runtime_infos) else None + req_ids = _codec_ids_from_payload_or_input(req_ids, runtime_info) if req_ids.numel() < 1: parsed.append((0, 0)) continue diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index abe32a103db..1ffbdb931a9 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -495,7 +495,7 @@ def talker2code2wav_full_payload( if ref_code is not None: audio = torch.cat([ref_code.to(audio.device), audio], dim=0) - codec_codes = audio.transpose(0, 1).cpu().reshape(-1).tolist() + codec_codes = audio.transpose(0, 1).to(device="cpu", dtype=torch.long).reshape(-1).contiguous() return { "codes": {"audio": codec_codes}, "meta": {"finished": torch.tensor(True, dtype=torch.bool)}, From 8ee05c1ae36e137064e90fdc34953b2db12c5de8 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 27 May 2026 08:37:37 +0000 Subject: [PATCH 19/19] minor fix Signed-off-by: natureofnature --- vllm_omni/core/sched/omni_scheduler_mixin.py | 7 ++++++- .../stage_input_processors/qwen2_5_omni.py | 1 + .../omni_connector_model_runner_mixin.py | 19 +++++++++++++++++-- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/vllm_omni/core/sched/omni_scheduler_mixin.py b/vllm_omni/core/sched/omni_scheduler_mixin.py index 78284efa3e9..570fa554545 100644 --- a/vllm_omni/core/sched/omni_scheduler_mixin.py +++ b/vllm_omni/core/sched/omni_scheduler_mixin.py @@ -22,9 +22,14 @@ # Scope: this constant only covers the full-payload coordinator path # (``input_coordinator``). The async-chunk path uses # ``chunk_transfer_adapter`` and is not affected by this constant. +_INPUT_WAIT_TIMEOUT_RAW = os.environ.get("VLLM_OMNI_INPUT_WAIT_TIMEOUT_S", "300") try: - DEFAULT_INPUT_WAIT_TIMEOUT_S: float = float(os.environ.get("VLLM_OMNI_INPUT_WAIT_TIMEOUT_S", "300")) + DEFAULT_INPUT_WAIT_TIMEOUT_S: float = float(_INPUT_WAIT_TIMEOUT_RAW) except ValueError: + logger.warning( + "Invalid VLLM_OMNI_INPUT_WAIT_TIMEOUT_S=%r; falling back to 300 seconds.", + _INPUT_WAIT_TIMEOUT_RAW, + ) DEFAULT_INPUT_WAIT_TIMEOUT_S = 300.0 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index 5046fa649f2..225674360c2 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -322,6 +322,7 @@ def _ensure_list(x): if x is None: return [] if hasattr(x, "_x"): + # vLLM wraps cached token-id lists in ConstantList-like objects. return list(x._x) if isinstance(x, list): return list(x) diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index 83f58edba20..64b7e60e26c 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -820,16 +820,31 @@ def _resolve_full_payload_replace_keys(self) -> frozenset: self._full_payload_replace_keys_cached = frozenset() return self._full_payload_replace_keys_cached try: - import importlib as _il import sys as _sys - mod = _sys.modules.get(module_name) or _il.import_module(module_name) + mod = _sys.modules.get(module_name) or importlib.import_module(module_name) keys = getattr(mod, "_FULL_PAYLOAD_REPLACE_KEYS", frozenset()) except ImportError: + logger.debug( + "Could not import stage input processor module %s while resolving " + "_FULL_PAYLOAD_REPLACE_KEYS; using CONCAT semantics for all keys.", + module_name, + exc_info=True, + ) keys = frozenset() if not isinstance(keys, (frozenset, set)): + logger.debug( + "Ignoring non-set _FULL_PAYLOAD_REPLACE_KEYS from %s: %s", + module_name, + type(keys).__name__, + ) keys = frozenset() self._full_payload_replace_keys_cached = frozenset(keys) + logger.debug( + "Resolved _FULL_PAYLOAD_REPLACE_KEYS for %s: %s", + module_name, + sorted(self._full_payload_replace_keys_cached), + ) return self._full_payload_replace_keys_cached def accumulate_full_payload_output(