From 1e0a7659b3394ca4b345159cf6e2c28ce4342f11 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Mon, 1 Jun 2026 08:33:56 +0000 Subject: [PATCH 01/10] [PR4] add async-chunk coordinator gate and chunk signals Introduce the coordinator-side selection, carried registration field, chunk-ready/finished consumption, late-ready retention, and scheduler completion guards needed before enabling async-chunk stages. Signed-off-by: natureofnature --- .../sched/test_omni_scheduling_coordinator.py | 131 +++++++++++++++++- vllm_omni/core/sched/omni_ar_scheduler.py | 13 +- .../core/sched/omni_generation_scheduler.py | 32 ++++- vllm_omni/core/sched/omni_scheduler_mixin.py | 21 ++- .../core/sched/omni_scheduling_coordinator.py | 88 ++++++++++-- vllm_omni/core/sched/output.py | 2 +- vllm_omni/worker/gpu_ar_model_runner.py | 2 +- .../worker/gpu_generation_model_runner.py | 2 +- .../omni_connector_model_runner_mixin.py | 11 +- 9 files changed, 266 insertions(+), 36 deletions(-) diff --git a/tests/core/sched/test_omni_scheduling_coordinator.py b/tests/core/sched/test_omni_scheduling_coordinator.py index 1b36cd784d8..c7afe1020de 100644 --- a/tests/core/sched/test_omni_scheduling_coordinator.py +++ b/tests/core/sched/test_omni_scheduling_coordinator.py @@ -10,12 +10,14 @@ import unittest from types import SimpleNamespace +from unittest import mock import torch import vllm_omni.core.sched.omni_scheduling_coordinator as coord_mod from vllm_omni.core.sched.omni_scheduling_coordinator import ( OmniSchedulingCoordinator, + uses_async_chunk_coordinator, uses_full_payload_input_coordinator, ) @@ -204,6 +206,26 @@ def test_ready_request_transitions_to_waiting(self): self.assertEqual(req.status, RequestStatus.WAITING) self.assertIn("r1", coord.requests_with_ready_chunks) + def test_late_ready_before_queue_insertion_is_retained(self): + # codex r3: a chunk can arrive before the request is surfaced into a + # queue. The readiness must be retained (not lost when the connector + # output is cleared) so a later cycle still transitions the request. + coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True) + + # Cycle 1: ready for "r1" arrives while no queue holds it yet. + coord.process_pending_chunks(MockQueue([]), [], chunk_ready_req_ids={"r1"}, chunk_finished_req_ids=set()) + self.assertIn("r1", coord.requests_with_ready_chunks, "late ready must be retained") + + # Cycle 2: r1 now appears as a fresh WAITING request, but chunk_ready is + # already empty (the connector output was consumed last cycle). Because + # retention recorded r1, it must NOT be parked into WAITING_FOR_CHUNK -- + # it stays schedulable. Without the retain it would be wrongly parked. + req = _make_request("r1", status=RequestStatus.WAITING) + waiting = MockQueue([req]) + coord.process_pending_chunks(waiting, [], chunk_ready_req_ids=set(), chunk_finished_req_ids=set()) + self.assertEqual(req.status, RequestStatus.WAITING, "ready-before-insertion must not be parked") + self.assertIn(req, waiting, "request must remain schedulable in the waiting queue") + def test_non_ready_stays_waiting_for_chunk(self): coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True) @@ -471,7 +493,7 @@ def test_full_payload_mode_auto_transitions_waiting_to_waiting_for_input(self): self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT) self.assertEqual(len(coord._waiting_for_input), 1) - self.assertEqual(len(coord.pending_input_registrations), 1) + self.assertEqual(len(coord.pending_connector_registrations), 1) def test_async_chunk_mode_does_not_auto_transition(self): """In async_chunk mode, fresh WAITING requests should NOT be @@ -494,7 +516,7 @@ def test_async_chunk_mode_does_not_auto_transition(self): self.assertEqual(req.status, RequestStatus.WAITING) - def test_pending_input_registrations(self): + def test_pending_connector_registrations(self): coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1) req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT) @@ -507,8 +529,8 @@ def test_pending_input_registrations(self): stage_recv_req_ids=set(), ) - self.assertEqual(len(coord.pending_input_registrations), 1) - self.assertEqual(coord.pending_input_registrations[0].request_id, "r1") + self.assertEqual(len(coord.pending_connector_registrations), 1) + self.assertEqual(coord.pending_connector_registrations[0].request_id, "r1") def test_idle_cycles_retain_received_marker_before_request_appears(self): coord = OmniSchedulingCoordinator( @@ -533,7 +555,7 @@ def test_idle_cycles_retain_received_marker_before_request_appears(self): coord.process_pending_full_payload_inputs(waiting, running, stage_recv_req_ids=set()) self.assertEqual(late_req.status, RequestStatus.WAITING) - self.assertEqual(coord.pending_input_registrations, []) + self.assertEqual(coord.pending_connector_registrations, []) self.assertIn("late", coord._full_payload_input_received) self.assertIn("late", coord.finished_requests) @@ -861,5 +883,104 @@ def test_overflow_does_not_strand_request(self): self.assertNotEqual(req.status, RequestStatus.RUNNING, "Overflowed request must not keep RUNNING status") +class TestAsyncChunkCoordinatorGate(unittest.TestCase): + """PR4: `uses_async_chunk_coordinator` selects the coordinator+mixin path for + allowlisted async-chunk archs on SharedMemory only; everyone else (empty + allowlist today, Mooncake, sync) stays on the legacy adapter. + """ + + _SM = {"name": "SharedMemoryConnector"} + _MOONCAKE = {"name": "MooncakeStoreConnector"} + + def test_allowlisted_sharedmemory_fires(self): + key = ("Qwen3OmniMoeForConditionalGeneration", "talker") + with mock.patch.object(coord_mod, "_ASYNC_CHUNK_COORDINATOR_STAGES", frozenset({key})): + mc = SimpleNamespace( + async_chunk=True, + model_arch=key[0], + model_stage=key[1], + stage_connector_config=self._SM, + ) + self.assertTrue(uses_async_chunk_coordinator(mc)) + # default (no connector config) is SharedMemory -> also fires + mc_default = SimpleNamespace( + async_chunk=True, + model_arch=key[0], + model_stage=key[1], + stage_connector_config=None, + ) + self.assertTrue(uses_async_chunk_coordinator(mc_default)) + + def test_mooncake_stays_on_adapter(self): + key = ("Qwen3OmniMoeForConditionalGeneration", "talker") + with mock.patch.object(coord_mod, "_ASYNC_CHUNK_COORDINATOR_STAGES", frozenset({key})): + mc = SimpleNamespace( + async_chunk=True, + model_arch=key[0], + model_stage=key[1], + stage_connector_config=self._MOONCAKE, + ) + self.assertFalse(uses_async_chunk_coordinator(mc)) + + def test_sync_or_non_allowlisted_does_not_fire(self): + key = ("Qwen3OmniMoeForConditionalGeneration", "talker") + with mock.patch.object(coord_mod, "_ASYNC_CHUNK_COORDINATOR_STAGES", frozenset({key})): + # async_chunk=False + self.assertFalse( + uses_async_chunk_coordinator( + SimpleNamespace( + async_chunk=False, model_arch=key[0], model_stage=key[1], stage_connector_config=self._SM + ) + ) + ) + # non-allowlisted arch + self.assertFalse( + uses_async_chunk_coordinator( + SimpleNamespace( + async_chunk=True, + model_arch="MiMoAudioModel", + model_stage="code2wav", + stage_connector_config=self._SM, + ) + ) + ) + # non-allowlisted stage of an allowlisted arch + self.assertFalse( + uses_async_chunk_coordinator( + SimpleNamespace( + async_chunk=True, model_arch=key[0], model_stage="thinker", stage_connector_config=self._SM + ) + ) + ) + + +class TestAsyncChunkRecvRegistration(unittest.TestCase): + """PR4 regression (flip data-plane bug 2026-06-02): a parked async-chunk + request MUST be registered for bg-thread recv via the CARRIED + ``pending_connector_registrations`` (the old ``pending_chunk_registrations`` + was never carried/consumed -> the runner never called register_chunk_recv, + the bg thread never polled, and the request hung until the 300s timeout). + The full-payload pass runs AFTER process_pending_chunks each cycle in + async-chunk mode, so it must NOT re-clear the chunk registrations. + """ + + def test_parked_chunk_request_registered_and_survives_full_payload_pass(self): + coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True) + req = _make_request("r1", status=RequestStatus.WAITING) + waiting = MockQueue([req]) + running: list = [] + + # No chunk ready yet -> park WAITING_FOR_CHUNK AND register for recv. + coord.process_pending_chunks(waiting, running, chunk_ready_req_ids=set(), chunk_finished_req_ids=set()) + self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK) + regs = [h.request_id for h in coord.pending_connector_registrations] + self.assertIn("r1", regs, "parked async-chunk request must be registered for bg recv polling") + + # The full-payload pass (runs after, every cycle) must not wipe it. + coord.process_pending_full_payload_inputs(waiting, running, stage_recv_req_ids=set()) + regs_after = [h.request_id for h in coord.pending_connector_registrations] + self.assertIn("r1", regs_after, "full-payload pass must not drop async-chunk recv registrations") + + if __name__ == "__main__": unittest.main() diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 67514994ec8..2dfef93f761 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -22,6 +22,7 @@ from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin from vllm_omni.core.sched.omni_scheduling_coordinator import ( OmniSchedulingCoordinator, + uses_async_chunk_coordinator, uses_full_payload_input_coordinator, ) from vllm_omni.core.sched.utils import omni_routed_experts_for_request @@ -80,15 +81,21 @@ def __init__(self, *args, **kwargs): # Cache per-request flag to avoid repeated deserialization of additional_information self._omits_kv_transfer_cache: dict[str, bool] = {} model_config = self.vllm_config.model_config + # PR4: allowlisted async-chunk archs (SharedMemory) drive recv through the + # OmniSchedulingCoordinator + runner mixin; everyone else keeps the legacy + # adapter. Empty allowlist today => _async_coord is always False (no behavior + # change). Full-payload (async_chunk=False) and async-chunk are mutually + # exclusive, so a single coordinator instance serves whichever fires. + _async_coord = uses_async_chunk_coordinator(model_config) self.chunk_transfer_adapter = None - if getattr(model_config, "async_chunk", False): + if getattr(model_config, "async_chunk", False) and not _async_coord: self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config) self.input_coordinator: OmniSchedulingCoordinator | None = None - if uses_full_payload_input_coordinator(model_config): + if uses_full_payload_input_coordinator(model_config) or _async_coord: self.input_coordinator = OmniSchedulingCoordinator( scheduler_max_num_seqs=self.vllm_config.scheduler_config.max_num_seqs, stage_id=getattr(model_config, "stage_id", 0), - async_chunk=False, + async_chunk=_async_coord, ) self._latest_omni_connector_output: OmniConnectorOutput | None = None # Snapshot prompt length for each streaming input update diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 1c43b958a59..517442e8349 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -26,6 +26,7 @@ from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin from vllm_omni.core.sched.omni_scheduling_coordinator import ( OmniSchedulingCoordinator, + uses_async_chunk_coordinator, uses_full_payload_input_coordinator, ) from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData @@ -43,16 +44,26 @@ class OmniGenerationScheduler(OmniSchedulerMixin, VLLMScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) model_config = self.vllm_config.model_config + # PR4: see OmniARScheduler.__init__ -- allowlisted async-chunk archs + # (SharedMemory) use the coordinator + runner mixin; others keep the adapter. + # Empty allowlist today => _async_coord False => no behavior change. + _async_coord = uses_async_chunk_coordinator(model_config) + # When True, stage completion follows the coordinator's terminal chunk + # signal (input_coordinator.finished_requests), NOT the plain-generation + # "num_computed >= num_prompt" heuristic -- otherwise code2wav (stage 2) + # would finish before its terminal chunk arrives. False with the empty + # allowlist, so the legacy branches are unchanged. + self._async_chunk_coordinator_active = _async_coord self.chunk_transfer_adapter = None - if getattr(model_config, "async_chunk", False): + if getattr(model_config, "async_chunk", False) and not _async_coord: self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config) self._pending_finish_reqs: list[Request] = [] self.input_coordinator: OmniSchedulingCoordinator | None = None - if uses_full_payload_input_coordinator(model_config): + if uses_full_payload_input_coordinator(model_config) or _async_coord: self.input_coordinator = OmniSchedulingCoordinator( scheduler_max_num_seqs=self.vllm_config.scheduler_config.max_num_seqs, stage_id=getattr(model_config, "stage_id", 0), - async_chunk=False, + async_chunk=_async_coord, ) self._latest_omni_connector_output: OmniConnectorOutput | None = None @@ -487,12 +498,25 @@ def update_from_output( # Diffusion request: completes in one step; mark finished and free resources if ( request.status == RequestStatus.FINISHED_STOPPED - or (self.chunk_transfer_adapter is None and request.num_computed_tokens >= request.num_prompt_tokens) + or ( + self.chunk_transfer_adapter is None + and not self._async_chunk_coordinator_active + and request.num_computed_tokens >= request.num_prompt_tokens + ) or ( self.chunk_transfer_adapter is not None and self.chunk_transfer_adapter.is_done_receiving_chunks(request.request_id) and request.num_computed_tokens >= len(request.prompt_token_ids) ) + or ( + # async-chunk coordinator path: complete only when the terminal + # chunk has arrived (coordinator finished_requests), mirroring the + # legacy-adapter clause above instead of the plain-gen heuristic. + self._async_chunk_coordinator_active + and self.input_coordinator is not None + and request.request_id in self.input_coordinator.finished_requests + and request.num_computed_tokens >= len(request.prompt_token_ids) + ) ): request.status = RequestStatus.FINISHED_STOPPED # Optional: set a stop_reason for front-end clarity diff --git a/vllm_omni/core/sched/omni_scheduler_mixin.py b/vllm_omni/core/sched/omni_scheduler_mixin.py index c33d16dac2b..11f9d885bdb 100644 --- a/vllm_omni/core/sched/omni_scheduler_mixin.py +++ b/vllm_omni/core/sched/omni_scheduler_mixin.py @@ -65,6 +65,17 @@ def _consume_pending_connector_output(self, model_mode: str) -> None: input_coordinator.update_request_metadata( self.requests, connector_output.request_metadata, model_mode=model_mode ) + # Both calls self-guard on the coordinator's async_chunk mode + # (process_pending_chunks returns early when async_chunk is False; + # process_pending_full_payload_inputs branches internally), so exactly + # one path is live per deployment. With an empty async allowlist the + # coordinator is never in async_chunk mode -> the chunk call is a no-op. + input_coordinator.process_pending_chunks( + self.waiting, + self.running, + connector_output.chunk_ready_req_ids if connector_output else set(), + connector_output.chunk_finished_req_ids if connector_output else set(), + ) input_coordinator.process_pending_full_payload_inputs( self.waiting, self.running, @@ -132,7 +143,7 @@ def _wrap_omni_scheduler_output( base: SchedulerOutput, *, finished_requests_needing_kv_transfer: dict | None = None, - pending_input_registrations: list[OmniChunkRecvHandle] | None = None, + pending_connector_registrations: list[OmniChunkRecvHandle] | None = None, ) -> OmniSchedulerOutput: """Wrap a base ``SchedulerOutput`` in ``OmniSchedulerOutput``. @@ -142,12 +153,14 @@ def _wrap_omni_scheduler_output( """ base_data = {name: getattr(base, name) for name in SchedulerOutput.__dataclass_fields__} input_coordinator = getattr(self, "input_coordinator", None) - if pending_input_registrations is None: - pending_input_registrations = input_coordinator.pending_input_registrations if input_coordinator else [] + if pending_connector_registrations is None: + pending_connector_registrations = ( + input_coordinator.pending_connector_registrations if input_coordinator else [] + ) return OmniSchedulerOutput( **base_data, finished_requests_needing_kv_transfer=finished_requests_needing_kv_transfer or {}, - pending_input_registrations=pending_input_registrations, + pending_connector_registrations=pending_connector_registrations, ) def make_stats(self, *args, **kwargs) -> SchedulerStats | None: diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index 4056fd93861..b185332d8f4 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -79,6 +79,43 @@ def uses_full_payload_input_coordinator(model_config: Any) -> bool: return key in _FULL_PAYLOAD_INPUT_STAGES +# (model_arch, model_stage) whose async-chunk RECEIVE is coordinated by +# OmniSchedulingCoordinator (+ the runner-level mixin transport) instead of the +# legacy scheduler-owned OmniChunkTransferAdapter. Intentionally EMPTY until the +# qwen3_omni cutover lands (PR4 final commit): an empty allowlist keeps every arch +# on the adapter, so each intermediate commit is behavior-preserving. Final +# entries are the recv stages only -- stage-0 producers do not wait on chunks: +# ("Qwen3OmniMoeForConditionalGeneration", "talker") # stage 1, recv from 0 +# ("Qwen3OmniMoeForConditionalGeneration", "code2wav") # stage 2, recv from 1 +_ASYNC_CHUNK_COORDINATOR_STAGES: frozenset[tuple[str, str]] = frozenset() + + +def uses_async_chunk_coordinator(model_config: Any) -> bool: + """Returns True iff this stage's async-chunk receive should be driven by + ``OmniSchedulingCoordinator`` + the runner-level mixin transport rather than + the legacy ``OmniChunkTransferAdapter``. + + Gated by the ``(model_arch, model_stage)`` allowlist AND a + ``SharedMemoryConnector`` connector: the single-connector mixin path is only + valid under SharedMemory's role-neutral put/get, so a Mooncake-served + deployment of the same arch stays on the adapter (graceful fallback). + """ + if not getattr(model_config, "async_chunk", False): + return False + key = ( + getattr(model_config, "model_arch", None), + getattr(model_config, "model_stage", None), + ) + if key not in _ASYNC_CHUNK_COORDINATOR_STAGES: + return False + connector_config = getattr(model_config, "stage_connector_config", None) + if isinstance(connector_config, dict): + name = connector_config.get("name") + else: + name = getattr(connector_config, "name", None) + return (name or "SharedMemoryConnector") == "SharedMemoryConnector" + + class OmniSchedulingCoordinator: """Pure-scheduling coordinator for chunk and full_payload input waiting. @@ -100,18 +137,16 @@ def __init__(self, scheduler_max_num_seqs: int, stage_id: int = 0, async_chunk: self._waiting_for_chunk_waiting: deque[Any] = deque() self._waiting_for_chunk_running: deque[Any] = deque() - # Request IDs that were newly registered for chunk recv this cycle. - # The engine/Model Runner should call register_chunk_recv() for these - # so the bg thread starts polling. - self.pending_chunk_registrations: list[Any] = [] - # Requests waiting for full_payload stage input (WAITING_FOR_INPUT). self._waiting_for_input: deque[Any] = deque() - # Per-cycle list of minimal handles to ship to the model runner so it - # can call register_chunk_recv(). Typed concretely (not list[Any]) so - # the surrounding OmniSchedulerOutput stays msgspec-friendly across - # default, PD-disagg, and multi-node executor IPC paths. - self.pending_input_registrations: list[OmniChunkRecvHandle] = [] + # Per-cycle list of minimal handles shipped to the model runner so it + # can call register_chunk_recv() (so the bg thread starts polling). + # Populated by BOTH the async-chunk park (process_pending_chunks) and the + # full-payload park (process_pending_full_payload_inputs); carried on + # OmniSchedulerOutput by _wrap_omni_scheduler_output. Typed concretely + # (not list[Any]) so the surrounding OmniSchedulerOutput stays + # msgspec-friendly across default, PD-disagg, and multi-node executor IPC. + self.pending_connector_registrations: list[OmniChunkRecvHandle] = [] # Monotonic timestamp recording when each request first entered # WAITING_FOR_CHUNK or WAITING_FOR_INPUT. Used by @@ -140,9 +175,19 @@ def process_pending_chunks( if self._stage_id == 0 or not self._async_chunk: return + # Retain readiness for requests not yet surfaced into the waiting/running + # queues: a bg recv can complete before the request appears in a queue, + # and the connector output clears chunk_ready_req_ids after this cycle. + # Without an unconditional retain, that first ready signal is lost during + # the queue scan below and the request hangs in WAITING_FOR_CHUNK forever. + # (Ported from tmp/trim_refactor; see codex review r3.) + self.requests_with_ready_chunks.update(chunk_ready_req_ids) + terminal_ready_req_ids = chunk_ready_req_ids.intersection(chunk_finished_req_ids) self.finished_requests.update(chunk_finished_req_ids - terminal_ready_req_ids) - self.pending_chunk_registrations = [] + # Reset the carried registration list; the full-payload pass (which runs + # after this in async-chunk mode) must NOT re-clear it (see its guard). + self.pending_connector_registrations = [] self._process_chunk_queue( waiting_queue, @@ -191,7 +236,11 @@ def process_pending_full_payload_inputs( self._stage_id, stage_recv_req_ids, ) - self.pending_input_registrations = [] + # Only the full-payload path owns this reset. In async-chunk mode this + # method runs AFTER process_pending_chunks (which already reset + filled + # the list), so re-clearing here would drop the chunk recv registrations. + if not self._async_chunk: + self.pending_connector_registrations = [] remaining: deque[Any] = deque() for request in self._waiting_for_input: @@ -218,7 +267,7 @@ def process_pending_full_payload_inputs( self._waiting_since.setdefault(request.request_id, time.monotonic()) to_remove.append(request) self._waiting_for_input.append(request) - self.pending_input_registrations.append( + self.pending_connector_registrations.append( OmniChunkRecvHandle( request_id=request.request_id, external_req_id=getattr(request, "external_req_id", None), @@ -231,7 +280,7 @@ def process_pending_full_payload_inputs( else: to_remove.append(request) self._waiting_for_input.append(request) - self.pending_input_registrations.append( + self.pending_connector_registrations.append( OmniChunkRecvHandle( request_id=request.request_id, external_req_id=getattr(request, "external_req_id", None), @@ -453,7 +502,16 @@ def _process_chunk_queue( if request.request_id in chunk_ready_req_ids: self.requests_with_ready_chunks.add(request.request_id) continue - self.pending_chunk_registrations.append(request) + # Register for bg-thread chunk recv via the CARRIED field so the + # runner's register_chunk_recv() actually starts polling. (The + # old pending_chunk_registrations was never carried/consumed -> + # the request parked forever and timed out at 300s.) + self.pending_connector_registrations.append( + OmniChunkRecvHandle( + request_id=request.request_id, + external_req_id=getattr(request, "external_req_id", None), + ) + ) request.status = RequestStatus.WAITING_FOR_CHUNK self._waiting_since.setdefault(request.request_id, time.monotonic()) else: diff --git a/vllm_omni/core/sched/output.py b/vllm_omni/core/sched/output.py index 29cd872998f..bb09f128a44 100644 --- a/vllm_omni/core/sched/output.py +++ b/vllm_omni/core/sched/output.py @@ -93,4 +93,4 @@ class OmniSchedulerOutput(SchedulerOutput): """Scheduler output with omni-specific transfer metadata.""" finished_requests_needing_kv_transfer: dict[str, dict] = field(default_factory=dict) - pending_input_registrations: list[OmniChunkRecvHandle] = field(default_factory=list) + pending_connector_registrations: list[OmniChunkRecvHandle] = field(default_factory=list) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index c59b3b4c73a..a250e1f2ca8 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -436,7 +436,7 @@ def execute_model( ) if hasattr(self, "_omni_connector"): - for request in getattr(scheduler_output, "pending_input_registrations", []): + for request in getattr(scheduler_output, "pending_connector_registrations", []): self.register_chunk_recv(request) self.recv_full_payload_inputs(scheduler_output) if self._pending_full_payload_send: diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index e4b7944d113..ae248886b6c 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -110,7 +110,7 @@ def execute_model( self.routed_experts_capturer.clear_buffer() if hasattr(self, "_omni_connector"): - for request in getattr(scheduler_output, "pending_input_registrations", []): + for request in getattr(scheduler_output, "pending_connector_registrations", []): self.register_chunk_recv(request) self.recv_full_payload_inputs(scheduler_output) if self._pending_full_payload_send: diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index 6d991d0f24f..cfe3c55b459 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -1806,8 +1806,15 @@ def _poll_single_request(self, req_id: str) -> bool: payload_data = self._accumulate_payload(external_req_id, payload_data) payload_consumable = incoming_payload_consumable else: - new_ids = self._payload_audio_codes(payload_data) or [] - if not new_ids and not is_finished: + # codes.audio may be a multi-element tensor -> use a numel/len + # check, never `or []` / `not new_ids` (bool(tensor) raises + # "ambiguous truth value" on a >1-element tensor). + audio_codes = self._payload_audio_codes(payload_data) + if isinstance(audio_codes, torch.Tensor): + has_codes = audio_codes.numel() > 0 + else: + has_codes = bool(audio_codes) + if not has_codes and not is_finished: return False payload_consumable = self._payload_is_consumable(payload_data) From 56ed1c248066e4884b97d5194fbe5679563096d1 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Mon, 1 Jun 2026 10:28:55 +0000 Subject: [PATCH 02/10] [PR4] add runner async-chunk send and sentinel plumbing Add runner-side async-chunk sends, duplicate-send guards, finish sentinel enqueue/consume paths, code2wav terminal flushing, deep metadata merge semantics, and quiet per-chunk transport logging. Signed-off-by: natureofnature --- .../test_qwen3_omni_finish_sentinel.py | 79 +++++++++ .../test_async_chunk_request_adapter.py | 76 +++++++++ tests/worker/test_omni_connector_mixin.py | 156 ++++++++++++++++++ vllm_omni/data_entry_keys.py | 8 + .../stage_input_processors/qwen3_omni.py | 42 +++++ vllm_omni/worker/gpu_ar_model_runner.py | 137 +++++++++++++++ .../omni_connector_model_runner_mixin.py | 136 +++++++++++++-- 7 files changed, 624 insertions(+), 10 deletions(-) create mode 100644 tests/model_executor/stage_input_processors/test_qwen3_omni_finish_sentinel.py create mode 100644 tests/worker/test_async_chunk_request_adapter.py diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_finish_sentinel.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_finish_sentinel.py new file mode 100644 index 00000000000..e7a731d6079 --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_finish_sentinel.py @@ -0,0 +1,79 @@ +"""PR4 MC-C.3b: code2wav async finish-sentinel terminal flush. + +The producer runner sends every in-step codec chunk with ``finished=False`` and +emits a separate finish sentinel next cycle (empty payload + the +``ASYNC_FINISH_SENTINEL_KEY`` marker the legacy adapter never sets). On that +marker, ``talker2code2wav_async_chunk`` must flush the trailing partial codec +chunk that the live ``is_finished`` branch would otherwise have flushed, reusing +the same context math, without re-appending. +""" + +from types import SimpleNamespace + +import torch + +from vllm_omni.data_entry_keys import ASYNC_FINISH_SENTINEL_KEY +from vllm_omni.model_executor.stage_input_processors.qwen3_omni import talker2code2wav_async_chunk + + +def _tm(accumulated, chunk_frames=4, left_frames=25): + return SimpleNamespace( + code_prompt_token_ids=dict(accumulated), + connector=SimpleNamespace( + config={"extra": {"codec_chunk_frames": chunk_frames, "codec_left_context_frames": left_frames}} + ), + ) + + +def _sentinel_payload(): + return {ASYNC_FINISH_SENTINEL_KEY: True} + + +def test_finish_sentinel_flushes_partial_tail(): + # 6 frames accumulated, chunk size 4 -> a 2-frame partial tail is still held. + tm = _tm({"r": [[1], [2], [3], [4], [5], [6]]}, chunk_frames=4, left_frames=25) + req = SimpleNamespace(external_req_id="r") + + out = talker2code2wav_async_chunk(tm, _sentinel_payload(), req, is_finished=True) + + assert out is not None + assert bool(out.meta.finished) is True + # context_length = 6 % 4 = 2; left = min(6-2, 25) = 4; end_index = min(6, 4+2) = 6. + assert out.meta.left_context_size == 4 + assert isinstance(out.codes.audio, torch.Tensor) + # 6 single-codebook frames -> flattened length 6. + assert out.codes.audio.numel() == 6 + + +def test_finish_sentinel_on_chunk_boundary_emits_flag_only(): + # 4 frames, chunk size 4 -> the last full chunk was already sent in-step; + # no unsent tail, so the sentinel must NOT re-send codec (flag only). + tm = _tm({"r": [[1], [2], [3], [4]]}, chunk_frames=4) + req = SimpleNamespace(external_req_id="r") + + out = talker2code2wav_async_chunk(tm, _sentinel_payload(), req, is_finished=True) + + assert out is not None + assert bool(out.meta.finished) is True + assert out.codes is None, "boundary finish must not re-send the last full chunk" + + +def test_finish_sentinel_with_no_sent_chunks_emits_flag_only(): + tm = _tm({}, chunk_frames=4) + req = SimpleNamespace(external_req_id="missing") + + out = talker2code2wav_async_chunk(tm, _sentinel_payload(), req, is_finished=True) + + assert out is not None + assert bool(out.meta.finished) is True + assert out.codes is None + + +def test_non_sentinel_empty_call_is_unchanged(): + # Without the marker, an empty/codeless call returns None as before -> the + # adapter path (which never sets the marker) is byte-identical. + tm = _tm({"r": [[1], [2]]}, chunk_frames=4) + req = SimpleNamespace(external_req_id="r") + + assert talker2code2wav_async_chunk(tm, {"codes": {}}, req, is_finished=True) is None + assert talker2code2wav_async_chunk(tm, {}, req, is_finished=True) is None diff --git a/tests/worker/test_async_chunk_request_adapter.py b/tests/worker/test_async_chunk_request_adapter.py new file mode 100644 index 00000000000..c44a1c830fa --- /dev/null +++ b/tests/worker/test_async_chunk_request_adapter.py @@ -0,0 +1,76 @@ +"""PR4 MC-C.1: unit coverage for the runner-side async-chunk request shim. + +Covers the pure, GPU-free helpers that the AR runner uses to feed the +async-chunk stage-input processors from a worker-side ``CachedRequestState``: +``_strip_trailing_placeholder_tokens`` and ``_AsyncChunkRequestAdapter``. + +The actual ``send_chunk`` call is gated by ``uses_async_chunk_coordinator`` +(empty allowlist -> dormant) and exercised end-to-end at the allowlist flip; +that gate is covered separately in test_omni_scheduling_coordinator.py. +""" + +from types import SimpleNamespace + +from vllm_omni.worker.gpu_ar_model_runner import ( + _AsyncChunkRequestAdapter, + _strip_trailing_placeholder_tokens, +) + + +def test_strip_trailing_placeholder_tokens(): + assert _strip_trailing_placeholder_tokens(None) == [] + assert _strip_trailing_placeholder_tokens([]) == [] + assert _strip_trailing_placeholder_tokens([1, 2, 3]) == [1, 2, 3] + # Only trailing -1 placeholders are dropped; interior values are kept. + assert _strip_trailing_placeholder_tokens([1, 2, -1, -1]) == [1, 2] + assert _strip_trailing_placeholder_tokens([-1, -1]) == [] + assert _strip_trailing_placeholder_tokens([5, -1, 6]) == [5, -1, 6] + + +def _cached_state(): + return SimpleNamespace( + req_id="internal-1", + prompt_token_ids=[10, 11, 12], + output_token_ids=[20, 21, -1], + additional_information="add-info-sentinel", + ) + + +def test_adapter_exposes_request_identity_fields(): + inner = _cached_state() + adapter = _AsyncChunkRequestAdapter(inner, external_req_id="ext-99", finished=False) + + # external id is the passed cross-stage id; request_id/req_id mirror the inner req_id. + assert adapter.external_req_id == "ext-99" + assert adapter.request_id == "internal-1" + assert adapter.req_id == "internal-1" + + +def test_adapter_all_token_ids_strips_placeholder_output(): + inner = _cached_state() + adapter = _AsyncChunkRequestAdapter(inner, external_req_id="ext-99", finished=False) + + # output_token_ids drops the trailing -1; all_token_ids = prompt + stripped output. + assert adapter.output_token_ids == [20, 21] + assert adapter.prompt_token_ids == [10, 11, 12] + assert adapter.all_token_ids == [10, 11, 12, 20, 21] + + +def test_adapter_is_finished_reflects_constructor_flag(): + inner = _cached_state() + assert _AsyncChunkRequestAdapter(inner, external_req_id="e", finished=False).is_finished() is False + assert _AsyncChunkRequestAdapter(inner, external_req_id="e", finished=True).is_finished() is True + + +def test_adapter_delegates_unknown_attrs_to_inner(): + inner = _cached_state() + adapter = _AsyncChunkRequestAdapter(inner, external_req_id="e", finished=False) + # additional_information (used by speaker/language extractors) is not a + # declared property -> must delegate to the wrapped CachedRequestState. + assert adapter.additional_information == "add-info-sentinel" + + +def test_adapter_handles_empty_prompt(): + inner = SimpleNamespace(req_id="r", prompt_token_ids=None, output_token_ids=[7]) + adapter = _AsyncChunkRequestAdapter(inner, external_req_id="e", finished=False) + assert adapter.all_token_ids == [7] diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py index b8eb899fa32..b8f27460a11 100644 --- a/tests/worker/test_omni_connector_mixin.py +++ b/tests/worker/test_omni_connector_mixin.py @@ -20,6 +20,7 @@ from vllm_omni.outputs import OmniConnectorOutput from vllm_omni.worker.omni_connector_model_runner_mixin import ( OmniConnectorModelRunnerMixin, + _deep_merge_chunk_payload, ) pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -170,6 +171,117 @@ def broken_process(transfer_manager, pooling_output, request, is_finished=""): sender.shutdown_omni_connectors() + def test_send_chunk_skips_preempted_replay(self): + # PR4 MC-C.2: preemption dup-send guard parity with the adapter. + connector = MockConnector(stage_id=0) + sender = MixinHost() + sender.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0, async_chunk=True), + ) + sender._omni_connector = connector + sender._stage_id = 0 + sender._async_chunk = True + + calls = {"n": 0} + + def proc(transfer_manager, pooling_output, request, is_finished=False): + calls["n"] += 1 + return {"data": pooling_output} + + sender._custom_process_func = proc + + req = _make_request("req-1", "ext-1") + req.num_computed_tokens = 10 + self.assertTrue(sender.send_chunk(req, pooling_output={"v": 1})) + self.assertEqual(calls["n"], 1) + self.assertEqual(sender._requests_num_chunks_sent["ext-1"], 10) + self.assertEqual(sender._put_req_chunk["ext-1"], 1) + + # Preemption: committed-token count regresses below the watermark -> + # the replayed span must be skipped (no re-run, no chunk-id advance). + req.num_computed_tokens = 5 + self.assertTrue(sender.send_chunk(req, pooling_output={"v": 2})) + self.assertEqual(calls["n"], 1, "preempted replay must not re-run the process func") + self.assertEqual(sender._put_req_chunk["ext-1"], 1, "chunk id must not advance on a skipped replay") + self.assertEqual(sender._requests_num_chunks_sent["ext-1"], 10, "watermark unchanged on skip") + + # Forward progress past the watermark resumes sending. + req.num_computed_tokens = 15 + self.assertTrue(sender.send_chunk(req, pooling_output={"v": 3})) + self.assertEqual(calls["n"], 2) + self.assertEqual(sender._put_req_chunk["ext-1"], 2) + self.assertEqual(sender._requests_num_chunks_sent["ext-1"], 15) + + sender.shutdown_omni_connectors() + + def _quiesce_save_thread(self, sender): + # Stop the background save loop so the enqueued sentinel task stays in + # _pending_save_reqs for deterministic synchronous inspection. + sender._stop_event.set() + sender._work_available.set() + thread = getattr(sender, "_save_thread", None) + if thread is not None: + thread.join(timeout=2) + + def _last_enqueued_data(self, sender, request_id): + with sender._lock: + dq = sender._pending_save_reqs.get(request_id) + return dq[-1]["data"] if dq else None + + def test_finish_sentinel_falls_back_to_finished_flag(self): + # PR4 MC-C.3: hook returns None for an empty terminal -> mixin enqueues a + # bare finished=True flag so the downstream stage still terminates. + sender = MixinHost() + sender.init_omni_connectors(vllm_config=None, model_config=_make_model_config(stage_id=0, async_chunk=True)) + self._quiesce_save_thread(sender) + sender._omni_connector = MockConnector(stage_id=0) + sender._stage_id = 0 + sender._async_chunk = True + + calls = {"n": 0} + + def proc(transfer_manager, pooling_output, request, is_finished=False): + calls["n"] += 1 + return None # empty terminal -> hook produces nothing + + sender._custom_process_func = proc + sender._put_req_chunk["ext-1"] = 3 # already sent 3 data chunks + + req = SimpleNamespace(request_id="req-1", req_id="req-1", external_req_id="ext-1", is_finished=lambda: True) + self.assertTrue(sender.enqueue_finish_sentinel(req, "ext-1")) + self.assertEqual(calls["n"], 1, "hook must be consulted before the flag fallback") + self.assertEqual(sender._put_req_chunk["ext-1"], 4, "sentinel must take the next contiguous chunk id") + + data = self._last_enqueued_data(sender, "ext-1") + self.assertIsNotNone(data, "finish sentinel must be enqueued") + self.assertIn("meta", data) + self.assertTrue(bool(data["meta"]["finished"])) + + def test_finish_sentinel_prefers_hook_terminal_payload(self): + # When the hook DOES build a terminal payload (e.g. code2wav trailing + # codec flush), that payload is sent, not the bare flag. + sender = MixinHost() + sender.init_omni_connectors(vllm_config=None, model_config=_make_model_config(stage_id=0, async_chunk=True)) + self._quiesce_save_thread(sender) + sender._omni_connector = MockConnector(stage_id=0) + sender._stage_id = 0 + sender._async_chunk = True + + def proc(transfer_manager, pooling_output, request, is_finished=False): + return {"meta": {"finished": True}, "tail": "terminal-codec"} + + sender._custom_process_func = proc + sender._put_req_chunk["ext-9"] = 5 + + req = SimpleNamespace(request_id="r9", req_id="r9", external_req_id="ext-9", is_finished=lambda: True) + self.assertTrue(sender.enqueue_finish_sentinel(req, "ext-9")) + self.assertEqual(sender._put_req_chunk["ext-9"], 6) + + data = self._last_enqueued_data(sender, "ext-9") + self.assertIsNotNone(data) + self.assertEqual(data.get("tail"), "terminal-codec", "hook terminal payload must be preferred over the flag") + class TestMixinKVCacheTransfer(unittest.TestCase): """Test 3: KV cache delegation to OmniKVTransferManager.""" @@ -1440,5 +1552,49 @@ def run_one_loop(): sender.shutdown_omni_connectors() +class TestDeepMergeChunkPayload(unittest.TestCase): + """PR4 BUG2 regression: the recv cache must DEEP-merge nested dicts so a later + decode chunk's embed={decode} does not clobber chunk-0's embed={prefill} + (shallow dict.update lost the prefill -> KeyError 'prefill' in + talker_preprocess_prefill).""" + + def test_decode_chunk_does_not_clobber_chunk0_prefill(self): + existing = { + "embed": {"prefill": "P", "tts_bos": "B"}, + "hidden_states": {"output": "H"}, + "ids": [1, 2, 3], + "meta": {"finished": False, "left_context_size": 1}, + } + incoming = {"embed": {"decode": "D"}, "meta": {"finished": False}} + _deep_merge_chunk_payload(existing, incoming) + + # prefill survives AND decode is added. + self.assertEqual(existing["embed"], {"prefill": "P", "tts_bos": "B", "decode": "D"}) + # untouched top-level keys retained. + self.assertEqual(existing["hidden_states"], {"output": "H"}) + self.assertEqual(existing["ids"], [1, 2, 3]) + # non-finished meta keys retained. + self.assertEqual(existing["meta"]["left_context_size"], 1) + + def test_latest_decode_embed_wins(self): + existing = {"embed": {"prefill": "P", "decode": "D1"}} + _deep_merge_chunk_payload(existing, {"embed": {"decode": "D2"}}) + self.assertEqual(existing["embed"], {"prefill": "P", "decode": "D2"}) + + def test_meta_finished_not_overwritten_by_intermediate_chunk(self): + existing = {"meta": {"finished": True, "x": 1}} + _deep_merge_chunk_payload(existing, {"meta": {"finished": False, "y": 2}}) + # finished is skipped (mirrors adapter); other meta keys merge. + self.assertTrue(existing["meta"]["finished"]) + self.assertEqual(existing["meta"]["y"], 2) + + def test_non_dict_value_replaced_and_new_key_added(self): + existing = {"a": 1, "nested": {"k": "v"}} + _deep_merge_chunk_payload(existing, {"a": 2, "b": 3}) + self.assertEqual(existing["a"], 2) + self.assertEqual(existing["b"], 3) + self.assertEqual(existing["nested"], {"k": "v"}) + + if __name__ == "__main__": unittest.main() diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index 2584a7242f4..031c81e2f44 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -22,6 +22,14 @@ from vllm_omni.engine import AdditionalInformationEntry, AdditionalInformationPayload +# PR4 MC-C.3b: marker key the runner-level finish sentinel sets on its (otherwise +# empty) ``pooling_output`` so a model's async-chunk stage-input hook can flush a +# terminal payload (e.g. code2wav's trailing partial codec). The legacy +# scheduler-driven ``OmniChunkTransferAdapter`` never sets it, so adapter-driven +# hook calls are unaffected. +ASYNC_FINISH_SENTINEL_KEY = "__async_finish_sentinel__" + + class HiddenStates(TypedDict, total=False): output: torch.Tensor trailing_text: torch.Tensor diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index c2b33368d94..e2f727c2e87 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -12,6 +12,7 @@ from vllm.platforms import current_platform from vllm_omni.data_entry_keys import ( + ASYNC_FINISH_SENTINEL_KEY, CodesStruct, EmbeddingsStruct, HiddenStatesStruct, @@ -799,6 +800,43 @@ def thinker2talker_token_only( # ========================= +def _code2wav_codec_config(transfer_manager: Any) -> tuple[int, int]: + connector = getattr(transfer_manager, "connector", None) + raw_cfg = getattr(connector, "config", {}) or {} + cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} + return int(cfg.get("codec_chunk_frames", 25)), int(cfg.get("codec_left_context_frames", 25)) + + +def _flush_code2wav_finish_tail(transfer_manager: Any, request: OmniEngineCoreRequest) -> OmniPayloadStruct: + """PR4 MC-C.3b: terminal payload for the runner's async finish sentinel. + + The producer runner sends every in-step codec chunk with ``finished=False`` + (it cannot know finish at sample time), so the trailing partial chunk that the + live ``is_finished`` branch below would have flushed is still held in + ``code_prompt_token_ids``. This flushes that tail reusing the SAME math, + WITHOUT re-appending (the codes were already appended by the in-step sends). + If the request finished on a chunk boundary (no unsent tail), emit a + finish-only flag instead of re-sending the last full chunk. + """ + chunk_size_config, left_context_size_config = _code2wav_codec_config(transfer_manager) + request_id = request.external_req_id + length = len(transfer_manager.code_prompt_token_ids.get(request_id, [])) + chunk_length = length % chunk_size_config if chunk_size_config else 0 + finished_flag = torch.tensor(True, dtype=torch.bool) + if length == 0 or chunk_length == 0: + # Boundary / nothing held: the last full chunk was already sent in-step. + return OmniPayloadStruct(meta=MetaStruct(finished=finished_flag)) + + context_length = chunk_length + left_context_size = max(0, min(length - context_length, left_context_size_config)) + end_index = min(length, left_context_size + context_length) + codes = torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]).transpose(0, 1).reshape(-1) + return OmniPayloadStruct( + codes=CodesStruct(audio=codes), + meta=MetaStruct(left_context_size=left_context_size, finished=finished_flag), + ) + + def talker2code2wav_async_chunk( transfer_manager: Any, pooling_output: OmniPayload, @@ -810,6 +848,10 @@ def talker2code2wav_async_chunk( """ if not isinstance(pooling_output, dict): return None + # PR4 MC-C.3b: runner finish sentinel (empty payload + marker the adapter + # never sets) -> flush the held trailing partial codec as the terminal chunk. + if pooling_output.get(ASYNC_FINISH_SENTINEL_KEY): + return _flush_code2wav_finish_tail(transfer_manager, request) talker_codes = pooling_output.get("codes", {}) if not isinstance(talker_codes, dict): return None diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index a250e1f2ca8..6bbf8ec0c3c 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -9,6 +9,7 @@ from contextlib import nullcontext from copy import copy from dataclasses import replace +from types import SimpleNamespace from typing import Any, NamedTuple import numpy as np @@ -36,6 +37,7 @@ from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices from vllm.v1.worker.utils import is_residual_scattered_for_sp +from vllm_omni.core.sched.omni_scheduling_coordinator import uses_async_chunk_coordinator from vllm_omni.data_entry_keys import flatten_payload from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.outputs import OmniModelRunnerOutput @@ -46,6 +48,68 @@ logger = init_logger(__name__) +def _strip_trailing_placeholder_tokens(token_ids: list[int] | None) -> list[int]: + """Drop trailing ``-1`` placeholder tokens that pad ``output_token_ids`` + before the real sampled token is written back.""" + if not token_ids: + return [] + end = len(token_ids) + while end > 0 and int(token_ids[end - 1]) == -1: + end -= 1 + return list(token_ids[:end]) + + +class _AsyncChunkRequestAdapter: + """Wrap a ``CachedRequestState`` to expose the attributes the async-chunk + stage-input processors expect (``external_req_id``, ``all_token_ids``, + ``is_finished()``). + + ``CachedRequestState`` uses ``req_id`` while the processors expect + ``request_id`` / ``external_req_id``; this bridges that gap without mutating + the wrapped state. Any other attribute read (e.g. ``additional_information`` + for speaker/language) delegates to the inner state. + """ + + __slots__ = ("_inner", "_external_req_id", "_finished") + + def __init__(self, cached_state: Any, external_req_id: str, finished: bool): + self._inner = cached_state + self._external_req_id = external_req_id + self._finished = finished + + @property + def external_req_id(self) -> str: + return self._external_req_id + + @property + def request_id(self) -> str: # for send_chunk + return self._inner.req_id + + @property + def req_id(self) -> str: + return self._inner.req_id + + @property + def prompt_token_ids(self) -> list[int] | None: + return self._inner.prompt_token_ids + + @property + def output_token_ids(self) -> list[int]: + return _strip_trailing_placeholder_tokens(self._inner.output_token_ids) + + @property + def all_token_ids(self) -> list[int]: + prompt = self._inner.prompt_token_ids or [] + return list(prompt) + self.output_token_ids + + def is_finished(self) -> bool: + return self._finished + + def __getattr__(self, name: str) -> Any: + # Delegate everything else to the inner CachedRequestState. + return getattr(self._inner, name) + + class ExecuteModelState(NamedTuple): scheduler_output: SchedulerOutput logits: torch.Tensor | None @@ -191,6 +255,46 @@ def _request_needs_downstream_stage_payload(self, req_id: str) -> bool: self._downstream_payload_cache[req_id] = needs_payload return needs_payload + def _resolve_transfer_request_id(self, req_id: str) -> str: + """Resolve the cross-stage (external) request id used for connector + keys: prefer the recv-registration mapping, else the request state's + ``external_req_id``, else the local ``req_id``.""" + mapped = self._request_ids_mapping.get(req_id) + if mapped is not None: + return mapped + req_state = self.requests.get(req_id) + if req_state is None: + return req_id + external_req_id = getattr(req_state, "external_req_id", None) + if external_req_id is not None: + return str(external_req_id) + return req_id + + def _send_async_chunk_finish_sentinels(self, finished_req_ids: set[str]) -> None: + """Emit one terminal chunk per just-finished async-chunk request. + + In-step ``send_chunk`` calls run with ``finished=False`` (the runner does + not yet know a request will stop at ``sample_tokens`` time). This runs at + the start of the next ``execute_model`` cycle, when + ``scheduler_output.finished_req_ids`` lists the requests that finished in + the previous cycle, and delegates to the mixin's generic + ``enqueue_finish_sentinel`` (model hook builds the terminal payload, else a + bare ``finished=True`` flag). The request may already be freed, so it + rebuilds from the send-time snapshot with ``is_finished`` forced True. + """ + for rid in finished_req_ids: + ext_id = self._request_ids_mapping.get(rid) or self._resolve_transfer_request_id(rid) + # Only requests that actually sent ≥1 chunk need a terminal. + if ext_id not in self._put_req_chunk: + continue + snapshot = self._send_side_request_snapshot.get(ext_id) + if snapshot is not None: + snapshot.is_finished = lambda: True + request: Any = snapshot + else: + request = SimpleNamespace(request_id=rid, req_id=rid, external_req_id=ext_id, is_finished=lambda: True) + self.enqueue_finish_sentinel(request, ext_id) + def _resolve_pooler_payload_req_ids(self, req_ids_output_copy: list[str]) -> tuple[str, list[str]]: downstream_req_ids = [rid for rid in req_ids_output_copy if self._request_needs_downstream_stage_payload(rid)] engine_output_type = (self.vllm_config.model_config.engine_output_type or "").lower() @@ -444,6 +548,14 @@ def execute_model( flush_ids.update({rid for rid in self._pending_full_payload_send if rid not in self.requests}) if flush_ids: self.flush_full_payload_outputs(flush_ids) + # PR4 MC-C.3: emit the terminal chunk for async-chunk requests that + # finished last cycle (in-step sends were finished=False; the runner + # learns finish only here). Before _update_states frees them. + # Dormant until the allowlist flip (predicate False -> no-op). + if uses_async_chunk_coordinator(self.model_config): + async_finished = set(getattr(scheduler_output, "finished_req_ids", set())) + if async_finished: + self._send_async_chunk_finish_sentinels(async_finished) if self.omni_prefix_cache is not None and scheduler_output.finished_req_ids: self.omni_prefix_cache.commit_deferred_mm_outputs( @@ -1165,6 +1277,31 @@ def _unwrap_lists(v): if req_state is not None and pooler_output[i]: self.accumulate_full_payload_output(rid, pooler_output[i], req_state) + # PR4 MC-C.1: runner-side async-chunk producer send. Gated by the SAME + # predicate the scheduler uses to select the coordinator (arch + # allowlisted AND SharedMemoryConnector), NOT a bare async_chunk flag -- + # a non-coordinator async arch keeps the legacy adapter (scheduler-side + # save_async), so a bare-flag guard here would double-send. Only the + # V5 downstream-pooler-filtered requests are sent (never an ungated + # terminal None -> starvation). is_finished is always False here: the + # runner cannot know finish at sample_tokens time (engine core decides + # after), so the terminal is handled separately (later sub-commit), + # mirroring the adapter. With the empty allowlist this branch is + # dormant (uses_async_chunk_coordinator always returns False). + if pooler_output and uses_async_chunk_coordinator(self.model_config): + for i, rid in enumerate(req_ids_output_copy): + if rid not in downstream_req_id_set or not pooler_output[i]: + continue + req_state = self.requests.get(rid) + if req_state is None: + continue + ext_id = self._resolve_transfer_request_id(rid) + wrapped = _AsyncChunkRequestAdapter(req_state, external_req_id=ext_id, finished=False) + # Snapshot for the next-cycle finish sentinel: the live request + # may be freed by _update_states before we learn it finished. + self._send_side_request_snapshot[ext_id] = self._snapshot_request_for_send(wrapped, ext_id) + self.send_chunk(request=wrapped, pooling_output=pooler_output[i]) + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): routed_experts_lists = None if self.routed_experts_initialized: diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index cfe3c55b459..a9404d05d5d 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -24,7 +24,7 @@ from vllm.distributed.parallel_state import get_tp_group from vllm.logger import init_logger -from vllm_omni.data_entry_keys import OmniPayload +from vllm_omni.data_entry_keys import ASYNC_FINISH_SENTINEL_KEY, OmniPayload from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec from vllm_omni.outputs import OmniConnectorOutput @@ -46,6 +46,30 @@ logger = init_logger(__name__) +def _deep_merge_chunk_payload(existing: dict, incoming: dict) -> None: + """Merge one received async-chunk payload into the cached payload in place. + + Nested dict values (e.g. ``embed``) are merged KEY-BY-KEY, not replaced: a + later decode chunk's ``embed={decode: ...}`` must not clobber chunk-0's + ``embed={prefill: ..., tts_bos: ...}`` (a shallow ``dict.update`` replaces the + whole ``embed`` sub-dict, losing ``prefill`` -> ``KeyError: 'prefill'`` in + ``talker_preprocess_prefill``). An intermediate chunk's ``meta.finished`` is + not allowed to overwrite, mirroring the adapter recv merge + (``OmniChunkTransferAdapter._update_request_payload``). + """ + for key, value in incoming.items(): + if isinstance(value, dict): + sub = existing.get(key) + merged = dict(sub) if isinstance(sub, dict) else {} + for sub_key, sub_val in value.items(): + if key == "meta" and sub_key == "finished": + continue + merged[sub_key] = sub_val + existing[key] = merged + else: + existing[key] = value + + def should_accumulate_full_payload_output(model_config, custom_process_func) -> bool: """Producer-side structural gate. @@ -136,12 +160,21 @@ def init_omni_connectors( # -- chunk index tracking (ported from OmniChunkTransferAdapter) -- self._put_req_chunk: dict[str, int] = defaultdict(int) self._get_req_chunk: dict[str, int] = defaultdict(int) + # Preemption dup-send watermark (parity with the adapter's + # ``requests_num_chunks_sent``): committed-token count at the last send + # per external req id; a send whose confirmed count regresses below the + # watermark is a preempt/replay and is skipped. + self._requests_num_chunks_sent: dict[str, int] = defaultdict(int) # Send-side async accumulation / staging buffer. Receive-side payload # ownership lives in ``_local_stage_payload_cache``. self._send_side_request_payload: dict[str, dict[str, Any]] = {} self._code_prompt_token_ids: dict[str, list[list[int]]] = defaultdict(list) self._cached_ic: dict[str, int] = {} self._request_ids_mapping: dict[str, str] = {} + # Frozen request snapshot captured at send time so the next-cycle finish + # sentinel can rebuild its terminal payload after ``_update_states`` has + # freed the live request (keyed by external req id). + self._send_side_request_snapshot: dict[str, Any] = {} # -- async I/O state (shared by chunk + full_payload_mode) -- self._pending_load_reqs: dict[str, Any] = {} @@ -282,6 +315,8 @@ def cleanup_finished_request(self, req_id: str) -> None: if k in keys_pending: continue self._put_req_chunk.pop(k, None) + self._requests_num_chunks_sent.pop(k, None) + self._send_side_request_snapshot.pop(k, None) self._send_side_request_payload.pop(k, None) self._code_prompt_token_ids.pop(k, None) self._cached_ic.pop(k, None) @@ -1108,6 +1143,21 @@ def recv_chunk(self) -> dict[str, Any]: self._chunk_ready_req_ids.update(finished) return result + @staticmethod + def _confirmed_num_computed_tokens(request: Any) -> int: + """Committed-token watermark for the preemption dup-send guard. + + vLLM async scheduling advances ``num_computed_tokens`` with output + placeholders before the corresponding token is committed; the send + watermark must count only committed tokens. ``num_output_placeholders`` + is absent on the worker ``CachedRequestState`` (defaults to 0), which is + harmless: the guard is relative (compare to the stored watermark), so a + consistent count is all that matters. + """ + num_computed = int(getattr(request, "num_computed_tokens", 0) or 0) + num_placeholders = int(getattr(request, "num_output_placeholders", 0) or 0) + return max(0, num_computed - num_placeholders) + def send_chunk( self, request: Any, @@ -1131,6 +1181,23 @@ def send_chunk( # resolve the external ID even after the request is freed. if raw_req_id and raw_req_id != request_id: self._request_ids_mapping.setdefault(raw_req_id, request_id) + + # Preemption dup-send guard (parity with OmniChunkTransferAdapter.save_async): + # if the request was preempted, its committed-token count regresses below + # what we already sent up to -- skip the replayed span instead of + # re-sending (which would duplicate audio/embed downstream). + confirmed_num_computed_tokens = self._confirmed_num_computed_tokens(request) + if confirmed_num_computed_tokens < self._requests_num_chunks_sent.get(request_id, 0): + logger.warning( + "[Stage-%s] send_chunk: skip replayed span for req=%s (confirmed=%s < previously_sent=%s)", + self._stage_id, + request_id, + confirmed_num_computed_tokens, + self._requests_num_chunks_sent.get(request_id, 0), + ) + return True + self._requests_num_chunks_sent[request_id] = confirmed_num_computed_tokens + chunk_id = self._put_req_chunk[request_id] payload_data = self._build_custom_process_payload( @@ -1140,7 +1207,7 @@ def send_chunk( ) if payload_data is None: if chunk_id == 0: - logger.warning( + logger.debug( "[Stage-%s] send_chunk: payload is None for req=%s chunk=%s (process_func=%s)", self._stage_id, request_id, @@ -1154,7 +1221,7 @@ def send_chunk( connector_put_key = f"{request_id}_{self._stage_id}_{chunk_id}" if chunk_id == 0: - logger.info( + logger.debug( "[Stage-%s] send_chunk: first chunk enqueued, req=%s key=%s", self._stage_id, request_id, @@ -1174,6 +1241,51 @@ def send_chunk( self._work_available.set() return True + def enqueue_finish_sentinel(self, request: Any, request_id: str) -> bool: + """Enqueue the terminal chunk for a finished async-chunk request. + + The producer runner cannot know at ``sample_tokens`` time that a request + will finish (engine core decides after), so the in-step chunks are all + sent ``finished=False``; this is called next cycle once finish is known. + + It first asks the model hook to build the terminal payload (``request`` + must report ``is_finished()`` True, ``pooling_output`` is empty) so that + model-specific terminal content (e.g. code2wav's trailing partial codec + flush) is preserved; if the hook produces nothing it falls back to a bare + ``finished=True`` flag so the downstream stage still terminates. Enqueued + at the next contiguous chunk id via the same background save path as + ``send_chunk``. + """ + if self._omni_connector is None or not self.is_data_transfer_rank(): + return True + + # The empty pooling_output carries the finish-sentinel marker so a model + # hook can emit model-specific terminal content (e.g. code2wav's trailing + # partial codec); the legacy adapter never sets this marker. + payload_data = self._build_custom_process_payload( + request_id=request_id, + request=request, + pooling_output={ASYNC_FINISH_SENTINEL_KEY: True}, + ) + if payload_data is None: + payload_data = {"meta": {"finished": torch.tensor(True, dtype=torch.bool)}} + + chunk_id = self._put_req_chunk[request_id] + self._put_req_chunk[request_id] += 1 + connector_put_key = f"{request_id}_{self._stage_id}_{chunk_id}" + task = { + "stage_id": self._stage_id, + "next_stage_id": self._next_stage_id, + "put_key": connector_put_key, + "data": payload_data, + "request_id": request_id, + } + with self._lock: + self._pending_save_reqs.setdefault(request_id, deque()).append(task) + self._pending_save_counts[request_id] += 1 + self._work_available.set() + return True + # ------------------------------------------------------------------ # # KV cache (delegates to OmniKVTransferManager) # ------------------------------------------------------------------ # @@ -1578,7 +1690,7 @@ def get_omni_connector_output(self) -> OmniConnectorOutput: has_pending_kv_work=self.has_pending_kv_work(), ) if output.stage_recv_req_ids or chunk_finished or newly_finished: - logger.info( + logger.debug( "[Stage-%s] get_omni_connector_output: stage_recv=%s, chunk_finished=%s, chunk_ready=%s", self._stage_id, output.stage_recv_req_ids, @@ -1786,7 +1898,7 @@ def _poll_single_request(self, req_id: str) -> bool: if not payload_data: return False if isinstance(payload_data, dict): - logger.info( + logger.debug( "[Stage-%s] recv_chunk_result: req=%s ext=%s key=%s keys=%s finished=%s", self._stage_id, req_id, @@ -1822,12 +1934,14 @@ def _poll_single_request(self, req_id: str) -> bool: if is_finished: self._chunk_finished_req_ids.add(req_id) self._chunk_stream_completed.add(req_id) - # Local cache (RFC §2.4) — merge, don't replace, so that - # earlier chunk keys (e.g. thinker_prefill_embeddings from - # chunk 0) are not overwritten by later chunks. + # Local cache (RFC §2.4) — DEEP-merge, don't replace, so that + # earlier chunk keys (e.g. chunk-0's embed.prefill) survive when a + # later decode chunk arrives with embed.decode. Shallow dict.update + # replaces the nested 'embed' wholesale -> lost prefill -> KeyError + # in talker_preprocess_prefill. existing = self._local_stage_payload_cache.get(req_id) if existing is not None and isinstance(existing, dict) and isinstance(payload_data, dict): - existing.update(payload_data) + _deep_merge_chunk_payload(existing, payload_data) else: self._local_stage_payload_cache[req_id] = payload_data staged_payload = self._local_stage_payload_cache[req_id] @@ -1981,7 +2095,7 @@ def _send_single_request(self, task: dict) -> bool: put_key=put_key, data=payload_data, ) - logger.info( + logger.debug( "[Stage-%s] _send_single_request: put_key=%s success=%s size=%s", task["stage_id"], put_key, @@ -2009,6 +2123,8 @@ def _decrement_pending_save_count(self, request_id: str) -> None: cleanup_req_id = request_id if cleanup_req_id is not None: self._put_req_chunk.pop(cleanup_req_id, None) + self._requests_num_chunks_sent.pop(cleanup_req_id, None) + self._send_side_request_snapshot.pop(cleanup_req_id, None) self._send_side_request_payload.pop(cleanup_req_id, None) self._code_prompt_token_ids.pop(cleanup_req_id, None) self._cached_ic.pop(cleanup_req_id, None) From a45282cfd2986d6a1752fd4739c45b38befa255a Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 3 Jun 2026 13:58:00 +0000 Subject: [PATCH 03/10] [PR4] enable qwen3 omni async-chunk coordinator path Flip qwen3_omni async-chunk stages onto the coordinator+mixin transport, preserve audio finish behavior, and keep allowlist-active branches safe for partial unit-test mocks. Signed-off-by: natureofnature --- vllm_omni/core/sched/omni_ar_scheduler.py | 7 + .../core/sched/omni_generation_scheduler.py | 236 +++++++++++++++++- .../core/sched/omni_scheduling_coordinator.py | 7 +- .../models/qwen3_omni/qwen3_omni.py | 55 ++-- .../stage_input_processors/qwen3_omni.py | 8 + vllm_omni/worker/gpu_ar_model_runner.py | 8 +- vllm_omni/worker/gpu_model_runner.py | 11 + 7 files changed, 299 insertions(+), 33 deletions(-) diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 2dfef93f761..41c18190b5c 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -270,6 +270,13 @@ def schedule(self) -> SchedulerOutput: # type: ignore[override] scheduler_output.scheduled_new_reqs = new_list # type: ignore[assignment] if self.chunk_transfer_adapter: self.chunk_transfer_adapter.postprocess_scheduler_output(scheduler_output, self.requests) + if self.input_coordinator: + # PR4: mirror the adapter postprocess on the coordinator path so the + # per-cycle ready-chunk flags (requests_with_ready_chunks) are cleared + # after each scheduler step. Without this, a streamed multi-chunk request + # stays flagged "ready" after its first chunk and never re-enters the + # per-cycle wait/clear cadence the adapter path gets for free. + self.input_coordinator.postprocess_scheduler_output(scheduler_output, self.requests) # Add information about requests needing KV cache transfer finished_reqs = self.get_finished_requests_needing_kv_transfer() except Exception: diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 517442e8349..1281d20f17f 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -40,13 +40,22 @@ logger = init_logger(__name__) +def _extend_all_token_ids_if_available(request: Request, padding: int) -> None: + if padding <= 0: + return + all_token_ids = getattr(request, "_all_token_ids", None) + if isinstance(all_token_ids, list): + all_token_ids.extend([0] * padding) + + class OmniGenerationScheduler(OmniSchedulerMixin, VLLMScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) model_config = self.vllm_config.model_config # PR4: see OmniARScheduler.__init__ -- allowlisted async-chunk archs # (SharedMemory) use the coordinator + runner mixin; others keep the adapter. - # Empty allowlist today => _async_coord False => no behavior change. + # _async_coord is True only for allowlisted stages (qwen3_omni + # talker+code2wav); False (no behavior change) for everything else. _async_coord = uses_async_chunk_coordinator(model_config) # When True, stage completion follows the coordinator's terminal chunk # signal (input_coordinator.finished_requests), NOT the plain-generation @@ -66,6 +75,14 @@ def __init__(self, *args, **kwargs): async_chunk=_async_coord, ) self._latest_omni_connector_output: OmniConnectorOutput | None = None + # PR4: consumer-side terminal-completeness trackers (async-chunk coordinator + # path only). Code2wav's terminal stage must finish on the producer's true + # total code length, not on chunks-arrived-so-far; these carry the terminal + # chunk marker / metadata across the finish-only connector cycle and remember + # which requests have emitted a pooler payload. Ported from tmp/trim_refactor. + self._deferred_terminal_chunk_req_ids: set[str] = set() + self._deferred_terminal_request_metadata: dict[str, dict] = {} + self._reqs_with_pooler_history: set[str] = set() def schedule(self) -> SchedulerOutput: """Diffusion fast path: @@ -92,6 +109,26 @@ def schedule(self) -> SchedulerOutput: cached_prompt_token_ids: dict[str, list[int]] = {} cached_additional_information: dict[str, dict | None] = {} + def _ensure_terminal_placeholder(request: Request) -> int: + # PR4: a terminal-ready coordinator request can still have a connector- + # delivered payload queued even when its prompt length has not grown. + # Grow prompt_token_ids by one placeholder so the request is scheduled + # for a one-token step that drains the ready payload, instead of being + # stranded behind required_tokens == 0. Ported from tmp/trim_refactor. + if not isinstance(request.prompt_token_ids, list): + request.prompt_token_ids = list(request.prompt_token_ids) + current_len = len(request.prompt_token_ids) + target_len = max(current_len, request.num_computed_tokens + 1) + max_model_len = getattr(self, "max_model_len", None) + if isinstance(max_model_len, int) and max_model_len > 0: + target_len = min(target_len, max_model_len) + missing = target_len - current_len + if missing > 0: + request.prompt_token_ids.extend([0] * missing) + request.num_prompt_tokens = len(request.prompt_token_ids) + _extend_all_token_ids_if_available(request, missing) + return max(0, len(request.prompt_token_ids) - request.num_computed_tokens) + # Temporary queue: preserve waiting order, do not disturb non-diffusion requests skipped_waiting_requests = create_request_queue(self.policy) req_index = 0 @@ -121,12 +158,45 @@ def schedule(self) -> SchedulerOutput: break # async_chunk: don't schedule placeholder tokens when no new chunk is available. if required_tokens <= 0: - if self.chunk_transfer_adapter is not None and self.chunk_transfer_adapter.is_done_receiving_chunks( - request.request_id - ): - self._pending_finish_reqs.append(request) - req_index += 1 - continue + if self.chunk_transfer_adapter is not None: + # Adapter path (legacy, unchanged): defer the finish to + # update_from_output via _pending_finish_reqs. + if self.chunk_transfer_adapter.is_done_receiving_chunks(request.request_id): + self._pending_finish_reqs.append(request) + req_index += 1 + continue + if self._async_chunk_coordinator_active and self.input_coordinator is not None: + # Coordinator drain path: a terminal-ready request may still have + # a connector-delivered payload queued even though prompt length + # has not grown. Schedule a one-token placeholder so the terminal + # stage can drain it; only complete-immediately when nothing is + # left to drain. The finish itself is deferred to + # update_from_output via _pending_finish_reqs (parity with the + # adapter path above) so a real EngineCoreOutput is emitted: a + # finish surfaced only through finished_requests is dropped by the + # orchestrator output poller (StagePool._poll_stage_raw), leaving + # the client waiting until the outer timeout aborts it. + if request.request_id in self.input_coordinator.requests_with_ready_chunks: + required_tokens = _ensure_terminal_placeholder(request) + if required_tokens <= 0: + request.stop_reason = ( + None if request.request_id in self.input_coordinator.finished_requests else "length" + ) + self._pending_finish_reqs.append(request) + req_index += 1 + continue + # required_tokens > 0: fall through to allocation below so the + # placeholder step actually runs and drains the ready payload. + elif request.request_id in self.input_coordinator.finished_requests: + self._pending_finish_reqs.append(request) + req_index += 1 + continue + else: + req_index += 1 + continue + else: + req_index += 1 + continue num_new_tokens = min(required_tokens, token_budget) new_blocks = self.kv_cache_manager.allocate_slots( request, @@ -181,6 +251,30 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.prepend_request(request) continue + # async-chunk coordinator path: same empty-prompt guard, but a + # terminal-ready request still gets a one-token placeholder so the + # terminal stage can drain its ready payload before finishing. + if ( + self._async_chunk_coordinator_active + and self.input_coordinator is not None + and len(request.prompt_token_ids) == 0 + ): + if ( + request.request_id in self.input_coordinator.finished_requests + and request.request_id not in self.input_coordinator.requests_with_ready_chunks + ): + request = self.waiting.pop_request() + self._pending_finish_reqs.append(request) + continue + if request.request_id in self.input_coordinator.finished_requests: + # Terminal chunk arrived but a payload is still queued: grow a + # placeholder and fall through to schedule the drain step. + _ensure_terminal_placeholder(request) + else: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + # Uniformly treat as diffusion. A feature flag can be added later # via config or request tag. @@ -340,6 +434,10 @@ def schedule(self) -> SchedulerOutput: if self.chunk_transfer_adapter: self.chunk_transfer_adapter.postprocess_scheduler_output(scheduler_output) + if self.input_coordinator: + # PR4: mirror the adapter postprocess on the coordinator path so per-cycle + # ready-chunk flags are cleared (see omni_ar_scheduler for full rationale). + self.input_coordinator.postprocess_scheduler_output(scheduler_output) except Exception: # If anything goes wrong, leave the original output unchanged @@ -382,6 +480,12 @@ def finish_requests(self, request_ids, finished_status: RequestStatus) -> list[t return finished def _free_request(self, request: Request, delay_free_blocks: bool = False) -> dict[str, Any] | None: + # PR4: drop consumer-side terminal-completeness trackers for the freed + # request (no-op outside the async-chunk coordinator path). + self._deferred_terminal_chunk_req_ids.discard(request.request_id) + self._deferred_terminal_request_metadata.pop(request.request_id, None) + self._reqs_with_pooler_history.discard(request.request_id) + if self.input_coordinator is None: return super()._free_request(request, delay_free_blocks) @@ -416,6 +520,96 @@ def update_from_output( num_nans_in_logits = model_runner_output.num_nans_in_logits kv_connector_output = model_runner_output.kv_connector_output + # PR4: read this cycle's terminal-chunk signal and producer metadata. + # Only consumed on the async-chunk coordinator path; for the adapter / + # full-payload paths these stay empty and the blocks below are skipped. + omni_output = getattr(model_runner_output, "omni_connector_output", None) + chunk_finished_req_ids_now = set(omni_output.chunk_finished_req_ids) if omni_output else set() + chunk_ready_req_ids_now = set(omni_output.chunk_ready_req_ids) if omni_output else set() + request_metadata_now = ( + dict(omni_output.request_metadata) if omni_output and omni_output.request_metadata else {} + ) + + # Under async scheduling a terminal chunk-finished signal can arrive in an + # empty connector-only output one cycle before the final pooler payload. + # Preserve that terminal marker and merge it into the next real model + # output for the same request so the last audio frame and terminal state + # are evaluated together. Gated on the async-chunk coordinator path so the + # adapter / full-payload deployments are untouched. + if ( + self._async_chunk_coordinator_active + and chunk_finished_req_ids_now + and not getattr(model_runner_output, "req_ids", None) + and not pooler_outputs + ): + handled_finish_only_req_ids: set[str] = set() + for req_id in set(chunk_finished_req_ids_now): + request = self.requests.get(req_id) + if request is None or req_id not in self._reqs_with_pooler_history: + continue + terminal_prompt_len = len(request.prompt_token_ids) + if req_metadata := request_metadata_now.get(req_id): + code_predictor_codes = req_metadata.get("code_predictor_codes") + if code_predictor_codes is not None and len(code_predictor_codes) > 0: + terminal_prompt_len = len(code_predictor_codes) + else: + next_stage_prompt_len = req_metadata.get("next_stage_prompt_len") + if isinstance(next_stage_prompt_len, int) and next_stage_prompt_len > 0: + terminal_prompt_len = next_stage_prompt_len + if request.num_computed_tokens < terminal_prompt_len: + continue + # A consumable terminal chunk (real trailing codec) co-arrived this + # cycle alongside the finish signal; its decode step runs next cycle. + # Defer the finish so that decode emits the final audio frame WITH + # finish via the reached_terminal_chunk path (one combined output), + # matching the adapter. Without this the last received chunk is never + # decoded -- the finish surfaces as a bare empty audio-typed output and + # crashes _create_audio_choice ("no audio produced", HTTP 400). + # Finish-only sentinels (no codes) are NOT deferred: there is no audio + # frame to wait for, so the hard termination guarantee is preserved. + _co_arrived_codes = request_metadata_now.get(req_id, {}).get("code_predictor_codes") + if req_id in chunk_ready_req_ids_now and _co_arrived_codes is not None and len(_co_arrived_codes) > 0: + continue + # Defer the terminal finish to the _pending_finish_reqs drain below + # so a real EngineCoreOutput is emitted (parity with the adapter and + # the schedule()-side coordinator finishes). A finish surfaced only + # via finished_requests is dropped by StagePool._poll_stage_raw, + # leaving the client waiting until the outer timeout aborts it. + self._pending_finish_reqs.append(request) + handled_finish_only_req_ids.add(req_id) + remaining_ids = set(chunk_finished_req_ids_now) - handled_finish_only_req_ids + if remaining_ids: + self._deferred_terminal_chunk_req_ids.update(remaining_ids) + for req_id in remaining_ids: + if req_id in request_metadata_now: + self._deferred_terminal_request_metadata[req_id] = dict(request_metadata_now[req_id]) + chunk_finished_req_ids_now = remaining_ids + elif ( + self._async_chunk_coordinator_active + and self._deferred_terminal_chunk_req_ids + and getattr(model_runner_output, "req_ids", None) + ): + carried_req_ids = self._deferred_terminal_chunk_req_ids.intersection( + getattr(model_runner_output, "req_ids", None) + ) + if carried_req_ids: + chunk_finished_req_ids_now.update(carried_req_ids) + for req_id in carried_req_ids: + if req_id not in request_metadata_now: + if (deferred_meta := self._deferred_terminal_request_metadata.get(req_id)) is not None: + request_metadata_now[req_id] = dict(deferred_meta) + self._deferred_terminal_chunk_req_ids.discard(req_id) + self._deferred_terminal_request_metadata.pop(req_id, None) + + # Sweep stale deferred entries for requests no longer tracked, preventing + # unbounded growth if a request disappears without being freed. + if self._deferred_terminal_chunk_req_ids: + stale = self._deferred_terminal_chunk_req_ids - set(self.requests) + if stale: + self._deferred_terminal_chunk_req_ids -= stale + for rid in stale: + self._deferred_terminal_request_metadata.pop(rid, None) + cudagraph_stats: CUDAGraphStat | None = model_runner_output.cudagraph_stats perf_stats: PerfStats | None = None if self.perf_metrics and self.perf_metrics.is_enabled(): @@ -491,10 +685,33 @@ def update_from_output( new_token_ids = generated_token_ids kv_transfer_params = None pooler_output = pooler_outputs[req_index] if pooler_outputs else None + if pooler_output is not None: + self._reqs_with_pooler_history.add(req_id) status_before_stop = request.status finish_reason = None routed_experts = None + # PR4: terminal-completeness gate (async-chunk coordinator path only). + # The true total code length comes from producer metadata -- not the + # chunks-arrived-so-far prompt length -- so resolve it before deciding + # the request is done. Defaults to the current prompt length for every + # other path (request_metadata_now is empty there). + terminal_prompt_len = len(request.prompt_token_ids) + if req_metadata := request_metadata_now.get(req_id): + code_predictor_codes = req_metadata.get("code_predictor_codes") + if code_predictor_codes is not None and len(code_predictor_codes) > 0: + terminal_prompt_len = len(code_predictor_codes) + else: + next_stage_prompt_len = req_metadata.get("next_stage_prompt_len") + if isinstance(next_stage_prompt_len, int) and next_stage_prompt_len > 0: + terminal_prompt_len = next_stage_prompt_len + reached_terminal_chunk = ( + self._async_chunk_coordinator_active + and req_id in chunk_finished_req_ids_now + and pooler_output is not None + and request.num_computed_tokens >= terminal_prompt_len + ) + # Diffusion request: completes in one step; mark finished and free resources if ( request.status == RequestStatus.FINISHED_STOPPED @@ -517,6 +734,7 @@ def update_from_output( and request.request_id in self.input_coordinator.finished_requests and request.num_computed_tokens >= len(request.prompt_token_ids) ) + or reached_terminal_chunk ): request.status = RequestStatus.FINISHED_STOPPED # Optional: set a stop_reason for front-end clarity @@ -602,6 +820,10 @@ def update_from_output( request.request_id, getattr(request, "external_req_id", None), ) + # Coordinator path: release the request from the coordinator's + # tracking sets too (the adapter path uses cleanup() above). + if self._async_chunk_coordinator_active and self.input_coordinator is not None: + self.input_coordinator.free_finished_request(request.request_id) outputs[request.client_index].append( OmniEngineCoreOutput( request_id=request.request_id, diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index b185332d8f4..8ce090eecc5 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -87,7 +87,12 @@ def uses_full_payload_input_coordinator(model_config: Any) -> bool: # entries are the recv stages only -- stage-0 producers do not wait on chunks: # ("Qwen3OmniMoeForConditionalGeneration", "talker") # stage 1, recv from 0 # ("Qwen3OmniMoeForConditionalGeneration", "code2wav") # stage 2, recv from 1 -_ASYNC_CHUNK_COORDINATOR_STAGES: frozenset[tuple[str, str]] = frozenset() +_ASYNC_CHUNK_COORDINATOR_STAGES: frozenset[tuple[str, str]] = frozenset( + { + ("Qwen3OmniMoeForConditionalGeneration", "talker"), + ("Qwen3OmniMoeForConditionalGeneration", "code2wav"), + } +) def uses_async_chunk_coordinator(model_config: Any) -> bool: diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 4a144bff027..ac19ab70558 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -721,7 +721,14 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, ) update_dict["mtp_inputs"] = last_talker_hidden, text_step - update_dict.setdefault("meta", {})["num_processed_tokens"] = meta.get("num_processed_tokens", 0) + span_len + # PR4 async-chunk decode: the talker may emit a pad while waiting for the + # next thinker decode embed to stream in; in that case it must NOT advance + # num_processed_tokens (else the awaited token is skipped). Prefill and the + # non-async path leave the flag unset -> default True -> advance as before. + advance_num_processed_tokens = update_dict.pop("_advance_num_processed_tokens", True) + update_dict.setdefault("meta", {})["num_processed_tokens"] = meta.get("num_processed_tokens", 0) + ( + span_len if advance_num_processed_tokens else 0 + ) return input_ids, input_embeds, update_dict def talker_mtp( @@ -1016,33 +1023,39 @@ def _thinker_decode_to_talker_decode( embed = payload.get("embed", {}) meta = payload.get("meta", {}) - cached_thinker_decode_embeds = embed.get("cached_decode", None) thinker_decode_embed = embed.get("decode", None) start_index = meta.get("num_processed_tokens", 0) - - if cached_thinker_decode_embeds is not None and start_index < cached_thinker_decode_embeds.shape[0]: - cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device) - thinker_embed = cached_thinker_decode_embeds[start_index] - if thinker_decode_embed is not None: - thinker_decode_embed = thinker_decode_embed.to(device) - cached_thinker_decode_embeds = torch.cat([cached_thinker_decode_embeds, thinker_decode_embed], dim=0) - update_dict.setdefault("embed", {})["cached_decode"] = cached_thinker_decode_embeds - - elif thinker_decode_embed is not None: - thinker_embed = thinker_decode_embed - if thinker_embed.device != device: - thinker_embed = thinker_embed.to(device) - - else: - # When the tokens output by the thinker are exhausted, an EOS token needs to be appended. - # Use the finished_flag to mark that all tokens output by thinker have been consumed. + # PR4: the worker re-syncs the FULL cumulative thinker decode embeds into + # ``embed.decode`` every step (OmniChunkTransferAdapter parity via + # _accumulate_payload + _sync_local_stage_payloads), so consume by an + # advancing absolute index rather than popping the tensor (a pop would be + # clobbered by the next sync). ``embed.decode`` row 0 is the first + # POST-prefill thinker token, while ``num_processed_tokens`` counts from + # the prefill base, hence the ``- prefill_consumed_text_tokens`` offset. + base = meta.get("prefill_consumed_text_tokens", 0) + avail = thinker_decode_embed.shape[0] if isinstance(thinker_decode_embed, torch.Tensor) else 0 + idx = start_index - base + + if isinstance(thinker_decode_embed, torch.Tensor) and 0 <= idx < avail: + thinker_embed = thinker_decode_embed[idx].to(device) + update_dict["_advance_num_processed_tokens"] = True + return self.talker.text_projection(thinker_embed).to(device) + + # No (more) thinker decode embed available at this index yet. + if bool(meta.get("finished", False)) and idx >= avail: + # Thinker finished and the talker consumed everything: emit one EOS + # then pad. Do not advance past the terminal token. if meta.get("eos_emitted", False): + update_dict["_advance_num_processed_tokens"] = False return self.tts_pad_embed.to(device) update_dict.setdefault("meta", {})["eos_emitted"] = True + update_dict["_advance_num_processed_tokens"] = False return self.tts_eos_embed.to(device) - update_dict.setdefault("embed", {})["decode"] = None - return self.talker.text_projection(thinker_embed).to(device) + # Talker outran the thinker stream: emit pad and WAIT (do not advance), + # so the not-yet-arrived thinker token is consumed once it lands. + update_dict["_advance_num_processed_tokens"] = False + return self.tts_pad_embed.to(device) def talker_preprocess_decode( self, input_ids: torch.Tensor, input_embeds: torch.Tensor, update_dict: OmniPayload, payload: OmniPayload diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index e2f727c2e87..311ef746dd8 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -21,6 +21,7 @@ OmniPayload, OmniPayloadStruct, to_dict, + unflatten_payload, ) from vllm_omni.engine import OmniEngineCoreRequest from vllm_omni.inputs.data import OmniTokensPrompt @@ -852,6 +853,13 @@ def talker2code2wav_async_chunk( # never sets) -> flush the held trailing partial codec as the terminal chunk. if pooling_output.get(ASYNC_FINISH_SENTINEL_KEY): return _flush_code2wav_finish_tail(transfer_manager, request) + # PR4 BUG5: the coordinator+mixin runner send (MC-C) hands this builder a + # FLATTENED pooling_output (flatten_payload -> dotted "codes.audio"), whereas + # the legacy adapter send passes it nested. Reads below expect nested + # ("codes"->"audio"), so without this they silently get None and NO codes + # reach code2wav (audio truncated to a fraction). unflatten is a passthrough + # for already-nested (adapter) input, so it is safe for both paths. + pooling_output = unflatten_payload(pooling_output) talker_codes = pooling_output.get("codes", {}) if not isinstance(talker_codes, dict): return None diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 6bbf8ec0c3c..48ecc435ddd 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -552,7 +552,7 @@ def execute_model( # finished last cycle (in-step sends were finished=False; the runner # learns finish only here). Before _update_states frees them. # Dormant until the allowlist flip (predicate False -> no-op). - if uses_async_chunk_coordinator(self.model_config): + if uses_async_chunk_coordinator(getattr(self, "model_config", None)): async_finished = set(getattr(scheduler_output, "finished_req_ids", set())) if async_finished: self._send_async_chunk_finish_sentinels(async_finished) @@ -1286,9 +1286,9 @@ def _unwrap_lists(v): # terminal None -> starvation). is_finished is always False here: the # runner cannot know finish at sample_tokens time (engine core decides # after), so the terminal is handled separately (later sub-commit), - # mirroring the adapter. With the empty allowlist this branch is - # dormant (uses_async_chunk_coordinator always returns False). - if pooler_output and uses_async_chunk_coordinator(self.model_config): + # mirroring the adapter. This branch is active for allowlisted stages + # (qwen3_omni talker+code2wav) and dormant for every other arch/stage. + if pooler_output and uses_async_chunk_coordinator(getattr(self, "model_config", None)): for i, rid in enumerate(req_ids_output_copy): if rid not in downstream_req_id_set or not pooler_output[i]: continue diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 70f1f58dce5..7b2de000ed5 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -29,6 +29,7 @@ from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices from vllm_omni.core.prefix_cache import OmniTensorPrefixCache +from vllm_omni.core.sched.omni_scheduling_coordinator import uses_async_chunk_coordinator from vllm_omni.engine.serialization import deserialize_additional_information from vllm_omni.model_executor.layers.rotary_embedding.mrope import OmniMRotaryEmbedding as MRotaryEmbedding from vllm_omni.model_executor.models.output_templates import OmniOutput @@ -1563,6 +1564,16 @@ def _preprocess( if hasattr(self.model, "has_preprocess") or hasattr(self.model, "enable_update_additional_information"): if self.vllm_config.model_config.async_chunk: self._update_additional_information(scheduler_output) + # PR4: the coordinator+mixin async-chunk path delivers the heavy + # stage payload (e.g. thinker embed.prefill) via the worker-local + # _local_stage_payload_cache, NOT via scheduler + # additional_information (which _update_additional_information + # reads). Bridge the cache into model_intermediate_buffer the + # same way full-payload mode does, before the model's preprocess + # reads it. The legacy adapter async-chunk path still carries the + # payload in additional_information, so it needs no extra sync. + if uses_async_chunk_coordinator(self.vllm_config.model_config): + self._sync_local_stage_payloads() else: # In full-payload (non-async-chunk) mode, connector-delivered # stage payloads must override any earlier engine-level From 3ee11217031b6c7e018d3c19327169fcc9eab817 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 10 Jun 2026 02:56:36 +0000 Subject: [PATCH 04/10] [PR4] propagate qwen3 omni segment-finished terminals Carry segment-finished state across the legacy thinker to coordinator talker/code2wav boundary with a distinct connector signal so realtime async-chunk requests flush and finish correctly. Signed-off-by: natureofnature --- .../test_qwen3_omni_realtime_websocket.py | 8 ++ .../test_qwen3_omni_finish_sentinel.py | 6 +- .../test_qwen3_omni_streaming_helpers.py | 6 +- tests/worker/test_omni_connector_mixin.py | 79 +++++++++++++++++ vllm_omni/core/sched/omni_ar_scheduler.py | 48 +++++++++++ vllm_omni/core/sched/omni_scheduler_mixin.py | 14 +++ .../core/sched/omni_scheduling_coordinator.py | 2 +- vllm_omni/core/sched/output.py | 1 + .../chunk_transfer_adapter.py | 5 +- .../stage_input_processors/qwen3_omni.py | 21 +++-- vllm_omni/outputs.py | 2 + vllm_omni/worker/gpu_ar_model_runner.py | 41 +++++++++ .../omni_connector_model_runner_mixin.py | 85 ++++++++++++++++--- 13 files changed, 292 insertions(+), 26 deletions(-) diff --git a/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py b/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py index bbe50a9255d..e036a5602ac 100644 --- a/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py +++ b/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py @@ -198,6 +198,10 @@ def _assert_realtime_smoke(result: dict) -> None: assert result["output_sample_rate"] > 0 +def _has_cjk(text: str) -> bool: + return any("\u4e00" <= ch <= "\u9fff" for ch in text) + + def _assert_realtime_accuracy(result: dict) -> None: final_text = (result["transcription_text"] or "").strip() assert final_text, "Expected non-empty transcription (model text stream)" @@ -207,6 +211,10 @@ def _assert_realtime_accuracy(result: dict) -> None: assert whisper_text, "Whisper returned empty string for synthesized output audio" sim = cosine_similarity_text(whisper_text.lower(), final_text.lower()) + if sim <= 0.8 and _has_cjk(whisper_text + final_text): + # Chinese ASR may drop a very short function phrase, which heavily + # penalizes 3-grams while preserving close 2-gram overlap. + sim = max(sim, cosine_similarity_text(whisper_text.lower(), final_text.lower(), n=2)) assert sim > 0.8, ( f"Output audio transcript should match model text (sim={sim:.3f}): " f"whisper={whisper_text!r}, model_text={final_text!r}" diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_finish_sentinel.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_finish_sentinel.py index e7a731d6079..9ba710bdc96 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_finish_sentinel.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_finish_sentinel.py @@ -31,7 +31,7 @@ def _sentinel_payload(): def test_finish_sentinel_flushes_partial_tail(): # 6 frames accumulated, chunk size 4 -> a 2-frame partial tail is still held. - tm = _tm({"r": [[1], [2], [3], [4], [5], [6]]}, chunk_frames=4, left_frames=25) + tm = _tm({"r": [torch.tensor([[i]]) for i in range(1, 7)]}, chunk_frames=4, left_frames=25) req = SimpleNamespace(external_req_id="r") out = talker2code2wav_async_chunk(tm, _sentinel_payload(), req, is_finished=True) @@ -48,7 +48,7 @@ def test_finish_sentinel_flushes_partial_tail(): def test_finish_sentinel_on_chunk_boundary_emits_flag_only(): # 4 frames, chunk size 4 -> the last full chunk was already sent in-step; # no unsent tail, so the sentinel must NOT re-send codec (flag only). - tm = _tm({"r": [[1], [2], [3], [4]]}, chunk_frames=4) + tm = _tm({"r": [torch.tensor([[i]]) for i in range(1, 5)]}, chunk_frames=4) req = SimpleNamespace(external_req_id="r") out = talker2code2wav_async_chunk(tm, _sentinel_payload(), req, is_finished=True) @@ -72,7 +72,7 @@ def test_finish_sentinel_with_no_sent_chunks_emits_flag_only(): def test_non_sentinel_empty_call_is_unchanged(): # Without the marker, an empty/codeless call returns None as before -> the # adapter path (which never sets the marker) is byte-identical. - tm = _tm({"r": [[1], [2]]}, chunk_frames=4) + tm = _tm({"r": [torch.tensor([[1]]), torch.tensor([[2]])]}, chunk_frames=4) req = SimpleNamespace(external_req_id="r") assert talker2code2wav_async_chunk(tm, {"codes": {}}, req, is_finished=True) is None diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index 4b2bdf3fee7..d831df1f0fb 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -97,7 +97,11 @@ def test_streaming_input_prefill_chunk_is_cached() -> None: transfer_manager, ) - assert payload is None + assert payload is not None + assert payload.embed is None + assert payload.hidden_states is None + assert payload.ids is None + assert bool(payload.meta.finished) is False cached = transfer_manager._pending_streaming_prefills["rt-1"] assert torch.equal(cached["embed"]["prefill"], thinker_emb) assert torch.equal(cached["hidden_states"]["output"], thinker_hid) diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py index b8f27460a11..331fd22ab60 100644 --- a/tests/worker/test_omni_connector_mixin.py +++ b/tests/worker/test_omni_connector_mixin.py @@ -1172,6 +1172,85 @@ def test_send_side_request_payload_not_cleared_before_payload_is_consumable(self self.assertIn("r1", host._send_side_request_payload) host.shutdown_omni_connectors() + def test_accumulate_payload_overwrites_scalar_meta(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), + ) + first = { + "embed": {"decode": torch.ones(1, 2)}, + "meta": { + "finished": torch.tensor(False), + "is_segment_finished": torch.tensor(False), + "next_stage_prompt_len": torch.tensor(4), + "num_processed_tokens": torch.tensor(1), + }, + } + second = { + "embed": {"decode": torch.ones(1, 2)}, + "meta": { + "finished": torch.tensor(False), + "is_segment_finished": torch.tensor(True), + "next_stage_prompt_len": torch.tensor(7), + "num_processed_tokens": torch.tensor(2), + }, + } + + host._accumulate_payload("r1", first) + merged = host._accumulate_payload("r1", second) + + self.assertEqual(tuple(merged["embed"]["decode"].shape), (2, 2)) + self.assertEqual(merged["meta"]["is_segment_finished"].shape, torch.Size([])) + self.assertTrue(bool(merged["meta"]["is_segment_finished"].item())) + self.assertEqual(int(merged["meta"]["next_stage_prompt_len"].item()), 7) + self.assertEqual(int(merged["meta"]["num_processed_tokens"].item()), 2) + host.shutdown_omni_connectors() + + def test_segment_finish_is_not_request_finish(self): + payload = { + "meta": { + "finished": torch.tensor(False), + "is_segment_finished": torch.tensor(True), + } + } + + self.assertFalse(MixinHost._payload_finished(payload)) + self.assertTrue(MixinHost._payload_segment_finished(payload)) + self.assertFalse(MixinHost._payload_is_consumable(payload)) + + def test_poll_segment_finish_wakes_without_closing_stream(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), + ) + host._omni_connector = MagicMock() + host._stage_id = 1 + host._async_chunk = True + host._model_mode = "ar" + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 0 + host._pending_load_reqs["r1"] = object() + host._omni_connector.get.return_value = ( + { + "embed": {"decode": torch.ones(1, 2)}, + "meta": { + "finished": torch.tensor(False), + "is_segment_finished": torch.tensor(True), + }, + }, + 1, + ) + + self.assertTrue(host._poll_single_request("r1")) + + self.assertIn("r1", host._finished_load_reqs) + self.assertNotIn("r1", host._chunk_finished_req_ids) + self.assertNotIn("r1", host._chunk_stream_completed) + self.assertIn("r1", host._pending_load_reqs) + host.shutdown_omni_connectors() + def test_payload_consumable_ignores_token_horizon_only_updates(self): host = MixinHost() host.init_omni_connectors( diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 41c18190b5c..b331481179e 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -88,6 +88,10 @@ def __init__(self, *args, **kwargs): # exclusive, so a single coordinator instance serves whichever fires. _async_coord = uses_async_chunk_coordinator(model_config) self.chunk_transfer_adapter = None + # Coordinator-path segment-finished ids pending emission to the runner + # (resumable realtime stops); drained into OmniSchedulerOutput each schedule. + self._omni_pending_segment_finished: set[str] = set() + self._omni_pending_upstream_segment_finished: set[str] = set() if getattr(model_config, "async_chunk", False) and not _async_coord: self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config) self.input_coordinator: OmniSchedulingCoordinator | None = None @@ -101,6 +105,30 @@ def __init__(self, *args, **kwargs): # Snapshot prompt length for each streaming input update self._new_prompt_len_snapshot: dict[str, int] = {} + def add_request(self, request: Request) -> None: + existing = self.requests.get(request.request_id) + if ( + existing is not None + and existing.streaming_queue is not None + and existing.status == RequestStatus.WAITING_FOR_STREAMING_REQ + and StreamingUpdate.from_request(request) is None + and self.chunk_transfer_adapter is not None + and getattr(self.vllm_config.model_config, "stage_id", 0) == 0 + ): + # A realtime final commit can arrive after stage 0 has already parked + # waiting for the next input segment. Base vLLM treats that as an + # external finish and removes the request without producing an + # EngineCoreOutput, which means no async-chunk terminal reaches stage 1. + old_status = existing.status + existing.status = RequestStatus.FINISHED_STOPPED + existing.resumable = False + self.chunk_transfer_adapter.save_async(None, existing, is_segment_finished=True) + existing.status = old_status + self.finish_requests(request.request_id, RequestStatus.FINISHED_STOPPED) + return + + super().add_request(request) + def _get_confirmed_num_computed_tokens(self, request: Request) -> int: """num_computed_tokens minus async placeholders (KV actually on GPU).""" # Output placeholders are zero when async scheduling isn't used @@ -420,6 +448,18 @@ def update_from_output( if not stopped and self._process_kv_transfer_trigger(request, new_token_ids): stopped = True + upstream_segment_finished = req_id in self._omni_pending_upstream_segment_finished + if ( + upstream_segment_finished + and not stopped + and getattr(self.vllm_config.model_config, "final_output", False) + ): + # Final-output stages must flush each realtime segment locally. + # Intermediate stages keep running so they can forward later + # segments and emit their own downstream segment sentinels. + request.status = RequestStatus.FINISHED_STOPPED + stopped = True + if new_token_ids and self.structured_output_manager.should_advance(request): struct_output_request = request.structured_output_request assert struct_output_request is not None @@ -434,6 +474,8 @@ def update_from_output( request.resumable = False stopped = True + self._omni_pending_upstream_segment_finished.discard(req_id) + if stopped: if model_runner_output.routed_experts is not None: routed_experts = omni_routed_experts_for_request(model_runner_output.routed_experts, request) @@ -444,6 +486,12 @@ def update_from_output( is_segment_finished = True finished = self._handle_stopped_request(request) if not finished: + # Coordinator-path stages (talker/code2wav) have no scheduler-side + # save_async(is_segment_finished); record the segment stop so the + # runner emits an is_segment_finished terminal (flushing this + # segment's audio tail) to the downstream stage next cycle. + if self.input_coordinator is not None and not upstream_segment_finished: + self._omni_pending_segment_finished.add(request.request_id) # for streaming input request only if self.chunk_transfer_adapter: if self.vllm_config.model_config.stage_id != 0: diff --git a/vllm_omni/core/sched/omni_scheduler_mixin.py b/vllm_omni/core/sched/omni_scheduler_mixin.py index 11f9d885bdb..c2f5787a0cb 100644 --- a/vllm_omni/core/sched/omni_scheduler_mixin.py +++ b/vllm_omni/core/sched/omni_scheduler_mixin.py @@ -65,6 +65,10 @@ def _consume_pending_connector_output(self, model_mode: str) -> None: input_coordinator.update_request_metadata( self.requests, connector_output.request_metadata, model_mode=model_mode ) + if connector_output and connector_output.chunk_segment_finished_req_ids: + pending_segment_stop = getattr(self, "_omni_pending_upstream_segment_finished", None) + if pending_segment_stop is not None: + pending_segment_stop.update(connector_output.chunk_segment_finished_req_ids) # Both calls self-guard on the coordinator's async_chunk mode # (process_pending_chunks returns early when async_chunk is False; # process_pending_full_payload_inputs branches internally), so exactly @@ -157,10 +161,20 @@ def _wrap_omni_scheduler_output( pending_connector_registrations = ( input_coordinator.pending_connector_registrations if input_coordinator else [] ) + # Drain segment-finished ids recorded by update_from_output for resumable + # realtime stops (coordinator stages only); the runner emits one + # is_segment_finished terminal per id next cycle, mirroring finished_req_ids. + pending_segment = getattr(self, "_omni_pending_segment_finished", None) + if pending_segment: + segment_finished_req_ids = set(pending_segment) + pending_segment.clear() + else: + segment_finished_req_ids = set() return OmniSchedulerOutput( **base_data, finished_requests_needing_kv_transfer=finished_requests_needing_kv_transfer or {}, pending_connector_registrations=pending_connector_registrations, + segment_finished_req_ids=segment_finished_req_ids, ) def make_stats(self, *args, **kwargs) -> SchedulerStats | None: diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index 8ce090eecc5..407ef000191 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -520,7 +520,7 @@ def _process_chunk_queue( request.status = RequestStatus.WAITING_FOR_CHUNK self._waiting_since.setdefault(request.request_id, time.monotonic()) else: - if request.request_id in chunk_ready_req_ids: + if request.request_id in chunk_ready_req_ids or request.request_id in self.finished_requests: request.status = target_status self.requests_with_ready_chunks.add(request.request_id) self._waiting_since.pop(request.request_id, None) diff --git a/vllm_omni/core/sched/output.py b/vllm_omni/core/sched/output.py index bb09f128a44..37ce5965cd3 100644 --- a/vllm_omni/core/sched/output.py +++ b/vllm_omni/core/sched/output.py @@ -94,3 +94,4 @@ class OmniSchedulerOutput(SchedulerOutput): finished_requests_needing_kv_transfer: dict[str, dict] = field(default_factory=dict) pending_connector_registrations: list[OmniChunkRecvHandle] = field(default_factory=list) + segment_finished_req_ids: set[str] = field(default_factory=set) diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 121fb79ecf5..f778945b77b 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -167,7 +167,10 @@ def save_async( request: Request object is_segment_finished: whether the segment of request is finished """ - is_finished = request.is_finished() and not request.resumable + # A final realtime chunk can be both segment-finished and request-finished. + # Do not mask request completion just because the session is resumable; + # non-final segment stops already report request.is_finished() as False. + is_finished = request.is_finished() confirmed_num_computed_tokens = self._confirmed_num_computed_tokens(request) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 311ef746dd8..b2b01faafad 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -273,10 +273,10 @@ def _construct_thinker2talker_streaming_input_async_chunk( """Build Thinker -> Talker payloads for realtime streaming input chunks. A resumable realtime request reuses the same logical request id across - audio segments. The first streaming prefill chunk is cached and returns ``None`` so the - connector does not emit an incomplete downstream chunk. The following - decode chunk flushes that cached prefill together with the current Thinker - output, keeping Talker ids and tensor rows aligned. + audio segments. The first streaming prefill chunk is cached and returns a + metadata-only payload so async chunk ids keep advancing without waking the + Talker. The following decode chunk flushes that cached prefill together + with the current Thinker output, keeping Talker ids and tensor rows aligned. """ request_id = request.external_req_id output_token_ids = request.output_token_ids @@ -305,7 +305,9 @@ def _construct_thinker2talker_streaming_input_async_chunk( language=language, ) transfer_manager._pending_streaming_prefills[request_id] = to_dict(payload) - return None + # Keep async chunk ids monotonic even when this step only caches + # prefill state for the following decode chunk. + return OmniPayloadStruct(meta=MetaStruct(finished=finished)) else: save_payload = transfer_manager._pending_streaming_prefills.pop(request_id, None) if save_payload is not None: @@ -335,8 +337,9 @@ def _construct_thinker2talker_streaming_input_async_chunk( ) else: if not is_finished: - # do not send async chunk mode placeholder token or embedding/hidden of the stop token - return None + # Keep the transport moving, but do not wake the Talker for a + # placeholder-only chunk. + return OmniPayloadStruct(meta=MetaStruct(finished=finished)) return OmniPayloadStruct( meta=MetaStruct(finished=finished), embed=EmbeddingsStruct(decode=emb_cpu), @@ -831,7 +834,9 @@ def _flush_code2wav_finish_tail(transfer_manager: Any, request: OmniEngineCoreRe context_length = chunk_length left_context_size = max(0, min(length - context_length, left_context_size_config)) end_index = min(length, left_context_size + context_length) - codes = torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]).transpose(0, 1).reshape(-1) + codes = ( + torch.cat(transfer_manager.code_prompt_token_ids[request_id][-end_index:], dim=0).transpose(0, 1).reshape(-1) + ) return OmniPayloadStruct( codes=CodesStruct(audio=codes), meta=MetaStruct(left_context_size=left_context_size, finished=finished_flag), diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index b4d308ebf8f..02d6e90b590 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -19,6 +19,7 @@ class OmniConnectorOutput: Attributes: chunk_ready_req_ids: Request IDs with newly arrived chunks this cycle. chunk_finished_req_ids: Request IDs whose final chunk has arrived. + chunk_segment_finished_req_ids: Request IDs whose current realtime segment ended. request_metadata: Lightweight scheduling metadata keyed by request ID (e.g. next_stage_prompt_len, code_predictor_codes, left_context_size). Full payloads are owned by the Model Runner's local cache. @@ -30,6 +31,7 @@ class OmniConnectorOutput: chunk_ready_req_ids: set[str] = field(default_factory=set) chunk_finished_req_ids: set[str] = field(default_factory=set) + chunk_segment_finished_req_ids: set[str] = field(default_factory=set) request_metadata: dict[str, dict[str, Any]] = field(default_factory=dict) kv_sent_req_ids: list[str] = field(default_factory=list) stage_recv_req_ids: set[str] = field(default_factory=set) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 48ecc435ddd..c1e285da58c 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -295,6 +295,27 @@ def _send_async_chunk_finish_sentinels(self, finished_req_ids: set[str]) -> None request = SimpleNamespace(request_id=rid, req_id=rid, external_req_id=ext_id, is_finished=lambda: True) self.enqueue_finish_sentinel(request, ext_id) + def _send_async_chunk_segment_sentinels(self, segment_finished_req_ids: set[str]) -> None: + """Emit one SEGMENT terminal per just-segment-finished async-chunk request. + + Mirrors ``_send_async_chunk_finish_sentinels`` but for a resumable + ``/v1/realtime`` segment stop: the request is NOT finished/freed (it + continues with the next segment), so this flushes the segment's trailing + audio tail to the downstream stage marked ``is_segment_finished=True`` / + ``finished=False`` -- it does NOT force ``is_finished`` on the request. + """ + for rid in segment_finished_req_ids: + ext_id = self._request_ids_mapping.get(rid) or self._resolve_transfer_request_id(rid) + # A realtime segment may produce zero downstream chunks, but the next + # stage is already registered and waiting for chunk 0. + self._put_req_chunk.setdefault(ext_id, 0) + snapshot = self._send_side_request_snapshot.get(ext_id) + if snapshot is not None: + request: Any = snapshot + else: + request = SimpleNamespace(request_id=rid, req_id=rid, external_req_id=ext_id, is_finished=lambda: False) + self.enqueue_finish_sentinel(request, ext_id, is_segment_finished=True) + def _resolve_pooler_payload_req_ids(self, req_ids_output_copy: list[str]) -> tuple[str, list[str]]: downstream_req_ids = [rid for rid in req_ids_output_copy if self._request_needs_downstream_stage_payload(rid)] engine_output_type = (self.vllm_config.model_config.engine_output_type or "").lower() @@ -556,6 +577,9 @@ def execute_model( async_finished = set(getattr(scheduler_output, "finished_req_ids", set())) if async_finished: self._send_async_chunk_finish_sentinels(async_finished) + segment_finished = set(getattr(scheduler_output, "segment_finished_req_ids", set())) + if segment_finished: + self._send_async_chunk_segment_sentinels(segment_finished) if self.omni_prefix_cache is not None and scheduler_output.finished_req_ids: self.omni_prefix_cache.commit_deferred_mm_outputs( @@ -1302,6 +1326,23 @@ def _unwrap_lists(v): self._send_side_request_snapshot[ext_id] = self._snapshot_request_for_send(wrapped, ext_id) self.send_chunk(request=wrapped, pooling_output=pooler_output[i]) + if uses_async_chunk_coordinator(getattr(self, "model_config", None)): + segment_finished_now: set[str] = set() + for rid in downstream_req_id_set: + info = self.model_intermediate_buffer.get(rid) + if not isinstance(info, dict): + continue + meta = info.get("meta") + if not isinstance(meta, dict): + continue + if self._payload_truthy_scalar(meta.get("is_segment_finished")): + segment_finished_now.add(rid) + # The terminal is emitted below for this segment. Clear the + # transient segment marker so later idle steps do not resend it. + meta["is_segment_finished"] = False + if segment_finished_now: + self._send_async_chunk_segment_sentinels(segment_finished_now) + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): routed_experts_lists = None if self.routed_experts_initialized: diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index a9404d05d5d..c65b54dcc08 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -34,6 +34,15 @@ ) _EMBED_SPAN_GROUPS: tuple[tuple[str, str, str], ...] = (("decode", "decode_token_start", "decode_token_end"),) +_META_OVERWRITE_KEYS: frozenset[str] = frozenset( + ( + "finished", + "is_segment_finished", + "stream_finished", + "next_stage_prompt_len", + "num_processed_tokens", + ) +) if TYPE_CHECKING: from vllm_omni.distributed.omni_connectors.connectors.base import ( @@ -185,6 +194,7 @@ def init_omni_connectors( # -- per-cycle output accumulator -- self._chunk_ready_req_ids: set[str] = set() self._chunk_finished_req_ids: set[str] = set() + self._chunk_segment_finished_req_ids: set[str] = set() self._stage_recv_req_ids: set[str] = set() self._full_payload_pending_broadcast_req_ids: set[str] = set() self._async_chunk_updated_req_ids: set[str] = set() @@ -357,6 +367,7 @@ def _clear_recv_delivery_state(self, req_id: str) -> None: self._finished_load_reqs.discard(req_id) self._chunk_ready_req_ids.discard(req_id) self._chunk_finished_req_ids.discard(req_id) + self._chunk_segment_finished_req_ids.discard(req_id) self._chunk_stream_completed.discard(req_id) self._stage_recv_req_ids.discard(req_id) self._full_payload_pending_broadcast_req_ids.discard(req_id) @@ -417,6 +428,7 @@ def prune_inactive_requests(self, active_req_ids: Any) -> set[str]: "_finished_load_reqs", "_chunk_ready_req_ids", "_chunk_finished_req_ids", + "_chunk_segment_finished_req_ids", "_chunk_stream_completed", "_stage_recv_req_ids", "_full_payload_pending_broadcast_req_ids", @@ -495,6 +507,8 @@ def _extract_scheduling_metadata(cls, payload: OmniPayload) -> dict[str, Any]: _NON_CONSUMABLE_PAYLOAD_KEYS: set[tuple[str, str]] = { ("meta", "finished"), + ("meta", "is_segment_finished"), + ("meta", "stream_finished"), ("meta", "override_keys"), ("meta", "next_stage_prompt_len"), ("meta", "left_context_size"), @@ -513,19 +527,31 @@ def _payload_value_has_content(value: Any) -> bool: return len(value) > 0 return True + @staticmethod + def _payload_truthy_scalar(flag: Any) -> bool: + if isinstance(flag, torch.Tensor): + return flag.numel() == 1 and bool(flag.item()) + return bool(flag) + @staticmethod def _payload_finished(payload: Any) -> bool: if not isinstance(payload, dict): return False if "finished" in payload: - logger.warning_once("legacy flat 'finished' key in payload; expected 'meta.finished'") + logger.warning_once("legacy flat finished key in payload; expected meta.finished") meta = payload.get("meta") - if not isinstance(meta, dict) or "finished" not in meta: + if not isinstance(meta, dict): return False - flag = meta["finished"] - if isinstance(flag, torch.Tensor): - return flag.numel() == 1 and bool(flag.item()) - return bool(flag) + return OmniConnectorModelRunnerMixin._payload_truthy_scalar(meta.get("finished")) + + @staticmethod + def _payload_segment_finished(payload: Any) -> bool: + if not isinstance(payload, dict): + return False + meta = payload.get("meta") + if not isinstance(meta, dict): + return False + return OmniConnectorModelRunnerMixin._payload_truthy_scalar(meta.get("is_segment_finished")) @staticmethod def _payload_audio_codes(payload: Any) -> Any: @@ -674,9 +700,14 @@ def _collect_async_chunk_fanout_packet_locked(self) -> dict[str, Any] | None: payload_req_ids = set(self._async_chunk_updated_req_ids) payload_req_ids.update(self._finished_load_reqs) payload_req_ids.update(self._chunk_finished_req_ids) + payload_req_ids.update(self._chunk_segment_finished_req_ids) payload_req_ids.update(self._local_request_metadata) if not ( - payload_req_ids or self._finished_load_reqs or self._chunk_finished_req_ids or self._local_request_metadata + payload_req_ids + or self._finished_load_reqs + or self._chunk_finished_req_ids + or self._chunk_segment_finished_req_ids + or self._local_request_metadata ): return None @@ -690,11 +721,13 @@ def _collect_async_chunk_fanout_packet_locked(self) -> dict[str, Any] | None: "request_metadata": dict(self._local_request_metadata), "newly_finished": set(self._finished_load_reqs), "chunk_finished": set(self._chunk_finished_req_ids), + "chunk_segment_finished": set(self._chunk_segment_finished_req_ids), } self._async_chunk_updated_req_ids.clear() self._finished_load_reqs.clear() self._chunk_finished_req_ids.clear() + self._chunk_segment_finished_req_ids.clear() self._local_request_metadata.clear() for req_id in packet["chunk_finished"]: @@ -1241,7 +1274,7 @@ def send_chunk( self._work_available.set() return True - def enqueue_finish_sentinel(self, request: Any, request_id: str) -> bool: + def enqueue_finish_sentinel(self, request: Any, request_id: str, *, is_segment_finished: bool = False) -> bool: """Enqueue the terminal chunk for a finished async-chunk request. The producer runner cannot know at ``sample_tokens`` time that a request @@ -1268,7 +1301,26 @@ def enqueue_finish_sentinel(self, request: Any, request_id: str) -> bool: pooling_output={ASYNC_FINISH_SENTINEL_KEY: True}, ) if payload_data is None: - payload_data = {"meta": {"finished": torch.tensor(True, dtype=torch.bool)}} + payload_data = { + "meta": { + "finished": torch.tensor(not is_segment_finished, dtype=torch.bool), + "is_segment_finished": torch.tensor(is_segment_finished, dtype=torch.bool), + } + } + elif is_segment_finished: + # Realtime SEGMENT terminal: the tail content was built above via the + # sentinel marker, but mark it segment-finished (NOT request-finished): + # a resumable /v1/realtime talker continues with the next segment, so the + # downstream stage must flush THIS segment's audio tail without the + # request being treated as finished/freed. + _meta = getattr(payload_data, "meta", None) + if _meta is not None: + _meta.finished = torch.tensor(False, dtype=torch.bool) + _meta.is_segment_finished = torch.tensor(True, dtype=torch.bool) + elif isinstance(payload_data, dict): + _m = payload_data.setdefault("meta", {}) + _m["finished"] = torch.tensor(False, dtype=torch.bool) + _m["is_segment_finished"] = torch.tensor(True, dtype=torch.bool) chunk_id = self._put_req_chunk[request_id] self._put_req_chunk[request_id] += 1 @@ -1652,12 +1704,14 @@ def get_omni_connector_output(self) -> OmniConnectorOutput: if fanout_packet is None: newly_finished = set() chunk_finished = set() + chunk_segment_finished = set() request_metadata = {} else: if not self.is_data_transfer_rank(): self._apply_async_chunk_fanout_packet(fanout_packet) newly_finished = set(fanout_packet["newly_finished"]) chunk_finished = set(fanout_packet["chunk_finished"]) + chunk_segment_finished = set(fanout_packet.get("chunk_segment_finished", set())) request_metadata = dict(fanout_packet["request_metadata"]) else: with self._lock: @@ -1665,6 +1719,8 @@ def get_omni_connector_output(self) -> OmniConnectorOutput: self._finished_load_reqs.clear() chunk_finished = set(self._chunk_finished_req_ids) self._chunk_finished_req_ids.clear() + chunk_segment_finished = set(self._chunk_segment_finished_req_ids) + self._chunk_segment_finished_req_ids.clear() request_metadata = dict(self._local_request_metadata) self._local_request_metadata.clear() # _send_side_request_payload is the async accumulation buffer for @@ -1684,6 +1740,7 @@ def get_omni_connector_output(self) -> OmniConnectorOutput: output = OmniConnectorOutput( chunk_ready_req_ids=set(self._chunk_ready_req_ids), chunk_finished_req_ids=chunk_finished, + chunk_segment_finished_req_ids=chunk_segment_finished, request_metadata=request_metadata, kv_sent_req_ids=list(self._kv_sent_req_ids), stage_recv_req_ids=set(self._stage_recv_req_ids), @@ -1707,6 +1764,7 @@ def _connector_output_has_signals(output: OmniConnectorOutput) -> bool: return bool( output.chunk_ready_req_ids or output.chunk_finished_req_ids + or output.chunk_segment_finished_req_ids or output.request_metadata or output.kv_sent_req_ids or output.stage_recv_req_ids @@ -1912,6 +1970,7 @@ def _poll_single_request(self, req_id: str) -> bool: if self._async_chunk: is_finished = self._payload_finished(payload_data) + is_segment_finished = self._payload_segment_finished(payload_data) incoming_payload_consumable = self._payload_is_consumable(payload_data) if self._model_mode == "ar": @@ -1926,7 +1985,7 @@ def _poll_single_request(self, req_id: str) -> bool: has_codes = audio_codes.numel() > 0 else: has_codes = bool(audio_codes) - if not has_codes and not is_finished: + if not has_codes and not (is_finished or is_segment_finished): return False payload_consumable = self._payload_is_consumable(payload_data) @@ -1934,6 +1993,8 @@ def _poll_single_request(self, req_id: str) -> bool: if is_finished: self._chunk_finished_req_ids.add(req_id) self._chunk_stream_completed.add(req_id) + if is_segment_finished: + self._chunk_segment_finished_req_ids.add(req_id) # Local cache (RFC §2.4) — DEEP-merge, don't replace, so that # earlier chunk keys (e.g. chunk-0's embed.prefill) survive when a # later decode chunk arrives with embed.decode. Shallow dict.update @@ -1951,7 +2012,7 @@ def _poll_single_request(self, req_id: str) -> bool: # the downstream stage can sync the merged local payload and # flush/finish even when the last recv carries no new # consumable chunk bytes. - if payload_consumable or is_finished: + if payload_consumable or is_finished or is_segment_finished: self._finished_load_reqs.add(req_id) if is_finished and not payload_consumable: logger.debug( @@ -2167,7 +2228,7 @@ def _accumulate_payload(self, req_id: str, payload_data: OmniPayload) -> OmniPay for qual, qval in value.items(): if qual in span_handled: continue - if key == "meta" and qual == "finished": + if key == "meta" and qual in _META_OVERWRITE_KEYS: merged_sub[qual] = qval continue if (key, qual) in override_keys: From 4d46b5abc44135d42d5777bb666dd54a974a1aec Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 10 Jun 2026 07:27:21 +0000 Subject: [PATCH 05/10] [PR4] fix qwen3 async-chunk decode handoff alignment Keep prefill and decode rows separate across the streaming handoff, then consume cached decode rows as the talker decode prefix before current decode rows. Signed-off-by: natureofnature --- .../test_qwen3_omni_realtime_websocket.py | 8 -- .../test_qwen3_omni_talker_decode.py | 95 +++++++++++++++++++ .../test_qwen3_omni_streaming_helpers.py | 6 +- .../models/qwen3_omni/qwen3_omni.py | 33 +++++-- .../stage_input_processors/qwen3_omni.py | 10 +- 5 files changed, 130 insertions(+), 22 deletions(-) create mode 100644 tests/model_executor/models/qwen3_omni/test_qwen3_omni_talker_decode.py diff --git a/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py b/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py index e036a5602ac..bbe50a9255d 100644 --- a/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py +++ b/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py @@ -198,10 +198,6 @@ def _assert_realtime_smoke(result: dict) -> None: assert result["output_sample_rate"] > 0 -def _has_cjk(text: str) -> bool: - return any("\u4e00" <= ch <= "\u9fff" for ch in text) - - def _assert_realtime_accuracy(result: dict) -> None: final_text = (result["transcription_text"] or "").strip() assert final_text, "Expected non-empty transcription (model text stream)" @@ -211,10 +207,6 @@ def _assert_realtime_accuracy(result: dict) -> None: assert whisper_text, "Whisper returned empty string for synthesized output audio" sim = cosine_similarity_text(whisper_text.lower(), final_text.lower()) - if sim <= 0.8 and _has_cjk(whisper_text + final_text): - # Chinese ASR may drop a very short function phrase, which heavily - # penalizes 3-grams while preserving close 2-gram overlap. - sim = max(sim, cosine_similarity_text(whisper_text.lower(), final_text.lower(), n=2)) assert sim > 0.8, ( f"Output audio transcript should match model text (sim={sim:.3f}): " f"whisper={whisper_text!r}, model_text={final_text!r}" diff --git a/tests/model_executor/models/qwen3_omni/test_qwen3_omni_talker_decode.py b/tests/model_executor/models/qwen3_omni/test_qwen3_omni_talker_decode.py new file mode 100644 index 00000000000..74eef071c77 --- /dev/null +++ b/tests/model_executor/models/qwen3_omni/test_qwen3_omni_talker_decode.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.models.qwen3_omni.qwen3_omni import ( + Qwen3OmniMoeForConditionalGeneration, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _make_minimal_omni() -> Qwen3OmniMoeForConditionalGeneration: + model = Qwen3OmniMoeForConditionalGeneration.__new__(Qwen3OmniMoeForConditionalGeneration) + model.talker = SimpleNamespace(text_projection=lambda x: x + 10) + model.tts_pad_embed = torch.full((2,), -1.0) + model.tts_eos_embed = torch.full((2,), -2.0) + return model + + +def test_async_chunk_decode_consumes_cached_handoff_decode() -> None: + model = _make_minimal_omni() + payload = { + "embed": { + "cached_decode": torch.tensor( + [ + [1.0, 2.0], + [3.0, 4.0], + ] + ) + }, + "meta": { + "num_processed_tokens": 1, + "prefill_consumed_text_tokens": 1, + }, + } + update: dict = {} + + out = model._thinker_decode_to_talker_decode(payload, torch.device("cpu"), update) + + assert torch.equal(out, torch.tensor([11.0, 12.0])) + assert update["_advance_num_processed_tokens"] is True + + +def test_async_chunk_decode_appends_current_decode_after_cached_prefix() -> None: + model = _make_minimal_omni() + payload = { + "embed": { + "cached_decode": torch.tensor( + [ + [1.0, 2.0], + [3.0, 4.0], + ] + ), + "decode": torch.tensor([[5.0, 6.0]]), + }, + "meta": { + "num_processed_tokens": 3, + "prefill_consumed_text_tokens": 1, + }, + } + update: dict = {} + + out = model._thinker_decode_to_talker_decode(payload, torch.device("cpu"), update) + + assert torch.equal(out, torch.tensor([15.0, 16.0])) + assert update["_advance_num_processed_tokens"] is True + + +def test_async_chunk_decode_uses_accumulated_decode_when_cache_is_prefix() -> None: + model = _make_minimal_omni() + payload = { + "embed": { + "cached_decode": torch.tensor([[1.0, 2.0]]), + "decode": torch.tensor( + [ + [1.0, 2.0], + [3.0, 4.0], + ] + ), + }, + "meta": { + "num_processed_tokens": 2, + "prefill_consumed_text_tokens": 1, + }, + } + update: dict = {} + + out = model._thinker_decode_to_talker_decode(payload, torch.device("cpu"), update) + + assert torch.equal(out, torch.tensor([13.0, 14.0])) + assert update["_advance_num_processed_tokens"] is True diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py index d831df1f0fb..4a8e62d3f06 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -139,10 +139,12 @@ def test_streaming_input_prefill_flushes_with_next_decode_chunk() -> None: ) assert payload is not None - assert payload.embed.prefill.shape == (3, 3) - assert payload.hidden_states.output.shape == (3, 3) + assert payload.embed.prefill.shape == (2, 3) + assert torch.equal(payload.embed.decode, thinker_emb) + assert payload.hidden_states.output.shape == (2, 3) assert payload.ids.all == [151644, 872, 100] assert payload.ids.prompt == [151644, 872] + assert payload.ids.output == [101] assert "rt-2" not in transfer_manager._pending_streaming_prefills diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index ac19ab70558..281bd830608 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -1023,15 +1023,32 @@ def _thinker_decode_to_talker_decode( embed = payload.get("embed", {}) meta = payload.get("meta", {}) - thinker_decode_embed = embed.get("decode", None) + cached_decode_embed = embed.get("cached_decode", None) + current_decode_embed = embed.get("decode", None) + if isinstance(cached_decode_embed, torch.Tensor) and isinstance(current_decode_embed, torch.Tensor): + cached_decode_embed = cached_decode_embed.to( + device=current_decode_embed.device, + dtype=current_decode_embed.dtype, + ) + cached_len = int(cached_decode_embed.shape[0]) + if ( + current_decode_embed.shape[0] >= cached_len + and torch.equal(current_decode_embed[:cached_len], cached_decode_embed) + ): + thinker_decode_embed = current_decode_embed + else: + thinker_decode_embed = torch.cat([cached_decode_embed, current_decode_embed], dim=0) + elif isinstance(cached_decode_embed, torch.Tensor): + thinker_decode_embed = cached_decode_embed + else: + thinker_decode_embed = current_decode_embed start_index = meta.get("num_processed_tokens", 0) - # PR4: the worker re-syncs the FULL cumulative thinker decode embeds into - # ``embed.decode`` every step (OmniChunkTransferAdapter parity via - # _accumulate_payload + _sync_local_stage_payloads), so consume by an - # advancing absolute index rather than popping the tensor (a pop would be - # clobbered by the next sync). ``embed.decode`` row 0 is the first - # POST-prefill thinker token, while ``num_processed_tokens`` counts from - # the prefill base, hence the ``- prefill_consumed_text_tokens`` offset. + # PR4 async chunk: prefill may cache the first post-prefill Thinker + # decode row in ``cached_decode`` before the request enters decode. + # Treat that cache as the prefix of the cumulative decode sequence and + # consume by absolute index; popping would be clobbered by the next sync. + # Row 0 is the first POST-prefill thinker token, while + # ``num_processed_tokens`` counts from the prefill base. base = meta.get("prefill_consumed_text_tokens", 0) avail = thinker_decode_embed.shape[0] if isinstance(thinker_decode_embed, torch.Tensor) else 0 idx = start_index - base diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index b2b01faafad..fecb9472cf2 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -275,8 +275,9 @@ def _construct_thinker2talker_streaming_input_async_chunk( A resumable realtime request reuses the same logical request id across audio segments. The first streaming prefill chunk is cached and returns a metadata-only payload so async chunk ids keep advancing without waking the - Talker. The following decode chunk flushes that cached prefill together - with the current Thinker output, keeping Talker ids and tensor rows aligned. + Talker. The following decode chunk flushes the cached prefill and carries + the current Thinker output as decode data, keeping Talker ids and tensor + rows aligned. """ request_id = request.external_req_id output_token_ids = request.output_token_ids @@ -316,11 +317,12 @@ def _construct_thinker2talker_streaming_input_async_chunk( if isinstance(saved_prefill, torch.Tensor) and isinstance(saved_output, torch.Tensor): return OmniPayloadStruct( meta=MetaStruct(finished=finished), - embed=EmbeddingsStruct(prefill=torch.cat((saved_prefill, emb_cpu), dim=0)), - hidden_states=HiddenStatesStruct(output=torch.cat((saved_output, hid_cpu), dim=0)), + embed=EmbeddingsStruct(prefill=saved_prefill, decode=emb_cpu), + hidden_states=HiddenStatesStruct(output=saved_output), ids=IdsStruct( all=save_payload.get("ids", {}).get("all"), prompt=save_payload.get("ids", {}).get("prompt"), + output=output_token_ids, ), speaker=speaker, language=language, From e3e957bf41756c438e53cb72bb88a8c8e6929877 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 10 Jun 2026 10:28:45 +0000 Subject: [PATCH 06/10] [PR4] preserve terminal finished meta in async cache Signed-off-by: natureofnature --- tests/worker/test_omni_connector_mixin.py | 19 +++++++++++++++++++ .../omni_connector_model_runner_mixin.py | 19 ++++++++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py index 331fd22ab60..e83d5efec68 100644 --- a/tests/worker/test_omni_connector_mixin.py +++ b/tests/worker/test_omni_connector_mixin.py @@ -1667,6 +1667,25 @@ def test_meta_finished_not_overwritten_by_intermediate_chunk(self): self.assertTrue(existing["meta"]["finished"]) self.assertEqual(existing["meta"]["y"], 2) + def test_meta_finished_true_overwrites_intermediate_false(self): + existing = {"meta": {"finished": torch.tensor(False), "x": 1}} + _deep_merge_chunk_payload(existing, {"meta": {"finished": torch.tensor(True), "y": 2}}) + + self.assertTrue(bool(existing["meta"]["finished"].item())) + self.assertEqual(existing["meta"]["x"], 1) + self.assertEqual(existing["meta"]["y"], 2) + + def test_terminal_finished_chunk_updates_model_visible_cache(self): + existing = { + "embed": {"decode": torch.ones(1, 2)}, + "meta": {"finished": torch.tensor(False)}, + } + incoming = {"meta": {"finished": torch.tensor(True)}} + + _deep_merge_chunk_payload(existing, incoming) + + self.assertTrue(bool(existing["meta"]["finished"].item())) + def test_non_dict_value_replaced_and_new_key_added(self): existing = {"a": 1, "nested": {"k": "v"}} _deep_merge_chunk_payload(existing, {"a": 2, "b": 3}) diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index c65b54dcc08..f90b9c86a49 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -62,8 +62,9 @@ def _deep_merge_chunk_payload(existing: dict, incoming: dict) -> None: later decode chunk's ``embed={decode: ...}`` must not clobber chunk-0's ``embed={prefill: ..., tts_bos: ...}`` (a shallow ``dict.update`` replaces the whole ``embed`` sub-dict, losing ``prefill`` -> ``KeyError: 'prefill'`` in - ``talker_preprocess_prefill``). An intermediate chunk's ``meta.finished`` is - not allowed to overwrite, mirroring the adapter recv merge + ``talker_preprocess_prefill``). A terminal ``meta.finished=True`` must stay + visible to the model, while an intermediate ``False`` must not clear an + already-terminal value, mirroring the adapter recv merge (``OmniChunkTransferAdapter._update_request_payload``). """ for key, value in incoming.items(): @@ -72,7 +73,19 @@ def _deep_merge_chunk_payload(existing: dict, incoming: dict) -> None: merged = dict(sub) if isinstance(sub, dict) else {} for sub_key, sub_val in value.items(): if key == "meta" and sub_key == "finished": - continue + incoming_finished = ( + sub_val.numel() == 1 and bool(sub_val.item()) + if isinstance(sub_val, torch.Tensor) + else bool(sub_val) + ) + existing_val = merged.get(sub_key) + existing_finished = ( + existing_val.numel() == 1 and bool(existing_val.item()) + if isinstance(existing_val, torch.Tensor) + else bool(existing_val) + ) + if existing_finished and not incoming_finished: + continue merged[sub_key] = sub_val existing[key] = merged else: From 2bf11d5dfc5f995ffe1298479db2cd1b39723b73 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 10 Jun 2026 13:29:35 +0000 Subject: [PATCH 07/10] Format Qwen3 Omni decode handoff Signed-off-by: natureofnature --- vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 281bd830608..4bf88c6ab76 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -1031,9 +1031,8 @@ def _thinker_decode_to_talker_decode( dtype=current_decode_embed.dtype, ) cached_len = int(cached_decode_embed.shape[0]) - if ( - current_decode_embed.shape[0] >= cached_len - and torch.equal(current_decode_embed[:cached_len], cached_decode_embed) + if current_decode_embed.shape[0] >= cached_len and torch.equal( + current_decode_embed[:cached_len], cached_decode_embed ): thinker_decode_embed = current_decode_embed else: From 53e77b0221fec0d005b0bc3fa2f534b81fa152ad Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 10 Jun 2026 15:51:37 +0000 Subject: [PATCH 08/10] Clean async chunk comments Signed-off-by: natureofnature --- .../core/sched/test_omni_scheduling_coordinator.py | 4 ++-- .../test_qwen3_omni_finish_sentinel.py | 2 +- tests/worker/test_async_chunk_request_adapter.py | 2 +- tests/worker/test_omni_connector_mixin.py | 8 ++++---- vllm_omni/core/sched/omni_ar_scheduler.py | 4 ++-- vllm_omni/core/sched/omni_generation_scheduler.py | 14 +++++++------- .../core/sched/omni_scheduling_coordinator.py | 2 +- vllm_omni/data_entry_keys.py | 2 +- .../model_executor/models/qwen3_omni/qwen3_omni.py | 4 ++-- .../stage_input_processors/qwen3_omni.py | 6 +++--- vllm_omni/worker/gpu_ar_model_runner.py | 8 ++++---- vllm_omni/worker/gpu_model_runner.py | 2 +- .../worker/omni_connector_model_runner_mixin.py | 2 +- 13 files changed, 30 insertions(+), 30 deletions(-) diff --git a/tests/core/sched/test_omni_scheduling_coordinator.py b/tests/core/sched/test_omni_scheduling_coordinator.py index c7afe1020de..caa5e46729d 100644 --- a/tests/core/sched/test_omni_scheduling_coordinator.py +++ b/tests/core/sched/test_omni_scheduling_coordinator.py @@ -884,7 +884,7 @@ def test_overflow_does_not_strand_request(self): class TestAsyncChunkCoordinatorGate(unittest.TestCase): - """PR4: `uses_async_chunk_coordinator` selects the coordinator+mixin path for + """`uses_async_chunk_coordinator` selects the coordinator+mixin path for allowlisted async-chunk archs on SharedMemory only; everyone else (empty allowlist today, Mooncake, sync) stays on the legacy adapter. """ @@ -955,7 +955,7 @@ def test_sync_or_non_allowlisted_does_not_fire(self): class TestAsyncChunkRecvRegistration(unittest.TestCase): - """PR4 regression (flip data-plane bug 2026-06-02): a parked async-chunk + """Regression coverage: a parked async-chunk request MUST be registered for bg-thread recv via the CARRIED ``pending_connector_registrations`` (the old ``pending_chunk_registrations`` was never carried/consumed -> the runner never called register_chunk_recv, diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_finish_sentinel.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_finish_sentinel.py index 9ba710bdc96..9986f4aed38 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_omni_finish_sentinel.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_finish_sentinel.py @@ -1,4 +1,4 @@ -"""PR4 MC-C.3b: code2wav async finish-sentinel terminal flush. +"""code2wav async finish-sentinel terminal flush. The producer runner sends every in-step codec chunk with ``finished=False`` and emits a separate finish sentinel next cycle (empty payload + the diff --git a/tests/worker/test_async_chunk_request_adapter.py b/tests/worker/test_async_chunk_request_adapter.py index c44a1c830fa..8614ddbd983 100644 --- a/tests/worker/test_async_chunk_request_adapter.py +++ b/tests/worker/test_async_chunk_request_adapter.py @@ -1,4 +1,4 @@ -"""PR4 MC-C.1: unit coverage for the runner-side async-chunk request shim. +"""unit coverage for the runner-side async-chunk request shim. Covers the pure, GPU-free helpers that the AR runner uses to feed the async-chunk stage-input processors from a worker-side ``CachedRequestState``: diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py index e83d5efec68..0abf06b5028 100644 --- a/tests/worker/test_omni_connector_mixin.py +++ b/tests/worker/test_omni_connector_mixin.py @@ -172,7 +172,7 @@ def broken_process(transfer_manager, pooling_output, request, is_finished=""): sender.shutdown_omni_connectors() def test_send_chunk_skips_preempted_replay(self): - # PR4 MC-C.2: preemption dup-send guard parity with the adapter. + # Keep preemption duplicate-send guard parity with the adapter. connector = MockConnector(stage_id=0) sender = MixinHost() sender.init_omni_connectors( @@ -230,8 +230,8 @@ def _last_enqueued_data(self, sender, request_id): return dq[-1]["data"] if dq else None def test_finish_sentinel_falls_back_to_finished_flag(self): - # PR4 MC-C.3: hook returns None for an empty terminal -> mixin enqueues a - # bare finished=True flag so the downstream stage still terminates. + # Empty terminal payloads fall back to a bare finished=True flag so the + # downstream stage still terminates. sender = MixinHost() sender.init_omni_connectors(vllm_config=None, model_config=_make_model_config(stage_id=0, async_chunk=True)) self._quiesce_save_thread(sender) @@ -1632,7 +1632,7 @@ def run_one_loop(): class TestDeepMergeChunkPayload(unittest.TestCase): - """PR4 BUG2 regression: the recv cache must DEEP-merge nested dicts so a later + """Regression coverage: the recv cache must DEEP-merge nested dicts so a later decode chunk's embed={decode} does not clobber chunk-0's embed={prefill} (shallow dict.update lost the prefill -> KeyError 'prefill' in talker_preprocess_prefill).""" diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index b331481179e..7a940fba137 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -81,7 +81,7 @@ def __init__(self, *args, **kwargs): # Cache per-request flag to avoid repeated deserialization of additional_information self._omits_kv_transfer_cache: dict[str, bool] = {} model_config = self.vllm_config.model_config - # PR4: allowlisted async-chunk archs (SharedMemory) drive recv through the + # Allowlisted async-chunk archs (SharedMemory) drive recv through the # OmniSchedulingCoordinator + runner mixin; everyone else keeps the legacy # adapter. Empty allowlist today => _async_coord is always False (no behavior # change). Full-payload (async_chunk=False) and async-chunk are mutually @@ -299,7 +299,7 @@ def schedule(self) -> SchedulerOutput: # type: ignore[override] if self.chunk_transfer_adapter: self.chunk_transfer_adapter.postprocess_scheduler_output(scheduler_output, self.requests) if self.input_coordinator: - # PR4: mirror the adapter postprocess on the coordinator path so the + # Mirror the adapter postprocess on the coordinator path so the # per-cycle ready-chunk flags (requests_with_ready_chunks) are cleared # after each scheduler step. Without this, a streamed multi-chunk request # stays flagged "ready" after its first chunk and never re-enters the diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 1281d20f17f..4cb6ea04f1b 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -52,7 +52,7 @@ class OmniGenerationScheduler(OmniSchedulerMixin, VLLMScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) model_config = self.vllm_config.model_config - # PR4: see OmniARScheduler.__init__ -- allowlisted async-chunk archs + # See OmniARScheduler.__init__: allowlisted async-chunk archs # (SharedMemory) use the coordinator + runner mixin; others keep the adapter. # _async_coord is True only for allowlisted stages (qwen3_omni # talker+code2wav); False (no behavior change) for everything else. @@ -75,7 +75,7 @@ def __init__(self, *args, **kwargs): async_chunk=_async_coord, ) self._latest_omni_connector_output: OmniConnectorOutput | None = None - # PR4: consumer-side terminal-completeness trackers (async-chunk coordinator + # Consumer-side terminal-completeness trackers (async-chunk coordinator # path only). Code2wav's terminal stage must finish on the producer's true # total code length, not on chunks-arrived-so-far; these carry the terminal # chunk marker / metadata across the finish-only connector cycle and remember @@ -110,7 +110,7 @@ def schedule(self) -> SchedulerOutput: cached_additional_information: dict[str, dict | None] = {} def _ensure_terminal_placeholder(request: Request) -> int: - # PR4: a terminal-ready coordinator request can still have a connector- + # A terminal-ready coordinator request can still have a connector- # delivered payload queued even when its prompt length has not grown. # Grow prompt_token_ids by one placeholder so the request is scheduled # for a one-token step that drains the ready payload, instead of being @@ -435,7 +435,7 @@ def _ensure_terminal_placeholder(request: Request) -> int: if self.chunk_transfer_adapter: self.chunk_transfer_adapter.postprocess_scheduler_output(scheduler_output) if self.input_coordinator: - # PR4: mirror the adapter postprocess on the coordinator path so per-cycle + # Mirror the adapter postprocess on the coordinator path so per-cycle # ready-chunk flags are cleared (see omni_ar_scheduler for full rationale). self.input_coordinator.postprocess_scheduler_output(scheduler_output) @@ -480,7 +480,7 @@ def finish_requests(self, request_ids, finished_status: RequestStatus) -> list[t return finished def _free_request(self, request: Request, delay_free_blocks: bool = False) -> dict[str, Any] | None: - # PR4: drop consumer-side terminal-completeness trackers for the freed + # Drop consumer-side terminal-completeness trackers for the freed # request (no-op outside the async-chunk coordinator path). self._deferred_terminal_chunk_req_ids.discard(request.request_id) self._deferred_terminal_request_metadata.pop(request.request_id, None) @@ -520,7 +520,7 @@ def update_from_output( num_nans_in_logits = model_runner_output.num_nans_in_logits kv_connector_output = model_runner_output.kv_connector_output - # PR4: read this cycle's terminal-chunk signal and producer metadata. + # Read this cycle's terminal-chunk signal and producer metadata. # Only consumed on the async-chunk coordinator path; for the adapter / # full-payload paths these stay empty and the blocks below are skipped. omni_output = getattr(model_runner_output, "omni_connector_output", None) @@ -691,7 +691,7 @@ def update_from_output( finish_reason = None routed_experts = None - # PR4: terminal-completeness gate (async-chunk coordinator path only). + # Terminal-completeness gate (async-chunk coordinator path only). # The true total code length comes from producer metadata -- not the # chunks-arrived-so-far prompt length -- so resolve it before deciding # the request is done. Defaults to the current prompt length for every diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py index 407ef000191..0e9882834a9 100644 --- a/vllm_omni/core/sched/omni_scheduling_coordinator.py +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -82,7 +82,7 @@ def uses_full_payload_input_coordinator(model_config: Any) -> bool: # (model_arch, model_stage) whose async-chunk RECEIVE is coordinated by # OmniSchedulingCoordinator (+ the runner-level mixin transport) instead of the # legacy scheduler-owned OmniChunkTransferAdapter. Intentionally EMPTY until the -# qwen3_omni cutover lands (PR4 final commit): an empty allowlist keeps every arch +# An empty allowlist keeps every arch # on the adapter, so each intermediate commit is behavior-preserving. Final # entries are the recv stages only -- stage-0 producers do not wait on chunks: # ("Qwen3OmniMoeForConditionalGeneration", "talker") # stage 1, recv from 0 diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py index 031c81e2f44..b8ff0ea30c3 100644 --- a/vllm_omni/data_entry_keys.py +++ b/vllm_omni/data_entry_keys.py @@ -22,7 +22,7 @@ from vllm_omni.engine import AdditionalInformationEntry, AdditionalInformationPayload -# PR4 MC-C.3b: marker key the runner-level finish sentinel sets on its (otherwise +# Marker key the runner-level finish sentinel sets on its (otherwise # empty) ``pooling_output`` so a model's async-chunk stage-input hook can flush a # terminal payload (e.g. code2wav's trailing partial codec). The legacy # scheduler-driven ``OmniChunkTransferAdapter`` never sets it, so adapter-driven diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 4bf88c6ab76..2e98904ffc8 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -721,7 +721,7 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, ) update_dict["mtp_inputs"] = last_talker_hidden, text_step - # PR4 async-chunk decode: the talker may emit a pad while waiting for the + # Async-chunk decode: the talker may emit a pad while waiting for the # next thinker decode embed to stream in; in that case it must NOT advance # num_processed_tokens (else the awaited token is skipped). Prefill and the # non-async path leave the flag unset -> default True -> advance as before. @@ -1042,7 +1042,7 @@ def _thinker_decode_to_talker_decode( else: thinker_decode_embed = current_decode_embed start_index = meta.get("num_processed_tokens", 0) - # PR4 async chunk: prefill may cache the first post-prefill Thinker + # Async chunk: prefill may cache the first post-prefill Thinker # decode row in ``cached_decode`` before the request enters decode. # Treat that cache as the prefix of the cumulative decode sequence and # consume by absolute index; popping would be clobbered by the next sync. diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index fecb9472cf2..e10b75103a6 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -814,7 +814,7 @@ def _code2wav_codec_config(transfer_manager: Any) -> tuple[int, int]: def _flush_code2wav_finish_tail(transfer_manager: Any, request: OmniEngineCoreRequest) -> OmniPayloadStruct: - """PR4 MC-C.3b: terminal payload for the runner's async finish sentinel. + """terminal payload for the runner's async finish sentinel. The producer runner sends every in-step codec chunk with ``finished=False`` (it cannot know finish at sample time), so the trailing partial chunk that the @@ -856,11 +856,11 @@ def talker2code2wav_async_chunk( """ if not isinstance(pooling_output, dict): return None - # PR4 MC-C.3b: runner finish sentinel (empty payload + marker the adapter + # Runner finish sentinel (empty payload + marker the adapter # never sets) -> flush the held trailing partial codec as the terminal chunk. if pooling_output.get(ASYNC_FINISH_SENTINEL_KEY): return _flush_code2wav_finish_tail(transfer_manager, request) - # PR4 BUG5: the coordinator+mixin runner send (MC-C) hands this builder a + # The coordinator/mixin runner send path hands this builder a # FLATTENED pooling_output (flatten_payload -> dotted "codes.audio"), whereas # the legacy adapter send passes it nested. Reads below expect nested # ("codes"->"audio"), so without this they silently get None and NO codes diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index c1e285da58c..eafd57bfe52 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -569,7 +569,7 @@ def execute_model( flush_ids.update({rid for rid in self._pending_full_payload_send if rid not in self.requests}) if flush_ids: self.flush_full_payload_outputs(flush_ids) - # PR4 MC-C.3: emit the terminal chunk for async-chunk requests that + # Emit the terminal chunk for async-chunk requests that # finished last cycle (in-step sends were finished=False; the runner # learns finish only here). Before _update_states frees them. # Dormant until the allowlist flip (predicate False -> no-op). @@ -1301,15 +1301,15 @@ def _unwrap_lists(v): if req_state is not None and pooler_output[i]: self.accumulate_full_payload_output(rid, pooler_output[i], req_state) - # PR4 MC-C.1: runner-side async-chunk producer send. Gated by the SAME + # Runner-side async-chunk producer send. Gated by the SAME # predicate the scheduler uses to select the coordinator (arch # allowlisted AND SharedMemoryConnector), NOT a bare async_chunk flag -- # a non-coordinator async arch keeps the legacy adapter (scheduler-side # save_async), so a bare-flag guard here would double-send. Only the - # V5 downstream-pooler-filtered requests are sent (never an ungated + # downstream-pooler-filtered requests are sent (never an ungated # terminal None -> starvation). is_finished is always False here: the # runner cannot know finish at sample_tokens time (engine core decides - # after), so the terminal is handled separately (later sub-commit), + # after), so the terminal is handled separately below, # mirroring the adapter. This branch is active for allowlisted stages # (qwen3_omni talker+code2wav) and dormant for every other arch/stage. if pooler_output and uses_async_chunk_coordinator(getattr(self, "model_config", None)): diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 7b2de000ed5..73f6ed8d373 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1564,7 +1564,7 @@ def _preprocess( if hasattr(self.model, "has_preprocess") or hasattr(self.model, "enable_update_additional_information"): if self.vllm_config.model_config.async_chunk: self._update_additional_information(scheduler_output) - # PR4: the coordinator+mixin async-chunk path delivers the heavy + # The coordinator/mixin async-chunk path delivers the heavy # stage payload (e.g. thinker embed.prefill) via the worker-local # _local_stage_payload_cache, NOT via scheduler # additional_information (which _update_additional_information diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index f90b9c86a49..70ddd207299 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -2008,7 +2008,7 @@ def _poll_single_request(self, req_id: str) -> bool: self._chunk_stream_completed.add(req_id) if is_segment_finished: self._chunk_segment_finished_req_ids.add(req_id) - # Local cache (RFC §2.4) — DEEP-merge, don't replace, so that + # Local cache: deep-merge instead of replacing, so that # earlier chunk keys (e.g. chunk-0's embed.prefill) survive when a # later decode chunk arrives with embed.decode. Shallow dict.update # replaces the nested 'embed' wholesale -> lost prefill -> KeyError From 9d920c282649580d72957ef02dea6bb77e237552 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Wed, 10 Jun 2026 16:29:19 +0000 Subject: [PATCH 09/10] Clean async chunk segment send state Signed-off-by: natureofnature --- tests/worker/test_omni_connector_mixin.py | 35 +++++++++++++++++++ .../omni_connector_model_runner_mixin.py | 15 +++++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py index 0abf06b5028..18da073d15b 100644 --- a/tests/worker/test_omni_connector_mixin.py +++ b/tests/worker/test_omni_connector_mixin.py @@ -282,6 +282,41 @@ def proc(transfer_manager, pooling_output, request, is_finished=False): self.assertIsNotNone(data) self.assertEqual(data.get("tail"), "terminal-codec", "hook terminal payload must be preferred over the flag") + def test_segment_finish_sentinel_cleans_segment_state_after_put(self): + sender = MixinHost() + sender.init_omni_connectors(vllm_config=None, model_config=_make_model_config(stage_id=0, async_chunk=True)) + self._quiesce_save_thread(sender) + sender._omni_connector = MockConnector(stage_id=0) + sender._stage_id = 0 + sender._async_chunk = True + + def proc(transfer_manager, pooling_output, request, is_finished=False): + return {"meta": {"finished": torch.tensor(True)}} + + sender._custom_process_func = proc + sender._put_req_chunk["ext-1"] = 7 + sender._requests_num_chunks_sent["ext-1"] = 42 + sender._code_prompt_token_ids["ext-1"] = [[1, 2, 3]] + sender._cached_ic["ext-1"] = 16 + + req = SimpleNamespace(request_id="req-1", req_id="req-1", external_req_id="ext-1", is_finished=lambda: False) + self.assertTrue(sender.enqueue_finish_sentinel(req, "ext-1", is_segment_finished=True)) + + self.assertEqual(sender._put_req_chunk["ext-1"], 8) + self.assertIn("ext-1", sender._requests_num_chunks_sent) + self.assertIn("ext-1", sender._code_prompt_token_ids) + self.assertIn("ext-1", sender._cached_ic) + + with sender._lock: + task = sender._pending_save_reqs["ext-1"].popleft() + + self.assertTrue(sender._send_single_request(task)) + + self.assertEqual(sender._put_req_chunk["ext-1"], 8) + self.assertNotIn("ext-1", sender._requests_num_chunks_sent) + self.assertNotIn("ext-1", sender._code_prompt_token_ids) + self.assertNotIn("ext-1", sender._cached_ic) + class TestMixinKVCacheTransfer(unittest.TestCase): """Test 3: KV cache delegation to OmniKVTransferManager.""" diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index 70ddd207299..0aaee99693a 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -362,8 +362,13 @@ def drop_inactive_request_delivery_state(self, req_id: str) -> None: def _drop_send_side_payload_state(self, req_id: str, ext_id: str | None) -> None: if ext_id is not None: self._send_side_request_payload.pop(ext_id, None) - self._cached_ic.pop(ext_id, None) + self._clear_segment_send_state(ext_id) self._send_side_request_payload.pop(req_id, None) + self._clear_segment_send_state(req_id) + + def _clear_segment_send_state(self, req_id: str) -> None: + self._requests_num_chunks_sent.pop(req_id, None) + self._code_prompt_token_ids.pop(req_id, None) self._cached_ic.pop(req_id, None) def _cleanup_recv_delivery_state(self, req_id: str) -> None: @@ -1344,6 +1349,7 @@ def enqueue_finish_sentinel(self, request: Any, request_id: str, *, is_segment_f "put_key": connector_put_key, "data": payload_data, "request_id": request_id, + "is_segment_finished": is_segment_finished, } with self._lock: self._pending_save_reqs.setdefault(request_id, deque()).append(task) @@ -2181,6 +2187,9 @@ def _send_single_request(self, task: dict) -> bool: return False self._decrement_pending_save_count(request_id) + if task.get("is_segment_finished"): + with self._lock: + self._clear_segment_send_state(request_id) return True def _decrement_pending_save_count(self, request_id: str) -> None: @@ -2197,11 +2206,9 @@ def _decrement_pending_save_count(self, request_id: str) -> None: cleanup_req_id = request_id if cleanup_req_id is not None: self._put_req_chunk.pop(cleanup_req_id, None) - self._requests_num_chunks_sent.pop(cleanup_req_id, None) self._send_side_request_snapshot.pop(cleanup_req_id, None) self._send_side_request_payload.pop(cleanup_req_id, None) - self._code_prompt_token_ids.pop(cleanup_req_id, None) - self._cached_ic.pop(cleanup_req_id, None) + self._clear_segment_send_state(cleanup_req_id) # ------------------------------------------------------------------ # # Payload accumulation (ported from OmniChunkTransferAdapter) From 088cfba90e98bbf5e1773d10b40349704a59f940 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Thu, 11 Jun 2026 03:30:10 +0000 Subject: [PATCH 10/10] Move segment cleanup to enqueue time Signed-off-by: natureofnature --- tests/worker/test_omni_connector_mixin.py | 19 +++++++++++-------- .../omni_connector_model_runner_mixin.py | 6 ++---- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py index 18da073d15b..ccca8d80c85 100644 --- a/tests/worker/test_omni_connector_mixin.py +++ b/tests/worker/test_omni_connector_mixin.py @@ -282,7 +282,7 @@ def proc(transfer_manager, pooling_output, request, is_finished=False): self.assertIsNotNone(data) self.assertEqual(data.get("tail"), "terminal-codec", "hook terminal payload must be preferred over the flag") - def test_segment_finish_sentinel_cleans_segment_state_after_put(self): + def test_segment_finish_sentinel_cleans_segment_state_at_enqueue(self): sender = MixinHost() sender.init_omni_connectors(vllm_config=None, model_config=_make_model_config(stage_id=0, async_chunk=True)) self._quiesce_save_thread(sender) @@ -303,19 +303,22 @@ def proc(transfer_manager, pooling_output, request, is_finished=False): self.assertTrue(sender.enqueue_finish_sentinel(req, "ext-1", is_segment_finished=True)) self.assertEqual(sender._put_req_chunk["ext-1"], 8) - self.assertIn("ext-1", sender._requests_num_chunks_sent) - self.assertIn("ext-1", sender._code_prompt_token_ids) - self.assertIn("ext-1", sender._cached_ic) + self.assertNotIn("ext-1", sender._requests_num_chunks_sent) + self.assertNotIn("ext-1", sender._code_prompt_token_ids) + self.assertNotIn("ext-1", sender._cached_ic) with sender._lock: task = sender._pending_save_reqs["ext-1"].popleft() - self.assertTrue(sender._send_single_request(task)) + sender._requests_num_chunks_sent["ext-1"] = 1 + sender._code_prompt_token_ids["ext-1"] = [[9]] + sender._cached_ic["ext-1"] = 2 + self.assertTrue(sender._send_single_request(task)) self.assertEqual(sender._put_req_chunk["ext-1"], 8) - self.assertNotIn("ext-1", sender._requests_num_chunks_sent) - self.assertNotIn("ext-1", sender._code_prompt_token_ids) - self.assertNotIn("ext-1", sender._cached_ic) + self.assertEqual(sender._requests_num_chunks_sent["ext-1"], 1) + self.assertEqual(sender._code_prompt_token_ids["ext-1"], [[9]]) + self.assertEqual(sender._cached_ic["ext-1"], 2) class TestMixinKVCacheTransfer(unittest.TestCase): diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index 0aaee99693a..b8d788f1450 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -1349,11 +1349,12 @@ def enqueue_finish_sentinel(self, request: Any, request_id: str, *, is_segment_f "put_key": connector_put_key, "data": payload_data, "request_id": request_id, - "is_segment_finished": is_segment_finished, } with self._lock: self._pending_save_reqs.setdefault(request_id, deque()).append(task) self._pending_save_counts[request_id] += 1 + if is_segment_finished: + self._clear_segment_send_state(request_id) self._work_available.set() return True @@ -2187,9 +2188,6 @@ def _send_single_request(self, task: dict) -> bool: return False self._decrement_pending_save_count(request_id) - if task.get("is_segment_finished"): - with self._lock: - self._clear_segment_send_state(request_id) return True def _decrement_pending_save_count(self, request_id: str) -> None: