diff --git a/tests/core/sched/test_chunk_scheduling_coordinator.py b/tests/core/sched/test_chunk_scheduling_coordinator.py new file mode 100644 index 00000000000..5e19465e224 --- /dev/null +++ b/tests/core/sched/test_chunk_scheduling_coordinator.py @@ -0,0 +1,690 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for OmniSchedulingCoordinator (formerly ChunkSchedulingCoordinator). + +These tests use mock request objects and mock queues. They do not require +GPU, vLLM runtime, or any connector. +""" + +from __future__ import annotations + +import unittest +from types import SimpleNamespace + +import vllm_omni.core.sched.omni_scheduling_coordinator as coord_mod +from vllm_omni.core.sched.omni_scheduling_coordinator import ( + ChunkSchedulingCoordinator, + OmniSchedulingCoordinator, +) + +# ------------------------------------------------------------------ # +# Mock helpers +# ------------------------------------------------------------------ # + + +class _RequestStatus: + WAITING = "waiting" + RUNNING = "running" + WAITING_FOR_CHUNK = "waiting_for_chunk" + WAITING_FOR_INPUT = "waiting_for_input" + FINISHED_STOPPED = "finished_stopped" + + +# Patch RequestStatus for tests that don't import vllm +try: + from vllm.v1.request import RequestStatus +except ImportError: + RequestStatus = _RequestStatus # type: ignore[misc,assignment] + +if not hasattr(RequestStatus, "WAITING_FOR_INPUT"): + coord_mod.RequestStatus = _RequestStatus # type: ignore[assignment] + RequestStatus = _RequestStatus # type: ignore[misc,assignment] + + +def _make_request(req_id: str, status: str = "waiting") -> SimpleNamespace: + return SimpleNamespace( + request_id=req_id, + external_req_id=req_id, + status=status, + additional_information=None, + prompt_token_ids=[], + num_prompt_tokens=0, + num_computed_tokens=0, + _all_token_ids=[], + _output_token_ids=[], + ) + + +class MockQueue: + """Simplified queue that mimics the Scheduler waiting queue interface.""" + + def __init__(self, items: list | None = None): + self._items: list = list(items or []) + + def __iter__(self): + return iter(self._items) + + def __len__(self): + return len(self._items) + + def __contains__(self, item): + return item in self._items + + def add_request(self, request): + self._items.append(request) + + def prepend_requests(self, requests): + self._items = list(requests) + self._items + + def remove(self, request): + self._items.remove(request) + + def remove_requests(self, requests): + remove_set = set(id(r) for r in requests) + self._items = [r for r in self._items if id(r) not in remove_set] + + +# ------------------------------------------------------------------ # +# Tests +# ------------------------------------------------------------------ # + + +class TestChunkCoordinatorStateTransition(unittest.TestCase): + """Test 5: process_pending_chunks transitions WAITING_FOR_CHUNK → target.""" + + def test_ready_request_transitions_to_waiting(self): + coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True) + + req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK) + waiting = MockQueue([req]) + running: list = [] + + coord.process_pending_chunks( + waiting, + running, + chunk_ready_req_ids={"r1"}, + chunk_finished_req_ids=set(), + ) + + self.assertEqual(req.status, RequestStatus.WAITING) + self.assertIn("r1", coord.requests_with_ready_chunks) + + def test_non_ready_stays_waiting_for_chunk(self): + coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True) + + req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK) + waiting = MockQueue([req]) + running: list = [] + + coord.process_pending_chunks( + waiting, + running, + chunk_ready_req_ids=set(), + chunk_finished_req_ids=set(), + ) + + self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK) + + def test_stage_0_is_noop(self): + coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=0) + req = _make_request("r1") + waiting = MockQueue([req]) + running: list = [] + + coord.process_pending_chunks( + waiting, + running, + chunk_ready_req_ids={"r1"}, + chunk_finished_req_ids=set(), + ) + self.assertNotEqual(req.status, RequestStatus.WAITING_FOR_CHUNK) + + +class TestChunkCoordinatorRestoreQueues(unittest.TestCase): + """Test 6: restore_queues returns waiting-for-chunk requests.""" + + def test_restore(self): + coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1) + + r1 = _make_request("r1") + r2 = _make_request("r2") + coord._waiting_for_chunk_waiting.append(r1) + coord._waiting_for_chunk_running.append(r2) + + waiting = MockQueue() + running: list = [] + + coord.restore_queues(waiting, running) + + self.assertIn(r1, waiting) + self.assertIn(r2, running) + self.assertEqual(len(coord._waiting_for_chunk_waiting), 0) + self.assertEqual(len(coord._waiting_for_chunk_running), 0) + + +class TestChunkCoordinatorFinishedSignal(unittest.TestCase): + """Test 8: chunk_finished_req_ids → finished_requests.""" + + def test_finished_signal(self): + coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True) + + req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK) + waiting = MockQueue([req]) + running: list = [] + + coord.process_pending_chunks( + waiting, + running, + chunk_ready_req_ids={"r1"}, + chunk_finished_req_ids={"r1"}, + ) + + self.assertIn("r1", coord.finished_requests) + + +class TestChunkCoordinatorUpdateRequestMetadata(unittest.TestCase): + """Test update_request_metadata applies scheduling metadata to requests.""" + + def test_ar_mode_no_longer_sets_additional_information(self): + """AR mode only processes scheduling metadata, not full payloads.""" + coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1) + + req = _make_request("r1") + requests = {"r1": req} + + # Only scheduling metadata is passed now (full payload stays in model runner) + request_metadata = {"r1": {"next_stage_prompt_len": 50}} + + coord.update_request_metadata(requests, request_metadata, model_mode="ar") + + # next_stage_prompt_len should update prompt_token_ids + self.assertEqual(len(req.prompt_token_ids), 50) + self.assertEqual(req.num_prompt_tokens, 50) + # additional_information should NOT be set + self.assertIsNone(getattr(req, "additional_information", None)) + + def test_generation_mode(self): + coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1) + + req = _make_request("r1") + req.prompt_token_ids = [0, 0, 0] + requests = {"r1": req} + + request_metadata = { + "r1": { + "code_predictor_codes": [10, 20, 30], + "left_context_size": 25, + } + } + + coord.update_request_metadata(requests, request_metadata, model_mode="generation") + + self.assertEqual(req.prompt_token_ids, [10, 20, 30]) + self.assertEqual(req.num_computed_tokens, 0) + self.assertIsNone(req.additional_information) + self.assertEqual(req._omni_initial_model_buffer, {"left_context_size": 25}) + + +class TestChunkCoordinatorPostprocess(unittest.TestCase): + """Test postprocess_scheduler_output clears ready chunks.""" + + def test_clear_ready(self): + coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1) + coord.requests_with_ready_chunks = {"r1", "r2"} + + new_req = SimpleNamespace(req_id="r1") + cached_reqs = SimpleNamespace(req_ids=["r2"]) + scheduler_output = SimpleNamespace( + scheduled_new_reqs=[new_req], + scheduled_cached_reqs=cached_reqs, + ) + + coord.postprocess_scheduler_output(scheduler_output) + + self.assertEqual(coord.requests_with_ready_chunks, set()) + + +class TestWaitingForInputTransition(unittest.TestCase): + """Test B8: process_pending_full_payload_inputs transitions WAITING_FOR_INPUT.""" + + def test_transition_on_recv(self): + coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1) + + req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT) + waiting = MockQueue([req]) + running: list = [] + + coord.process_pending_full_payload_inputs( + waiting, + running, + stage_recv_req_ids={"r1"}, + ) + + self.assertEqual(req.status, RequestStatus.WAITING) + + def test_stays_waiting_for_input_if_not_received(self): + coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1) + + req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT) + waiting = MockQueue([req]) + running: list = [] + + coord.process_pending_full_payload_inputs( + waiting, + running, + stage_recv_req_ids=set(), + ) + + self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT) + self.assertEqual(len(coord._waiting_for_input), 1) + + def test_stage_0_is_noop(self): + coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=0) + + req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT) + waiting = MockQueue([req]) + running: list = [] + + coord.process_pending_full_payload_inputs( + waiting, + running, + stage_recv_req_ids={"r1"}, + ) + self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT) + + def test_restore_queues_includes_waiting_for_input(self): + coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1) + + r1 = _make_request("r1") + coord._waiting_for_input.append(r1) + + waiting = MockQueue() + running: list = [] + + coord.restore_queues(waiting, running) + + self.assertIn(r1, waiting) + self.assertEqual(len(coord._waiting_for_input), 0) + + def test_full_payload_mode_auto_transitions_waiting_to_waiting_for_input(self): + """In full_payload_mode (async_chunk=False), fresh WAITING requests on + non-Stage-0 should be transitioned to WAITING_FOR_INPUT.""" + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=10, + stage_id=1, + async_chunk=False, + ) + + req = _make_request("r1", status=RequestStatus.WAITING) + waiting = MockQueue([req]) + running: list = [] + + coord.process_pending_full_payload_inputs( + waiting, + running, + stage_recv_req_ids=set(), + ) + + self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT) + self.assertEqual(len(coord._waiting_for_input), 1) + self.assertEqual(len(coord.pending_input_registrations), 1) + + def test_async_chunk_mode_does_not_auto_transition(self): + """In async_chunk mode, fresh WAITING requests should NOT be + transitioned to WAITING_FOR_INPUT.""" + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=10, + stage_id=1, + async_chunk=True, + ) + + req = _make_request("r1", status=RequestStatus.WAITING) + waiting = MockQueue([req]) + running: list = [] + + coord.process_pending_full_payload_inputs( + waiting, + running, + stage_recv_req_ids=set(), + ) + + self.assertEqual(req.status, RequestStatus.WAITING) + + def test_pending_input_registrations(self): + coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1) + + req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT) + waiting = MockQueue([req]) + running: list = [] + + coord.process_pending_full_payload_inputs( + waiting, + running, + stage_recv_req_ids=set(), + ) + + self.assertEqual(len(coord.pending_input_registrations), 1) + self.assertEqual(coord.pending_input_registrations[0].request_id, "r1") + + +class TestTimeoutDetection(unittest.TestCase): + """Regression tests for orphaned pending-recv timeout detection. + + Covers the full lifecycle: + 1. Request enters WAITING_FOR_CHUNK from either waiting or running queue + 2. restore_queues() moves it back to the scheduler queue + 3. Timeout fires via collect_timed_out_request_ids() + 4. Scheduler removes from both queues and calls _free_request() + """ + + def test_waiting_since_recorded_on_chunk_wait(self): + """_waiting_since is set when a request enters WAITING_FOR_CHUNK.""" + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=10, + stage_id=1, + async_chunk=True, + ) + req = _make_request("r1", status=RequestStatus.WAITING) + waiting = MockQueue([req]) + + coord.process_pending_chunks( + waiting, + [], + chunk_ready_req_ids=set(), + chunk_finished_req_ids=set(), + ) + + self.assertIn("r1", coord._waiting_since) + self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK) + + def test_waiting_since_cleared_on_chunk_arrival(self): + """_waiting_since is cleared when a chunk arrives.""" + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=10, + stage_id=1, + async_chunk=True, + ) + req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK) + waiting = MockQueue([req]) + + coord.process_pending_chunks( + waiting, + [], + chunk_ready_req_ids={"r1"}, + chunk_finished_req_ids=set(), + ) + + self.assertNotIn("r1", coord._waiting_since) + + def test_waiting_since_recorded_on_input_wait(self): + """_waiting_since is set when a request enters WAITING_FOR_INPUT.""" + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=10, + stage_id=1, + async_chunk=False, + ) + req = _make_request("r1", status=RequestStatus.WAITING) + waiting = MockQueue([req]) + + coord.process_pending_full_payload_inputs( + waiting, + [], + stage_recv_req_ids=set(), + ) + + self.assertIn("r1", coord._waiting_since) + + def test_waiting_since_cleared_on_input_arrival(self): + """_waiting_since is cleared when input data arrives.""" + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=10, + stage_id=1, + async_chunk=False, + ) + req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT) + coord._waiting_for_input.append(req) + coord._waiting_since["r1"] = 0.0 + + waiting = MockQueue() + coord.process_pending_full_payload_inputs( + waiting, + [], + stage_recv_req_ids={"r1"}, + ) + + self.assertNotIn("r1", coord._waiting_since) + self.assertEqual(req.status, RequestStatus.WAITING) + + def test_collect_timed_out_request_ids_no_timeout(self): + """No IDs returned when nothing has timed out.""" + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=10, + stage_id=1, + ) + import time + + coord._waiting_since["r1"] = time.monotonic() + + result = coord.collect_timed_out_request_ids(timeout_s=300.0) + self.assertEqual(result, set()) + + def test_collect_timed_out_request_ids_expired(self): + """Timed-out IDs are returned and _waiting_since is cleared.""" + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=10, + stage_id=1, + ) + coord._waiting_since["r1"] = 0.0 # epoch → definitely expired + coord._waiting_since["r2"] = 0.0 + + import time + + coord._waiting_since["r3"] = time.monotonic() + 9999 # far future + + result = coord.collect_timed_out_request_ids(timeout_s=1.0) + + self.assertEqual(result, {"r1", "r2"}) + self.assertNotIn("r1", coord._waiting_since) + self.assertNotIn("r2", coord._waiting_since) + self.assertIn("r3", coord._waiting_since) + + def test_collect_removes_from_coordinator_queues(self): + """Timed-out requests are defensively removed from internal queues.""" + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=10, + stage_id=1, + ) + r1 = _make_request("r1") + r2 = _make_request("r2") + coord._waiting_for_chunk_waiting.append(r1) + coord._waiting_for_input.append(r2) + coord._waiting_since["r1"] = 0.0 + coord._waiting_since["r2"] = 0.0 + + result = coord.collect_timed_out_request_ids(timeout_s=1.0) + + self.assertEqual(result, {"r1", "r2"}) + self.assertEqual(len(coord._waiting_for_chunk_waiting), 0) + self.assertEqual(len(coord._waiting_for_input), 0) + + def test_free_finished_request_clears_waiting_since(self): + """free_finished_request clears _waiting_since.""" + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=10, + stage_id=1, + ) + coord._waiting_since["r1"] = 0.0 + coord.free_finished_request("r1") + self.assertNotIn("r1", coord._waiting_since) + + def test_timeout_from_running_queue_full_lifecycle(self): + """End-to-end: request from running → WAITING_FOR_CHUNK → restore → + timeout → removed from running list. + + This is the critical regression case: WAITING_FOR_CHUNK requests + that originated from self.running are placed back into self.running + by restore_queues(), but their status remains WAITING_FOR_CHUNK. + The scheduler must remove from BOTH queues unconditionally. + """ + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=10, + stage_id=1, + async_chunk=True, + ) + + # 1) Request starts in running queue with WAITING status + req = _make_request("r1", status=RequestStatus.WAITING) + running = [req] + waiting = MockQueue() + + # 2) process_pending_chunks: moves to WAITING_FOR_CHUNK + coord.process_pending_chunks( + waiting, + running, + chunk_ready_req_ids=set(), + chunk_finished_req_ids=set(), + ) + self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK) + self.assertIn("r1", coord._waiting_since) + self.assertEqual(len(coord._waiting_for_chunk_running), 1) + + # 3) restore_queues: back to running (status stays WAITING_FOR_CHUNK) + coord.restore_queues(waiting, running) + self.assertIn(req, running) + self.assertEqual(len(coord._waiting_for_chunk_running), 0) + self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK) + + # 4) Force timeout by setting _waiting_since to epoch + coord._waiting_since["r1"] = 0.0 + + timed_out_ids = coord.collect_timed_out_request_ids(timeout_s=1.0) + self.assertEqual(timed_out_ids, {"r1"}) + + # 5) Scheduler removes from both queues (simulating the scheduler path) + timed_out_id_set = {id(req)} + running = [r for r in running if id(r) not in timed_out_id_set] + waiting.remove_requests([req]) + + self.assertNotIn(req, running) + self.assertEqual(len(waiting), 0) + + def test_timeout_from_waiting_queue_full_lifecycle(self): + """End-to-end: request from waiting → WAITING_FOR_CHUNK → restore → + timeout → removed from waiting queue.""" + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=10, + stage_id=1, + async_chunk=True, + ) + + req = _make_request("r1", status=RequestStatus.WAITING) + waiting = MockQueue([req]) + running: list = [] + + coord.process_pending_chunks( + waiting, + running, + chunk_ready_req_ids=set(), + chunk_finished_req_ids=set(), + ) + self.assertEqual(len(coord._waiting_for_chunk_waiting), 1) + + coord.restore_queues(waiting, running) + self.assertIn(req, waiting) + + coord._waiting_since["r1"] = 0.0 + timed_out_ids = coord.collect_timed_out_request_ids(timeout_s=1.0) + self.assertEqual(timed_out_ids, {"r1"}) + + waiting.remove_requests([req]) + self.assertEqual(len(waiting), 0) + + +class TestOverflowPreemption(unittest.TestCase): + """Tests for P1-1: overflow requests must get WAITING status. + + Overflow happens when multiple WAITING_FOR_CHUNK requests in + ``_waiting_for_chunk_running`` receive their chunk in the same cycle. + ``_process_chunk_queue`` restores them to RUNNING (``continue`` + path) while RUNNING requests without chunks are moved out. If the + net result exceeds ``scheduler_max_num_seqs``, the tail is pushed + to ``waiting_queue`` and must have status == WAITING. + """ + + def test_overflow_sets_waiting_status(self): + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=1, + stage_id=1, + async_chunk=True, + ) + + # r1 is currently RUNNING in the queue. + # r2, r3 were previously moved to _waiting_for_chunk_running. + r1 = _make_request("r1", status=RequestStatus.RUNNING) + r2 = _make_request("r2", status=RequestStatus.WAITING_FOR_CHUNK) + r3 = _make_request("r3", status=RequestStatus.WAITING_FOR_CHUNK) + + running = [r1] + waiting = MockQueue([]) + coord._waiting_for_chunk_running.extend([r2, r3]) + + # restore_queues puts r2, r3 back into running + coord.restore_queues(waiting, running) + self.assertEqual(len(running), 3) + + # Now process_pending_chunks with r2, r3 chunks ready: + # _process_chunk_queue will: + # r1 (RUNNING) → no chunk → move to _waiting_for_chunk_running + # r2 (WAITING_FOR_CHUNK, chunk ready) → set RUNNING, stay in running + # r3 (WAITING_FOR_CHUNK, chunk ready) → set RUNNING, stay in running + # running = [r2, r3], len=2 > max=1 → overflow + coord.process_pending_chunks( + waiting, + running, + chunk_ready_req_ids={"r2", "r3"}, + chunk_finished_req_ids=set(), + ) + + self.assertEqual(len(running), 1) + self.assertEqual(len(waiting), 1) + overflow_req = list(waiting)[0] + self.assertEqual( + overflow_req.status, + RequestStatus.WAITING, + f"Overflowed request should have WAITING status, got {overflow_req.status}", + ) + + def test_overflow_does_not_strand_request(self): + """Without the fix, the overflowed request would keep its + RUNNING status in the waiting queue and never be re-scheduled.""" + coord = OmniSchedulingCoordinator( + scheduler_max_num_seqs=1, + stage_id=1, + async_chunk=True, + ) + + r1 = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK) + r2 = _make_request("r2", status=RequestStatus.WAITING_FOR_CHUNK) + coord._waiting_for_chunk_running.extend([r1, r2]) + + running: list = [] + waiting = MockQueue([]) + + coord.restore_queues(waiting, running) + self.assertEqual(len(running), 2) + + coord.process_pending_chunks( + waiting, + running, + chunk_ready_req_ids={"r1", "r2"}, + chunk_finished_req_ids=set(), + ) + + self.assertEqual(len(running), 1) + self.assertEqual(len(waiting), 1) + for req in waiting: + self.assertNotEqual(req.status, RequestStatus.RUNNING, "Overflowed request must not keep RUNNING status") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py new file mode 100644 index 00000000000..0e162a37e5b --- /dev/null +++ b/tests/worker/test_omni_connector_mixin.py @@ -0,0 +1,1419 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for OmniConnectorModelRunnerMixin. + +These tests use a mock connector (in-memory dict store) and do not require +GPU or vLLM runtime. +""" + +from __future__ import annotations + +import time +import unittest +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from vllm_omni.outputs import OmniConnectorOutput +from vllm_omni.worker.omni_connector_model_runner_mixin import ( + OmniConnectorModelRunnerMixin, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +# ------------------------------------------------------------------ # +# Mock helpers +# ------------------------------------------------------------------ # + + +class MockConnector: + """In-memory connector for testing (mimics OmniConnectorBase).""" + + def __init__(self, stage_id: int = 0): + self.stage_id = stage_id + self._store: dict[str, Any] = {} + + def put(self, from_stage, to_stage, put_key, data): + key = f"{from_stage}_{to_stage}_{put_key}" + self._store[key] = data + return True, len(str(data)), None + + def get(self, from_stage, to_stage, get_key, metadata=None): + key = f"{from_stage}_{to_stage}_{get_key}" + data = self._store.pop(key, None) + if data is None: + return None + return data, len(str(data)) + + def close(self): + pass + + +def _make_model_config( + stage_id: int = 0, + async_chunk: bool = False, + worker_type: str = "ar", + custom_func: str | None = None, +) -> SimpleNamespace: + return SimpleNamespace( + stage_connector_config=None, + async_chunk=async_chunk, + worker_type=worker_type, + custom_process_next_stage_input_func=custom_func, + ) + + +def _make_request(req_id: str, external_req_id: str | None = None): + r = SimpleNamespace( + request_id=req_id, + external_req_id=external_req_id or req_id, + additional_information=None, + prompt_token_ids=[], + num_computed_tokens=0, + ) + return r + + +class MixinHost(OmniConnectorModelRunnerMixin): + """Minimal class that mixes in the mixin for testing.""" + + pass + + +class _FakeTPGroup: + def __init__(self, *, world_size: int, rank_in_group: int, follower_result: Any = None): + self.world_size = world_size + self.rank_in_group = rank_in_group + self.follower_result = follower_result + self.broadcast_inputs: list[Any] = [] + + def broadcast_object(self, obj: Any | None = None, src: int = 0): + self.broadcast_inputs.append(obj) + if self.rank_in_group == src: + return obj + return self.follower_result + + +# ------------------------------------------------------------------ # +# Test cases +# ------------------------------------------------------------------ # + + +class TestMixinAsyncChunkSendRecv(unittest.TestCase): + """Test 2: Async chunk send/recv + bg threads.""" + + def test_send_chunk_passes_is_finished_and_connector(self): + connector = MockConnector(stage_id=0) + + sender = MixinHost() + sender.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0, async_chunk=True), + ) + sender._omni_connector = connector + sender._stage_id = 0 + sender._async_chunk = True + + seen = {} + + def mock_process(transfer_manager, pooling_output, request, is_finished=False): + seen["connector"] = transfer_manager.connector + seen["is_finished"] = is_finished + return {"data": pooling_output, "finished": is_finished} + + sender._custom_process_func = mock_process + + request = _make_request("req-1", "ext-req-1") + request.is_finished = lambda: True + sender._send_single_request( + { + "stage_id": 0, + "next_stage_id": 1, + "request_id": "ext-req-1", + "request": request, + "pooling_output": {"value": 42}, + } + ) + self.assertIs(seen["connector"], connector) + self.assertTrue(seen["is_finished"]) + + sender.shutdown_omni_connectors() + + def test_send_chunk_does_not_retry_real_type_error(self): + connector = MockConnector(stage_id=0) + + sender = MixinHost() + sender.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0, async_chunk=True), + ) + sender._omni_connector = connector + sender._stage_id = 0 + sender._async_chunk = True + + seen = {"calls": 0} + + def broken_process(transfer_manager, pooling_output, request, is_finished=""): + seen["calls"] += 1 + return {"data": is_finished + "tail"} + + sender._custom_process_func = broken_process + + request = _make_request("req-1", "ext-req-1") + request.is_finished = lambda: True + ok = sender.send_chunk(request, pooling_output={"value": 42}) + self.assertFalse(ok) + self.assertEqual(seen["calls"], 1) + + sender.shutdown_omni_connectors() + + +class TestMixinKVCacheTransfer(unittest.TestCase): + """Test 3: KV cache delegation to OmniKVTransferManager.""" + + def test_send_kv_delegates(self): + mock_kvm = MagicMock() + mock_kvm.handle_finished_requests_kv_transfer.return_value = ["req-1"] + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + kv_transfer_manager=mock_kvm, + ) + + result = host.send_kv_cache( + finished_reqs={"req-1": {"seq_len": 10, "block_ids": [0]}}, + kv_caches=[], + block_size=16, + cache_dtype="float16", + ) + self.assertEqual(result, ["req-1"]) + mock_kvm.handle_finished_requests_kv_transfer.assert_called_once() + + host.shutdown_omni_connectors() + + def test_recv_kv_delegates(self): + mock_kvm = MagicMock() + mock_kvm.receive_kv_cache_for_request.return_value = ({"layer_blocks": {}}, 100) + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + kv_transfer_manager=mock_kvm, + ) + + data, size = host.recv_kv_cache("req-1") + self.assertIsNotNone(data) + self.assertEqual(size, 100) + mock_kvm.receive_kv_cache_for_request.assert_called_once() + + host.shutdown_omni_connectors() + + def test_receive_multi_kv_fetches_companions_via_mixin(self): + mock_kvm = MagicMock() + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + kv_transfer_manager=mock_kvm, + ) + + host.recv_kv_cache = MagicMock( + side_effect=[({"layer_blocks": {"k": [1]}}, 64), ({"layer_blocks": {"k": [2]}}, 32)] + ) + seen = {} + + def collect_cfg(request_id, cfg_role_payloads): + seen["request_id"] = request_id + seen["cfg_role_payloads"] = cfg_role_payloads + return {"cfg_text_kv_metadata": {"seq_len": 3}} + + req = SimpleNamespace( + request_id="req-1", + sampling_params=SimpleNamespace(cfg_kv_request_ids={"cfg_text": "req-1__cfg_text"}), + ) + ok = host.receive_multi_kv_cache(req, cfg_kv_collect_func=collect_cfg) + self.assertTrue(ok) + host.recv_kv_cache.assert_any_call("req-1", target_device=None) + host.recv_kv_cache.assert_any_call("req-1__cfg_text", target_device=None) + mock_kvm.apply_kv_cache_to_request.assert_called_once_with(req, {"layer_blocks": {"k": [1]}}) + self.assertEqual(seen["request_id"], "req-1") + self.assertEqual( + seen["cfg_role_payloads"], + {"cfg_text": ({"layer_blocks": {"k": [2]}}, 32)}, + ) + self.assertEqual(req.sampling_params.cfg_text_kv_metadata, {"seq_len": 3}) + + host.shutdown_omni_connectors() + + def test_receive_multi_kv_skips_inactive_request(self): + mock_kvm = MagicMock() + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + kv_transfer_manager=mock_kvm, + ) + + host.requests = {} + host.recv_kv_cache = MagicMock(return_value=({"layer_blocks": {"k": [1]}}, 64)) + req = SimpleNamespace(request_id="req-1", sampling_params=None) + + ok = host.receive_multi_kv_cache(req) + + self.assertFalse(ok) + host.recv_kv_cache.assert_not_called() + mock_kvm.apply_kv_cache_to_request.assert_not_called() + + host.shutdown_omni_connectors() + + +class TestOmniConnectorOutput(unittest.TestCase): + """Test 4: Output aggregation across transfer modes.""" + + def test_output_aggregation(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + ) + + host._chunk_ready_req_ids.add("req-1") + host._chunk_finished_req_ids.add("req-2") + host._local_request_metadata["req-1"] = {"next_stage_prompt_len": 10} + host._stage_recv_req_ids.add("req-3") + + output = host.get_omni_connector_output() + self.assertIsInstance(output, OmniConnectorOutput) + self.assertEqual(output.chunk_ready_req_ids, {"req-1"}) + self.assertEqual(output.chunk_finished_req_ids, {"req-2"}) + self.assertEqual(output.request_metadata, {"req-1": {"next_stage_prompt_len": 10}}) + self.assertEqual(output.stage_recv_req_ids, {"req-3"}) + + output2 = host.get_omni_connector_output() + self.assertEqual(output2.chunk_ready_req_ids, set()) + self.assertEqual(output2.request_metadata, {}) + + host.shutdown_omni_connectors() + + +class TestMixinNoConnector(unittest.TestCase): + """Edge case: mixin works gracefully without a connector.""" + + def test_no_connector(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + ) + self.assertIsNone(host._omni_connector) + + results = host.recv_full_payload_inputs(scheduler_output=None) + self.assertIsNone(results) + + sent = host.send_full_payload_outputs(None, {"req-1": {}}) + self.assertEqual(sent, []) + + ok = host.send_chunk(_make_request("req-1"), pooling_output={}) + self.assertFalse(ok) + + output = host.get_omni_connector_output() + self.assertIsInstance(output, OmniConnectorOutput) + + host.shutdown_omni_connectors() + + +class TestFinishedLoadReqsDrain(unittest.TestCase): + """Test A1 fix: get_omni_connector_output drains _finished_load_reqs.""" + + def test_finished_load_reqs_flow_to_chunk_ready(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + ) + + host._finished_load_reqs.add("req-1") + host._finished_load_reqs.add("req-2") + + output = host.get_omni_connector_output() + self.assertIn("req-1", output.chunk_ready_req_ids) + self.assertIn("req-2", output.chunk_ready_req_ids) + + self.assertEqual(len(host._finished_load_reqs), 0) + self.assertEqual(len(host._chunk_ready_req_ids), 0) + + host.shutdown_omni_connectors() + + +class TestLoadCustomFuncSelection(unittest.TestCase): + def test_skips_legacy_stage_list_processors_for_full_payload_mode(self): + legacy_paths = [ + "vllm_omni.model_executor.stage_input_processors.mimo_audio.llm2code2wav", + "vllm_omni.model_executor.stage_input_processors.mammoth_moda2.ar2dit", + "vllm_omni.model_executor.stage_input_processors.cosyvoice3.text2flow", + "vllm_omni.model_executor.stage_input_processors.glm_image.ar2diffusion", + ] + + for func_path in legacy_paths: + selected_path, func = MixinHost._load_custom_func( + SimpleNamespace( + async_chunk=False, + custom_process_input_func=func_path, + custom_process_next_stage_input_func=None, + ) + ) + assert selected_path != func_path + assert func is None or MixinHost._is_connector_payload_builder(func) + + +class TestFullPayloadSendWithCustomFunc(unittest.TestCase): + """Test B4: send_full_payload_outputs with full_payload_mode custom process func.""" + + def test_full_payload_send_passes_is_finished_and_connector(self): + seen = {} + + def full_payload_func(transfer_manager, pooling_output, request, is_finished=False): + seen["connector"] = transfer_manager.connector + seen["is_finished"] = is_finished + seen["data"] = pooling_output + seen["rid"] = request.request_id if request else None + return {"processed": True, "finished": is_finished} + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + ) + host._omni_connector = MockConnector(stage_id=0) + host._stage_id = 0 + host._custom_process_func = full_payload_func + + req = _make_request("req-1") + req.is_finished = lambda: True + sent = host.send_full_payload_outputs( + scheduler_output=None, + outputs={"req-1": ({"raw": 100}, req)}, + ) + self.assertEqual(sent, ["req-1"]) + self.assertEqual( + seen, + { + "connector": host._omni_connector, + "is_finished": True, + "data": {"raw": 100}, + "rid": "req-1", + }, + ) + + host.shutdown_omni_connectors() + + def test_accumulate_and_flush(self): + call_log = [] + + def full_payload_func(transfer_manager, pooling_output, request): + call_log.append(request.request_id if request else None) + return {"processed": True} + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + ) + host._omni_connector = MockConnector(stage_id=0) + host._stage_id = 0 + host._custom_process_func = full_payload_func + + req = _make_request("req-1") + host.accumulate_full_payload_output("req-1", {"raw": 42}, req) + self.assertEqual(len(host._pending_full_payload_send), 1) + + host.flush_full_payload_outputs({"req-1"}) + self.assertEqual(len(host._pending_full_payload_send), 0) + self.assertEqual(len(call_log), 1) + self.assertEqual(call_log[0], "req-1") + + time.sleep(0.1) + host.shutdown_omni_connectors() + + +class TestKVSentReqIdsAccumulation(unittest.TestCase): + """Test that kv_sent_req_ids accumulates results from send_kv_cache.""" + + def test_kv_sent_accumulation(self): + mock_kvm = MagicMock() + mock_kvm.handle_finished_requests_kv_transfer.return_value = ["req-1", "req-2"] + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + kv_transfer_manager=mock_kvm, + ) + + host.send_kv_cache( + finished_reqs={"req-1": {}, "req-2": {}}, + kv_caches=[], + block_size=16, + cache_dtype="float16", + ) + + output = host.get_omni_connector_output() + self.assertIn("req-1", output.kv_sent_req_ids) + self.assertIn("req-2", output.kv_sent_req_ids) + + output2 = host.get_omni_connector_output() + self.assertEqual(output2.kv_sent_req_ids, []) + + host.shutdown_omni_connectors() + + +class TestChunkStreamCompletedGuard(unittest.TestCase): + """Test that register_chunk_recv is skipped after finish sentinel. + + This validates the fix for the race condition where the scheduling + coordinator re-registers a request for chunk polling after its + upstream chunk stream has already finished (is_finished sentinel + received), causing the bg recv thread to poll for a non-existent + shared-memory segment (e.g. ``_0_7`` when only 7 chunks 0–6 exist). + """ + + def _make_host(self, stage_id: int = 1) -> MixinHost: + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=stage_id, async_chunk=True), + ) + host._omni_connector = MockConnector(stage_id=stage_id) + host._stage_id = stage_id + host._async_chunk = True + return host + + def test_register_blocked_after_finish_sentinel(self): + """register_chunk_recv must be a no-op after the finish sentinel.""" + host = self._make_host(stage_id=1) + + req = _make_request("req-1", "ext-req-1") + + # Simulate the bg thread having received the finish sentinel: + with host._lock: + host._chunk_stream_completed.add("req-1") + + # Now try to re-register — this mimics the coordinator asking + # the model runner to poll for the next (non-existent) chunk. + host.register_chunk_recv(req) + + # The request must NOT appear in _pending_load_reqs + self.assertNotIn( + "req-1", + host._pending_load_reqs, + "register_chunk_recv should skip requests whose chunk stream is already complete", + ) + + host.shutdown_omni_connectors() + + def test_register_allowed_before_finish(self): + """register_chunk_recv works normally before finish sentinel.""" + host = self._make_host(stage_id=1) + req = _make_request("req-1", "ext-req-1") + + host.register_chunk_recv(req) + self.assertIn( + "req-1", + host._pending_load_reqs, + "register_chunk_recv should add request to pending when stream is not yet complete", + ) + + host.shutdown_omni_connectors() + + def test_finish_sentinel_populates_completed_set(self): + """Receiving is_finished=True adds to _chunk_stream_completed.""" + host = self._make_host(stage_id=1) + + # Simulate _poll_single_request receiving is_finished=True + req_id = "req-1" + with host._lock: + host._chunk_finished_req_ids.add(req_id) + host._chunk_stream_completed.add(req_id) + host._local_stage_payload_cache[req_id] = {"finished": True} + host._local_request_metadata[req_id] = {} + host._finished_load_reqs.add(req_id) + host._pending_load_reqs.pop(req_id, None) + + self.assertIn(req_id, host._chunk_stream_completed) + + # Subsequent register_chunk_recv should be blocked + req = _make_request(req_id, f"ext-{req_id}") + host.register_chunk_recv(req) + self.assertNotIn(req_id, host._pending_load_reqs) + + host.shutdown_omni_connectors() + + def test_stage_0_always_skipped(self): + """Stage-0 has no upstream, register_chunk_recv is always no-op.""" + host = self._make_host(stage_id=0) + host._stage_id = 0 + + req = _make_request("req-1") + host.register_chunk_recv(req) + self.assertNotIn("req-1", host._pending_load_reqs) + + host.shutdown_omni_connectors() + + def test_full_payload_recv_guard_still_works(self): + """Pre-existing guard: staged full-payload results prevent registration.""" + host = self._make_host(stage_id=1) + + with host._lock: + host._stage_recv_req_ids.add("req-1") + + req = _make_request("req-1", "ext-req-1") + host.register_chunk_recv(req) + self.assertNotIn("req-1", host._pending_load_reqs) + + host.shutdown_omni_connectors() + + +class TestCleanupFinishedRequest(unittest.TestCase): + """Test cleanup_finished_request frees per-request mixin state.""" + + def _make_host(self, stage_id: int = 1) -> MixinHost: + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=stage_id, async_chunk=True), + ) + host._omni_connector = MockConnector(stage_id=stage_id) + host._stage_id = stage_id + host._async_chunk = True + return host + + def test_cleanup_removes_all_state(self): + """cleanup_finished_request removes all tracking dicts/sets.""" + host = self._make_host(stage_id=1) + req_id = "req-1" + ext_id = "ext-req-1" + + # Simulate state accumulated during a request's lifetime + host._request_ids_mapping[req_id] = ext_id + host._put_req_chunk[ext_id] = 5 + host._get_req_chunk[req_id] = 3 + host._send_side_request_payload[ext_id] = {"some": "data"} + host._code_prompt_token_ids[ext_id] = [[1, 2, 3]] + host._chunk_stream_completed.add(req_id) + host._stage_recv_req_ids.add(req_id) + host._local_stage_payload_cache[req_id] = {"engine_inputs": {}} + host._local_request_metadata[req_id] = {"prompt_len": 10} + + # Cleanup + host.cleanup_finished_request(req_id) + + # All state should be gone + self.assertNotIn(req_id, host._request_ids_mapping) + self.assertNotIn(ext_id, host._put_req_chunk) + self.assertNotIn(req_id, host._get_req_chunk) + self.assertNotIn(ext_id, host._send_side_request_payload) + self.assertNotIn(ext_id, host._code_prompt_token_ids) + self.assertNotIn(req_id, host._chunk_stream_completed) + self.assertNotIn(req_id, host._stage_recv_req_ids) + self.assertNotIn(req_id, host._local_stage_payload_cache) + self.assertNotIn(req_id, host._local_request_metadata) + + host.shutdown_omni_connectors() + + def test_cleanup_removes_per_cycle_ready_state(self): + """cleanup_finished_request clears ready/finished carry-over for req-id reuse.""" + host = self._make_host(stage_id=1) + req_id = "req-1" + + host._pending_load_reqs[req_id] = _make_request(req_id, "ext-req-1") + host._finished_load_reqs.add(req_id) + host._chunk_ready_req_ids.add(req_id) + host._chunk_finished_req_ids.add(req_id) + + host.cleanup_finished_request(req_id) + + self.assertNotIn(req_id, host._pending_load_reqs) + self.assertNotIn(req_id, host._finished_load_reqs) + self.assertNotIn(req_id, host._chunk_ready_req_ids) + self.assertNotIn(req_id, host._chunk_finished_req_ids) + + host.shutdown_omni_connectors() + + def test_cleanup_without_mapping(self): + """cleanup works for Stage-0 where _request_ids_mapping isn't set.""" + host = self._make_host(stage_id=0) + host._stage_id = 0 + req_id = "req-1" + + # Stage-0 uses req_id directly (no ext_id mapping) + host._put_req_chunk[req_id] = 3 + host._get_req_chunk[req_id] = 0 + + host.cleanup_finished_request(req_id) + + self.assertNotIn(req_id, host._put_req_chunk) + self.assertNotIn(req_id, host._get_req_chunk) + + host.shutdown_omni_connectors() + + def test_prune_inactive_requests_cleans_stale_state_but_keeps_active(self): + """Inactive request IDs should be pruned without touching active ones.""" + host = self._make_host(stage_id=1) + active_req_id = "req-active" + stale_req_id = "req-stale" + stale_ext_id = "ext-stale" + + host._request_ids_mapping[active_req_id] = "ext-active" + host._request_ids_mapping[stale_req_id] = stale_ext_id + host._put_req_chunk[stale_ext_id] = 2 + host._get_req_chunk[stale_req_id] = 1 + host._finished_load_reqs.add(stale_req_id) + host._chunk_ready_req_ids.update({active_req_id, stale_req_id}) + host._chunk_finished_req_ids.add(stale_req_id) + host._chunk_stream_completed.add(stale_req_id) + host._stage_recv_req_ids.add(active_req_id) + host._send_side_request_payload[stale_ext_id] = {"stale": True} + host._code_prompt_token_ids[stale_ext_id] = [[1, 2, 3]] + + pruned = host.prune_inactive_requests({active_req_id}) + + self.assertEqual(pruned, {stale_req_id}) + self.assertIn(active_req_id, host._request_ids_mapping) + self.assertIn(active_req_id, host._chunk_ready_req_ids) + self.assertIn(active_req_id, host._stage_recv_req_ids) + self.assertNotIn(stale_req_id, host._request_ids_mapping) + self.assertNotIn(stale_ext_id, host._put_req_chunk) + self.assertNotIn(stale_req_id, host._get_req_chunk) + self.assertNotIn(stale_req_id, host._pending_load_reqs) + self.assertNotIn(stale_req_id, host._finished_load_reqs) + self.assertNotIn(stale_req_id, host._chunk_ready_req_ids) + self.assertNotIn(stale_req_id, host._chunk_finished_req_ids) + self.assertNotIn(stale_req_id, host._chunk_stream_completed) + self.assertNotIn(stale_req_id, host._stage_recv_req_ids) + self.assertNotIn(stale_ext_id, host._send_side_request_payload) + self.assertNotIn(stale_ext_id, host._code_prompt_token_ids) + + host.shutdown_omni_connectors() + + def test_prune_inactive_requests_keeps_recently_received_full_payload_state(self): + """Late bg-thread receives must survive until the scheduler catches up.""" + host = self._make_host(stage_id=1) + req_id = "req-recv-race" + ext_id = "ext-recv-race" + + host._request_ids_mapping[req_id] = ext_id + host._put_req_chunk[ext_id] = 1 + host._local_stage_payload_cache[req_id] = {"engine_inputs": {"ids": [1, 2, 3]}} + host._local_request_metadata[req_id] = {"next_stage_prompt_len": 3} + host._stage_recv_req_ids.add(req_id) + + pruned = host.prune_inactive_requests(set()) + + self.assertEqual(pruned, set()) + self.assertIn(req_id, host._request_ids_mapping) + self.assertIn(req_id, host._local_stage_payload_cache) + self.assertIn(req_id, host._local_request_metadata) + self.assertIn(req_id, host._stage_recv_req_ids) + self.assertIn(ext_id, host._put_req_chunk) + + # Once the scheduler has consumed the wake-up and the request really + # disappears from all protected sets, prune should clean it up. + host._stage_recv_req_ids.clear() + host._local_stage_payload_cache.clear() + host._local_request_metadata.clear() + + pruned = host.prune_inactive_requests(set()) + + self.assertEqual(pruned, {req_id}) + self.assertNotIn(req_id, host._request_ids_mapping) + self.assertNotIn(ext_id, host._put_req_chunk) + + host.shutdown_omni_connectors() + + +class TestSendChunkCachesMapping(unittest.TestCase): + """Test that send_chunk caches internal→external req ID mapping.""" + + def test_send_chunk_populates_request_ids_mapping(self): + """send_chunk should cache the internal→external mapping.""" + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0, async_chunk=True), + ) + host._omni_connector = MockConnector(stage_id=0) + host._stage_id = 0 + host._async_chunk = True + + def mock_process(transfer_manager, pooling_output, request): + return {"data": "test", "finished": False} + + host._custom_process_func = mock_process + + request = _make_request("internal-1", "external-1") + host.send_chunk(request, pooling_output={"v": 1}) + + # The mapping should be cached + self.assertEqual( + host._request_ids_mapping.get("internal-1"), + "external-1", + ) + + time.sleep(0.1) + host.shutdown_omni_connectors() + + +class TestLocalPayloadCacheLifecycle(unittest.TestCase): + """Unit tests for the local payload cache API (RFC §2.4).""" + + def _make_host(self) -> MixinHost: + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0), + ) + host._omni_connector = MockConnector(stage_id=0) + host._stage_id = 0 + return host + + def test_put_get_pop(self): + host = self._make_host() + payload = {"engine_inputs": {"ids": [1, 2, 3]}} + host.put_local_stage_payload("r1", payload) + + self.assertEqual(host.get_local_stage_payload("r1"), payload) + popped = host.pop_local_stage_payload("r1") + self.assertEqual(popped, payload) + self.assertIsNone(host.get_local_stage_payload("r1")) + host.shutdown_omni_connectors() + + def test_recv_full_payload_inputs_populates_local_cache(self): + host = self._make_host() + host._omni_connector = MockConnector(stage_id=0) + host._stage_id = 0 + + # Simulate a full payload already staged by the bg recv path + with host._lock: + host._local_stage_payload_cache["r1"] = {"tok": [10]} + host._stage_recv_req_ids.add("r1") + + host.recv_full_payload_inputs(scheduler_output=None) + self.assertEqual(host.get_local_stage_payload("r1"), {"tok": [10]}) + host.shutdown_omni_connectors() + + def test_rank0_only_polls_connector_for_tp_full_payload(self): + host = self._make_host() + host._omni_connector = MagicMock() + host._stage_id = 2 + host._local_rank = 0 + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 0 + payload = {"tok": [10], "finished": torch.tensor(True)} + connector_result = (payload, 123) + host._omni_connector.get.return_value = connector_result + tp_group = _FakeTPGroup(world_size=2, rank_in_group=0) + + with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group): + made_progress = host._poll_single_request("r1") + + self.assertTrue(made_progress) + host._omni_connector.get.assert_called_once_with("1", "2", "ext-r1_1_0") + self.assertEqual(tp_group.broadcast_inputs, []) + self.assertEqual(host.get_local_stage_payload("r1"), payload) + self.assertIn("r1", host._full_payload_pending_broadcast_req_ids) + self.assertNotIn("r1", host._stage_recv_req_ids) + self.assertIsNone(host.get_local_request_metadata("r1")) + host.shutdown_omni_connectors() + + def test_tp_follower_skips_connector_poll_for_full_payload(self): + host = self._make_host() + host._omni_connector = MagicMock() + host._stage_id = 2 + host._local_rank = 1 + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 0 + tp_group = _FakeTPGroup(world_size=2, rank_in_group=1) + + with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group): + made_progress = host._poll_single_request("r1") + + self.assertFalse(made_progress) + host._omni_connector.get.assert_not_called() + self.assertEqual(tp_group.broadcast_inputs, []) + self.assertNotIn("r1", host._local_stage_payload_cache) + host.shutdown_omni_connectors() + + def test_recv_full_payload_inputs_broadcasts_tp_leader_results_to_followers(self): + host = self._make_host() + host._omni_connector = MagicMock() + host._stage_id = 2 + host._local_rank = 1 + host._pending_load_reqs["r1"] = object() + payload = {"tok": [10], "finished": torch.tensor(True)} + tp_group = _FakeTPGroup(world_size=2, rank_in_group=1, follower_result={"r1": payload}) + + with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group): + results = host.recv_full_payload_inputs(scheduler_output=None) + + self.assertEqual(results, {"r1": payload}) + self.assertEqual(host.get_local_stage_payload("r1"), payload) + self.assertEqual(host.get_local_request_metadata("r1"), {}) + self.assertEqual(host._stage_recv_req_ids, {"r1"}) + self.assertNotIn("r1", host._pending_load_reqs) + self.assertEqual(tp_group.broadcast_inputs, [None]) + host.shutdown_omni_connectors() + + +class TestTPAsyncChunkFanout(unittest.TestCase): + def _make_host(self, rank: int) -> MixinHost: + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"), + ) + host._omni_connector = MagicMock() + host._stage_id = 2 + host._async_chunk = True + host._model_mode = "gen" + host._local_rank = rank + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 0 + return host + + def test_rank0_only_polls_connector_for_tp_async_chunk(self): + host = self._make_host(rank=0) + payload = { + "code_predictor_codes": [10, 11], + "left_context_size": 0, + "finished": torch.tensor(False), + } + host._omni_connector.get.return_value = (payload, 123) + tp_group = _FakeTPGroup(world_size=2, rank_in_group=0) + + with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group): + made_progress = host._poll_single_request("r1") + + self.assertTrue(made_progress) + host._omni_connector.get.assert_called_once_with("1", "2", "ext-r1_1_0") + self.assertEqual(host.get_local_stage_payload("r1"), payload) + self.assertIn("r1", host._finished_load_reqs) + self.assertIn("r1", host._async_chunk_updated_req_ids) + self.assertEqual(tp_group.broadcast_inputs, []) + host.shutdown_omni_connectors() + + def test_tp_follower_skips_connector_poll_for_async_chunk(self): + host = self._make_host(rank=1) + tp_group = _FakeTPGroup(world_size=2, rank_in_group=1) + + with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group): + made_progress = host._poll_single_request("r1") + + self.assertFalse(made_progress) + host._omni_connector.get.assert_not_called() + self.assertIsNone(host.get_local_stage_payload("r1")) + self.assertEqual(tp_group.broadcast_inputs, []) + host.shutdown_omni_connectors() + + def test_get_output_broadcasts_tp_async_chunk_payloads_to_followers(self): + host = self._make_host(rank=1) + host._pending_load_reqs["r1"] = object() + payload = { + "code_predictor_codes": [10, 11], + "left_context_size": 0, + "finished": torch.tensor(True), + } + packet = { + "staged_payloads": {"r1": payload}, + "request_metadata": {"r1": {"code_predictor_codes": [10, 11], "left_context_size": 0}}, + "newly_finished": {"r1"}, + "chunk_finished": {"r1"}, + } + tp_group = _FakeTPGroup(world_size=2, rank_in_group=1, follower_result=packet) + + with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group): + output = host.get_omni_connector_output() + + self.assertEqual(output.chunk_ready_req_ids, {"r1"}) + self.assertEqual(output.chunk_finished_req_ids, {"r1"}) + self.assertEqual( + output.request_metadata, + {"r1": {"code_predictor_codes": [10, 11], "left_context_size": 0}}, + ) + self.assertEqual(host.get_local_stage_payload("r1"), payload) + self.assertNotIn("r1", host._pending_load_reqs) + self.assertIn("r1", host._chunk_stream_completed) + self.assertEqual(tp_group.broadcast_inputs, [None]) + host.shutdown_omni_connectors() + + +class TestKVTransferLifecycle(unittest.TestCase): + """Unit tests for KV transfer lifecycle methods.""" + + def _make_host(self) -> MixinHost: + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0), + ) + return host + + def test_mark_drain_ack_complete(self): + host = self._make_host() + self.assertFalse(host.has_pending_kv_work()) + + host.mark_kv_transfer("r1", seq_len=100, block_ids=[0, 1, 2]) + self.assertTrue(host.has_pending_kv_work()) + self.assertTrue(host.is_kv_transfer_triggered("r1")) + + # Drain moves pending → active + pending = host.drain_pending_kv_transfers() + self.assertEqual(pending, {"r1": {"seq_len": 100, "block_ids": [0, 1, 2]}}) + self.assertIn("r1", host._kv_active_transfers) + self.assertTrue(host.has_pending_kv_work()) + + # Ack moves active → completed + host.ack_kv_transfers(["r1"]) + self.assertNotIn("r1", host._kv_active_transfers) + self.assertIn("r1", host._kv_completed_transfers) + + # Drain completed + completed = host.drain_completed_kv_transfers() + self.assertEqual(completed, {"r1"}) + self.assertFalse(host.has_pending_kv_work()) + host.shutdown_omni_connectors() + + def test_mark_dedup(self): + host = self._make_host() + host.mark_kv_transfer("r1", seq_len=100, block_ids=[0]) + host.mark_kv_transfer("r1", seq_len=200, block_ids=[0, 1]) + # Second mark is a no-op + self.assertEqual(host._kv_pending_transfers["r1"]["seq_len"], 100) + host.shutdown_omni_connectors() + + def test_cleanup_removes_kv_state(self): + host = self._make_host() + host.mark_kv_transfer("r1", seq_len=50, block_ids=[0]) + host.drain_pending_kv_transfers() + host.cleanup_finished_request("r1") + self.assertFalse(host.is_kv_transfer_triggered("r1")) + self.assertNotIn("r1", host._kv_active_transfers) + self.assertFalse(host.has_pending_kv_work()) + host.shutdown_omni_connectors() + + +class TestAsyncPayloadLifecycle(unittest.TestCase): + """Regression tests for async payload delivery lifecycle.""" + + def test_send_side_request_payload_not_cleared_before_payload_is_consumable(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), + ) + host._request_ids_mapping["r1"] = "r1" + payload = { + "thinker_decode_embeddings": torch.ones(1, 2), + "thinker_output_token_ids": [1], + "override_keys": ["thinker_decode_embeddings", "thinker_output_token_ids"], + "finished": torch.tensor(False), + } + + host._accumulate_payload("r1", dict(payload)) + with host._lock: + host._finished_load_reqs.add("r1") + + host.get_omni_connector_output() + self.assertIn("r1", host._send_side_request_payload) + host.shutdown_omni_connectors() + + def test_payload_consumable_ignores_token_horizon_only_updates(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), + ) + payload = { + "thinker_output_token_ids": [1, 2, 3], + "finished": torch.tensor(False), + "override_keys": [ + "thinker_output_token_ids", + "thinker_decode_embeddings_token_start", + "thinker_decode_embeddings_token_end", + ], + "thinker_decode_embeddings_token_start": 2, + "thinker_decode_embeddings_token_end": 3, + } + self.assertFalse(host._payload_is_consumable(payload)) + host.shutdown_omni_connectors() + + def test_payload_consumable_accepts_decode_embeddings(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), + ) + payload = { + "thinker_output_token_ids": [1, 2, 3], + "thinker_decode_embeddings": torch.ones(1, 2), + "finished": torch.tensor(False), + } + self.assertTrue(host._payload_is_consumable(payload)) + host.shutdown_omni_connectors() + + def test_ar_metadata_only_followup_chunk_does_not_rewake_request(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), + ) + host._omni_connector = MagicMock() + host._stage_id = 1 + host._async_chunk = True + host._model_mode = "ar" + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 0 + + host._omni_connector.get.side_effect = [ + ( + { + "thinker_decode_embeddings": torch.ones(1, 2), + "finished": torch.tensor(False), + }, + 1, + ), + ( + { + "next_stage_prompt_len": 7, + "finished": torch.tensor(False), + }, + 1, + ), + ] + + host._poll_single_request("r1") + output1 = host.get_omni_connector_output() + self.assertEqual(output1.chunk_ready_req_ids, {"r1"}) + + host._poll_single_request("r1") + output2 = host.get_omni_connector_output() + self.assertEqual(output2.chunk_ready_req_ids, set()) + self.assertEqual(output2.request_metadata, {"r1": {"next_stage_prompt_len": 7}}) + + host.shutdown_omni_connectors() + + def test_non_ar_recv_does_not_overwrite_unconsumed_staged_chunk(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"), + ) + host._omni_connector = MagicMock() + host._stage_id = 2 + host._async_chunk = True + host._model_mode = "gen" + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 1 + host._local_stage_payload_cache["r1"] = { + "code_predictor_codes": [1, 2, 3], + "left_context_size": 0, + "finished": torch.tensor(False), + } + + made_progress = host._poll_single_request("r1") + + self.assertFalse(made_progress) + host._omni_connector.get.assert_not_called() + self.assertEqual(host._get_req_chunk["r1"], 1) + + host.shutdown_omni_connectors() + + def test_non_ar_recv_waits_for_scheduler_handoff_before_fetching_next_chunk(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"), + ) + host._omni_connector = MagicMock() + host._stage_id = 2 + host._async_chunk = True + host._model_mode = "gen" + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 1 + host._local_request_metadata["r1"] = { + "code_predictor_codes": [10, 11, 12], + "left_context_size": 0, + } + host._finished_load_reqs.add("r1") + + made_progress = host._poll_single_request("r1") + + self.assertFalse(made_progress) + host._omni_connector.get.assert_not_called() + self.assertEqual(host._get_req_chunk["r1"], 1) + + output = host.get_omni_connector_output() + self.assertEqual(output.request_metadata["r1"]["code_predictor_codes"], [10, 11, 12]) + self.assertEqual(output.chunk_ready_req_ids, {"r1"}) + + host._omni_connector.get.return_value = ( + { + "code_predictor_codes": [20, 21, 22], + "left_context_size": 0, + "finished": torch.tensor(False), + }, + 1, + ) + made_progress = host._poll_single_request("r1") + + self.assertTrue(made_progress) + host._omni_connector.get.assert_called_once() + self.assertEqual(host._get_req_chunk["r1"], 2) + + host.shutdown_omni_connectors() + + +class TestRankAwareKVRouting(unittest.TestCase): + def _make_host(self, *, from_tp: int, to_tp: int, local_rank: int) -> MixinHost: + host = MixinHost() + host.init_omni_connectors(vllm_config=None, model_config=_make_model_config(stage_id=1)) + host._from_tp = from_tp + host._to_tp = to_tp + host._local_rank = local_rank + return host + + def test_recv_keys_use_remote_rank_as_from_rank(self): + host = self._make_host(from_tp=4, to_tp=2, local_rank=1) + self.assertEqual( + host.get_rank_aware_kv_keys("req", from_stage=0), + ["req_0_0_2_1", "req_0_0_3_1"], + ) + host.shutdown_omni_connectors() + + def test_send_keys_route_from_rank_gt_to_rank(self): + host = self._make_host(from_tp=4, to_tp=2, local_rank=3) + self.assertEqual(host.get_rank_aware_kv_send_keys("req", from_stage=0), ["req_0_0_3_1"]) + host.shutdown_omni_connectors() + + def test_invalid_recv_rank_mapping_raises(self): + host = self._make_host(from_tp=3, to_tp=2, local_rank=1) + with self.assertRaises(ValueError): + host.get_rank_aware_kv_keys("req", from_stage=0) + host.shutdown_omni_connectors() + + def test_invalid_send_rank_mapping_raises(self): + host = self._make_host(from_tp=3, to_tp=2, local_rank=1) + with self.assertRaises(ValueError): + host.get_rank_aware_kv_send_keys("req", from_stage=0) + host.shutdown_omni_connectors() + + def test_merge_rank_sharded_payloads_concatenates_head_dimension(self): + host = self._make_host(from_tp=4, to_tp=2, local_rank=0) + payloads = [ + {"layer_blocks": {"key_cache": [torch.ones(2, 1, 3)], "value_cache": [torch.ones(2, 1, 3)]}}, + {"layer_blocks": {"key_cache": [torch.full((2, 1, 3), 2.0)], "value_cache": [torch.full((2, 1, 3), 2.0)]}}, + ] + merged = host._merge_rank_sharded_kv_payloads(payloads) + self.assertEqual(tuple(merged["layer_blocks"]["key_cache"][0].shape), (2, 2, 3)) + self.assertTrue(torch.equal(merged["layer_blocks"]["key_cache"][0][:, 0], torch.ones(2, 3))) + self.assertTrue(torch.equal(merged["layer_blocks"]["key_cache"][0][:, 1], torch.full((2, 3), 2.0))) + host.shutdown_omni_connectors() + + def test_slice_rank_sharded_payload_splits_head_dimension(self): + host = self._make_host(from_tp=2, to_tp=4, local_rank=1) + payload = { + "layer_blocks": { + "key_cache": [torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)], + "value_cache": [torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)], + }, + "metadata": {}, + } + sliced = host._slice_rank_sharded_kv_payload(payload) + self.assertEqual(tuple(sliced["layer_blocks"]["key_cache"][0].shape), (2, 2, 3)) + expected = torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)[:, 2:4, :] + self.assertTrue(torch.equal(sliced["layer_blocks"]["key_cache"][0], expected)) + host.shutdown_omni_connectors() + + +class TestAttachOmniConnectorOutput(unittest.TestCase): + def test_wraps_empty_model_runner_output_when_signals_exist(self): + from vllm.v1.worker.gpu_model_runner import EMPTY_MODEL_RUNNER_OUTPUT + + host = MixinHost() + host.get_omni_connector_output = lambda: OmniConnectorOutput(chunk_ready_req_ids={"req-1"}) + + wrapped = host.attach_omni_connector_output(EMPTY_MODEL_RUNNER_OUTPUT) + + self.assertIsNot(wrapped, EMPTY_MODEL_RUNNER_OUTPUT) + self.assertEqual(wrapped.omni_connector_output.chunk_ready_req_ids, {"req-1"}) + + +class TestConnectorConfigValidation(unittest.TestCase): + def test_invalid_connector_name_raises(self): + host = MixinHost() + model_config = _make_model_config(stage_id=1) + model_config.stage_connector_config = {"name": " "} + + with self.assertRaisesRegex(RuntimeError, "missing connector name"): + host.init_omni_connectors(vllm_config=None, model_config=model_config) + + +class _FailingConnector: + """Connector whose put() fails a configurable number of times.""" + + def __init__(self, fail_count: int = 1, raise_on_fail: bool = False): + self._fail_count = fail_count + self._raise_on_fail = raise_on_fail + self.attempt = 0 + + def put(self, from_stage, to_stage, put_key, data): + self.attempt += 1 + if self.attempt <= self._fail_count: + if self._raise_on_fail: + raise ConnectionError("transient connector error") + return False, 0, None + return True, len(str(data)), None + + def get(self, *a, **kw): + return None + + def close(self): + pass + + +class TestSendRetry(unittest.TestCase): + """Tests for P1-2: failed connector sends must be retried.""" + + def _make_sender(self, connector): + sender = MixinHost() + sender.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0, async_chunk=True), + ) + sender._omni_connector = connector + sender._stage_id = 0 + sender._async_chunk = True + return sender + + def _make_task(self, req_id="r1"): + return { + "stage_id": 0, + "next_stage_id": 1, + "request_id": req_id, + "data": {"payload": "test"}, + } + + def test_send_single_request_returns_false_on_put_failure(self): + connector = _FailingConnector(fail_count=999) + sender = self._make_sender(connector) + + result = sender._send_single_request(self._make_task()) + self.assertFalse(result) + sender.shutdown_omni_connectors() + + def test_send_single_request_does_not_decrement_on_failure(self): + connector = _FailingConnector(fail_count=999) + sender = self._make_sender(connector) + sender._pending_save_counts["r1"] = 1 + + sender._send_single_request(self._make_task()) + self.assertEqual(sender._pending_save_counts.get("r1"), 1, "pending count must NOT be decremented on failure") + sender.shutdown_omni_connectors() + + def test_send_single_request_decrements_on_success(self): + connector = MockConnector(stage_id=0) + sender = self._make_sender(connector) + sender._pending_save_counts["r1"] = 1 + + result = sender._send_single_request(self._make_task()) + self.assertTrue(result) + self.assertNotIn("r1", sender._pending_save_counts, "pending count should be zero/removed on success") + sender.shutdown_omni_connectors() + + def test_requeue_or_drop_requeues_on_first_failure(self): + sender = self._make_sender(MockConnector(stage_id=0)) + task = self._make_task() + + sender._requeue_or_drop_failed_send(task) + + self.assertEqual(task.get("_retry_count"), 1) + with sender._lock: + dq = sender._pending_save_reqs.get("r1") + self.assertIsNotNone(dq) + self.assertEqual(len(dq), 1) + sender.shutdown_omni_connectors() + + def test_requeue_or_drop_drops_after_max_retries(self): + sender = self._make_sender(MockConnector(stage_id=0)) + sender._pending_save_counts["r1"] = 1 + task = self._make_task() + task["_retry_count"] = sender._MAX_SEND_RETRIES # already at max + + sender._requeue_or_drop_failed_send(task) + + with sender._lock: + dq = sender._pending_save_reqs.get("r1") + self.assertTrue(dq is None or len(dq) == 0, "task should NOT be re-enqueued after max retries") + self.assertNotIn("r1", sender._pending_save_counts, "pending count should be cleaned up on final drop") + sender.shutdown_omni_connectors() + + def test_save_loop_retries_on_exception(self): + """Integration: _save_loop retries a task when put() raises.""" + from collections import deque + + connector = _FailingConnector(fail_count=1, raise_on_fail=True) + sender = self._make_sender(connector) + task = self._make_task() + + with sender._lock: + sender._pending_save_reqs["r1"] = deque([task]) + sender._pending_save_counts["r1"] = 1 + + sender._stop_event.clear() + + def run_one_loop(): + sender._save_loop() + + sender._stop_event.set() # will exit after one iteration + # Run manually instead of threading + # Simulate: pop task, send fails, requeue + popped_task = None + with sender._lock: + dq = sender._pending_save_reqs.get("r1") + if dq: + popped_task = dq.popleft() + if not dq: + del sender._pending_save_reqs["r1"] + + if popped_task is not None: + success = False + try: + success = sender._send_single_request(popped_task) + except Exception: + pass + if not success: + sender._requeue_or_drop_failed_send(popped_task) + + # After first failure, task should be re-enqueued + with sender._lock: + dq = sender._pending_save_reqs.get("r1") + self.assertIsNotNone(dq) + self.assertEqual(len(dq), 1) + requeued = dq[0] + self.assertEqual(requeued.get("_retry_count"), 1) + + # Second attempt should succeed (connector now returns True) + success = sender._send_single_request(requeued) + self.assertTrue(success) + sender.shutdown_omni_connectors() + + +if __name__ == "__main__": + unittest.main() diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py new file mode 100644 index 00000000000..c9d891afb41 --- /dev/null +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -0,0 +1,380 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Scheduling-side coordination for chunk and full_payload input waiting. + +Manages WAITING_FOR_CHUNK and WAITING_FOR_INPUT state transitions +based on readiness signals from OmniConnectorOutput, without ever +calling connector.put()/get(). + +This replaces the scheduling half of OmniChunkTransferAdapter; the +transport half lives in OmniConnectorModelRunnerMixin. +""" + +from __future__ import annotations + +import time +from collections import deque +from typing import Any + +from vllm.logger import init_logger +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +class OmniSchedulingCoordinator: + """Pure-scheduling coordinator for chunk and full_payload input waiting. + + The Scheduler owns an instance of this class. It consumes readiness + signals produced by the Model Runner's ``OmniConnectorModelRunnerMixin`` + (via ``OmniConnectorOutput``) and manages ``WAITING_FOR_CHUNK`` and + ``WAITING_FOR_INPUT`` state transitions accordingly. + """ + + def __init__(self, scheduler_max_num_seqs: int, stage_id: int = 0, async_chunk: bool = False): + self._stage_id = stage_id + self._scheduler_max_num_seqs = scheduler_max_num_seqs + self._async_chunk = async_chunk + + self.finished_requests: set[str] = set() + self.requests_with_ready_chunks: set[str] = set() + self._full_payload_input_received: set[str] = set() + + self._waiting_for_chunk_waiting: deque[Any] = deque() + self._waiting_for_chunk_running: deque[Any] = deque() + + # Request IDs that were newly registered for chunk recv this cycle. + # The engine/Model Runner should call register_chunk_recv() for these + # so the bg thread starts polling. + self.pending_chunk_registrations: list[Any] = [] + + # Requests waiting for full_payload stage input (WAITING_FOR_INPUT). + self._waiting_for_input: deque[Any] = deque() + self.pending_input_registrations: list[Any] = [] + + # Monotonic timestamp recording when each request first entered + # WAITING_FOR_CHUNK or WAITING_FOR_INPUT. Used by + # collect_timed_out_request_ids() to detect orphaned waits. + self._waiting_since: dict[str, float] = {} + + # ------------------------------------------------------------------ # + # Core scheduling methods + # ------------------------------------------------------------------ # + + def process_pending_chunks( + self, + waiting_queue: Any, + running_queue: list[Request], + chunk_ready_req_ids: set[str], + chunk_finished_req_ids: set[str], + ) -> None: + """Transition requests whose chunks have arrived. + + Args: + waiting_queue: Scheduler's waiting request queue. + running_queue: Scheduler's running request list. + chunk_ready_req_ids: IDs with a newly arrived chunk this cycle. + chunk_finished_req_ids: IDs whose final chunk has arrived. + """ + if self._stage_id == 0 or not self._async_chunk: + return + + terminal_ready_req_ids = chunk_ready_req_ids.intersection(chunk_finished_req_ids) + self.finished_requests.update(chunk_finished_req_ids - terminal_ready_req_ids) + self.pending_chunk_registrations = [] + + self._process_chunk_queue( + waiting_queue, + self._waiting_for_chunk_waiting, + RequestStatus.WAITING, + chunk_ready_req_ids, + ) + self._process_chunk_queue( + running_queue, + self._waiting_for_chunk_running, + RequestStatus.RUNNING, + chunk_ready_req_ids, + ) + self.finished_requests.update(terminal_ready_req_ids) + + while len(running_queue) > self._scheduler_max_num_seqs: + request = running_queue.pop() + # Must reset status to WAITING so the scheduler treats it as + # schedulable work. KV blocks are NOT freed here (unlike a + # real preemption), so PREEMPTED would be incorrect. + request.status = RequestStatus.WAITING + waiting_queue.prepend_requests([request]) + + def process_pending_full_payload_inputs( + self, + waiting_queue: Any, + running_queue: list[Request], + stage_recv_req_ids: set[str], + ) -> None: + """Manage WAITING_FOR_INPUT lifecycle for full_payload_mode. + + For non-Stage-0 stages in full_payload_mode (``async_chunk=False``): + 1. Fresh WAITING requests are transitioned to WAITING_FOR_INPUT + and registered for bg-thread polling. + 2. WAITING_FOR_INPUT requests whose data has arrived (in + ``stage_recv_req_ids``) are transitioned back to WAITING. + """ + if self._stage_id == 0: + return + + self._full_payload_input_received.update(stage_recv_req_ids) + if not self._async_chunk and stage_recv_req_ids: + self.finished_requests.update(stage_recv_req_ids) + logger.debug( + "[Coordinator stage-%s] full_payload recv -> finished_requests: %s", + self._stage_id, + stage_recv_req_ids, + ) + self.pending_input_registrations = [] + + remaining: deque[Any] = deque() + for request in self._waiting_for_input: + if request.request_id in stage_recv_req_ids: + request.status = RequestStatus.WAITING + self._waiting_since.pop(request.request_id, None) + waiting_queue.add_request(request) + else: + remaining.append(request) + self._waiting_for_input = remaining + + if not self._async_chunk: + to_remove: list[Any] = [] + queue_snapshot = list(waiting_queue) + for request in queue_snapshot: + if request.status == RequestStatus.WAITING: + if request.request_id in self._full_payload_input_received: + continue + if request.request_id in self.requests_with_ready_chunks: + continue + if request.request_id in self.finished_requests: + continue + request.status = RequestStatus.WAITING_FOR_INPUT + self._waiting_since.setdefault(request.request_id, time.monotonic()) + to_remove.append(request) + self._waiting_for_input.append(request) + self.pending_input_registrations.append(request) + elif request.status == RequestStatus.WAITING_FOR_INPUT: + if request.request_id in stage_recv_req_ids: + request.status = RequestStatus.WAITING + self._waiting_since.pop(request.request_id, None) + else: + to_remove.append(request) + self._waiting_for_input.append(request) + self.pending_input_registrations.append(request) + for request in to_remove: + waiting_queue.remove(request) + + def process_pending_full_payload_inputs_legacy( + self, + waiting_queue: Any, + running_queue: list[Request], + stage_recv_req_ids: set[str], + ) -> None: + """Compatibility wrapper for ``process_pending_full_payload_inputs``.""" + self.process_pending_full_payload_inputs(waiting_queue, running_queue, stage_recv_req_ids) + + def free_finished_request(self, request_id: str) -> None: + """Prune internal tracking sets for a freed request to prevent unbounded growth.""" + self._full_payload_input_received.discard(request_id) + self.finished_requests.discard(request_id) + self.requests_with_ready_chunks.discard(request_id) + self._waiting_since.pop(request_id, None) + + def collect_timed_out_request_ids( + self, + timeout_s: float, + ) -> set[str]: + """Return IDs of requests that have been waiting longer than *timeout_s*. + + Uses ``_waiting_since`` timestamps (always up-to-date) to detect + timed-out requests. This method is safe to call at any point in + the scheduling cycle — it does **not** rely on coordinator internal + queues (which are empty after ``restore_queues()``). + + Clears ``_waiting_since`` for timed-out IDs and defensively removes + them from coordinator internal queues if present. The caller + (scheduler) should then remove the requests from its queues, + set ``FINISHED_ERROR``, and call ``_free_request()`` so that + ``cleanup_finished_request()`` fires in the model runner mixin. + """ + if timeout_s <= 0: + return set() + now = time.monotonic() + timed_out_ids: set[str] = set() + for req_id, start_time in self._waiting_since.items(): + if now - start_time > timeout_s: + timed_out_ids.add(req_id) + if not timed_out_ids: + return set() + + # Defensively remove from coordinator internal queues (may already + # be empty if restore_queues() has run). + for queue_attr in ( + "_waiting_for_chunk_waiting", + "_waiting_for_chunk_running", + "_waiting_for_input", + ): + queue = getattr(self, queue_attr) + remaining: deque[Any] = deque() + for request in queue: + if request.request_id not in timed_out_ids: + remaining.append(request) + setattr(self, queue_attr, remaining) + + for req_id in timed_out_ids: + self._waiting_since.pop(req_id, None) + logger.warning( + "[Coordinator stage-%s] Request %s timed out waiting for chunk/input (waited > %.0fs)", + self._stage_id, + req_id, + timeout_s, + ) + + return timed_out_ids + + def restore_queues( + self, + waiting_queue: Any, + running_queue: list[Request], + ) -> None: + """Return waiting-for-chunk/input requests to scheduling queues.""" + for request in self._waiting_for_chunk_waiting: + waiting_queue.add_request(request) + self._waiting_for_chunk_waiting = deque() + + if self._waiting_for_chunk_running: + running_queue.extend(self._waiting_for_chunk_running) + self._waiting_for_chunk_running = deque() + + for request in self._waiting_for_input: + waiting_queue.add_request(request) + self._waiting_for_input = deque() + + def update_request_metadata( + self, + requests: dict[str, Request], + request_metadata: dict[str, dict[str, Any]], + model_mode: str = "ar", + ) -> None: + """Apply received scheduling metadata to request objects. + + For AR mode: only scheduler-visible metadata is applied locally. + For Generation mode: updates ``request.prompt_token_ids``. + + Additionally, if the payload contains ``next_stage_prompt_len``, + updates the request's ``prompt_token_ids`` to the correct length. + """ + for req_id, metadata in request_metadata.items(): + request = requests.get(req_id) + if request is None: + continue + + # Handle next_stage_prompt_len if present (for models like Qwen3-Omni). + # Only apply when the request has not started decoding yet + # (no output tokens). Resetting a mid-decode request would + # destroy generated tokens and desync KV cache state. + if "next_stage_prompt_len" in metadata: + next_len = metadata["next_stage_prompt_len"] + if isinstance(next_len, int) and next_len > 0: + output_token_ids = getattr(request, "_output_token_ids", None) + has_decode_output = output_token_ids is not None and len(output_token_ids) > 0 + if has_decode_output: + logger.debug( + "[Coordinator stage-%s] Skipping prompt resize for req %s: " + "request already has %s output tokens", + self._stage_id, + req_id, + len(output_token_ids), + ) + else: + current_prompt_ids = getattr(request, "prompt_token_ids", []) or [] + current_prompt_len = len(current_prompt_ids) + if current_prompt_len != next_len or getattr(request, "num_prompt_tokens", None) != next_len: + new_prompt = [0] * next_len + request.prompt_token_ids = new_prompt + request.num_prompt_tokens = next_len + request._all_token_ids.clear() + request._all_token_ids.extend(new_prompt) + request._output_token_ids.clear() + request.num_computed_tokens = 0 + logger.debug( + "[Coordinator stage-%s] Updated prompt_token_ids length to %s for req %s", + self._stage_id, + next_len, + req_id, + ) + + if model_mode != "ar": + new_ids = metadata.get("code_predictor_codes", []) + runtime_seed = None + if "left_context_size" in metadata: + runtime_seed = { + "left_context_size": metadata["left_context_size"], + } + request._omni_initial_model_buffer = runtime_seed + if new_ids: + request.prompt_token_ids = new_ids + request.num_computed_tokens = 0 + + def postprocess_scheduler_output( + self, + scheduler_output: Any, + requests: dict[str, Request] | None = None, + ) -> None: + """Clear per-cycle ready state after scheduler output is materialized.""" + self._clear_chunk_ready(scheduler_output) + + # ------------------------------------------------------------------ # + # Internal helpers + # ------------------------------------------------------------------ # + + def _process_chunk_queue( + self, + queue: Any, + waiting_for_chunk_list: deque[Any], + target_status: RequestStatus, + chunk_ready_req_ids: set[str], + ) -> None: + queue_snapshot = list(queue) + for request in queue_snapshot: + if request.status != RequestStatus.WAITING_FOR_CHUNK: + if request.request_id in self.requests_with_ready_chunks: + continue + if request.request_id in self.finished_requests: + continue + if request.status == RequestStatus.WAITING_FOR_INPUT: + continue + if request.request_id in chunk_ready_req_ids: + self.requests_with_ready_chunks.add(request.request_id) + continue + self.pending_chunk_registrations.append(request) + request.status = RequestStatus.WAITING_FOR_CHUNK + self._waiting_since.setdefault(request.request_id, time.monotonic()) + else: + if request.request_id in chunk_ready_req_ids: + request.status = target_status + self.requests_with_ready_chunks.add(request.request_id) + self._waiting_since.pop(request.request_id, None) + continue + queue.remove(request) + waiting_for_chunk_list.append(request) + + def _clear_chunk_ready(self, scheduler_output: Any) -> None: + if scheduler_output.scheduled_new_reqs: + for req_data in scheduler_output.scheduled_new_reqs: + self.requests_with_ready_chunks.discard( + getattr(req_data, "req_id", None), + ) + + if scheduler_output.scheduled_cached_reqs: + for req_id in scheduler_output.scheduled_cached_reqs.req_ids: + self.requests_with_ready_chunks.discard(req_id) + + +# Backward-compatible alias +ChunkSchedulingCoordinator = OmniSchedulingCoordinator diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 32ea5bf64dc..535f053c388 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -35,11 +35,12 @@ from vllm_omni.diffusion.worker.utils import DiffusionRequestState, RunnerOutput from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.platforms import current_omni_platform +from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin logger = init_logger(__name__) -class DiffusionModelRunner: +class DiffusionModelRunner(OmniConnectorModelRunnerMixin): """ Model runner that handles model loading and execution for diffusion models. diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index 9a7bb670658..2c2c1d21c11 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -9,6 +9,33 @@ from vllm_omni.inputs.data import OmniPromptType +@dataclass +class OmniConnectorOutput: + """Communication results from Model Runner to Scheduler. + + Carries transfer readiness signals so the Scheduler can make scheduling + decisions without ever calling connector.put()/get() directly. + + Attributes: + chunk_ready_req_ids: Request IDs with newly arrived chunks this cycle. + chunk_finished_req_ids: Request IDs whose final chunk has arrived. + request_metadata: Lightweight scheduling metadata keyed by request ID + (e.g. next_stage_prompt_len, code_predictor_codes, left_context_size). + Full payloads are owned by the Model Runner's local cache. + kv_sent_req_ids: Request IDs whose KV cache was successfully sent. + stage_recv_req_ids: Request IDs that received batch stage inputs. + has_pending_kv_work: True if the mixin has pending, active, or + completed KV transfers that the scheduler should account for. + """ + + chunk_ready_req_ids: set[str] = field(default_factory=set) + chunk_finished_req_ids: set[str] = field(default_factory=set) + request_metadata: dict[str, dict[str, Any]] = field(default_factory=dict) + kv_sent_req_ids: list[str] = field(default_factory=list) + stage_recv_req_ids: set[str] = field(default_factory=set) + has_pending_kv_work: bool = False + + class OmniModelRunnerOutput(ModelRunnerOutput): """Model runner output for omni models. @@ -24,6 +51,7 @@ class OmniModelRunnerOutput(ModelRunnerOutput): # IDs of requests whose KV cache has been extracted from GPU/NPU to CPU. # The Scheduler can safely free the block tables for these requests. kv_extracted_req_ids: list[str] | None = None + omni_connector_output: OmniConnectorOutput | None = None @dataclass diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 01ec23acb47..2ee1285ca69 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -40,6 +40,7 @@ from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner +from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin logger = init_logger(__name__) @@ -60,7 +61,7 @@ class ExecuteModelState(NamedTuple): slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None -class GPUARModelRunner(OmniGPUModelRunner): +class GPUARModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin): """Autoregressive GPU model runner that returns hidden states per request. Follows the v0.12 two-phase execute/sample flow from GPUModelRunner, and diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index d95b676f6d6..f10115c8e90 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -39,11 +39,12 @@ from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.worker.gpu_ar_model_runner import ExecuteModelState from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner +from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin logger = logging.getLogger(__name__) -class GPUGenerationModelRunner(OmniGPUModelRunner): +class GPUGenerationModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin): """Generation model runner for vLLM-Omni (non-autoregressive). - Reuses GPUModelRunner preparation, multimodal handling, and TP/PP/DP glue. diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py new file mode 100644 index 00000000000..e0df3ba3d7a --- /dev/null +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -0,0 +1,2125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unified data-plane communication mixin for Model Runners. + +All connector.put()/get() calls are consolidated here. Background I/O +threads handle async_chunk and full_payload_mode transfers; KV cache is delegated to +the existing OmniKVTransferManager (to be absorbed later). + +The mixin reports transfer results via OmniConnectorOutput so that the +Scheduler can make scheduling decisions without ever touching a connector. +""" + +from __future__ import annotations + +import importlib +import inspect +import os +import threading +from collections import defaultdict, deque +from types import SimpleNamespace +from typing import TYPE_CHECKING, Any + +import torch +from vllm.distributed.parallel_state import get_tp_group +from vllm.logger import init_logger + +from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory +from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec +from vllm_omni.outputs import OmniConnectorOutput +from vllm_omni.worker.payload_span import ( + THINKER_DECODE_EMBEDDINGS_KEY, + THINKER_DECODE_TOKEN_END_KEY, + THINKER_DECODE_TOKEN_START_KEY, + THINKER_OUTPUT_TOKEN_IDS_KEY, + get_tensor_span, + merge_tensor_spans, +) + +if TYPE_CHECKING: + from vllm_omni.distributed.omni_connectors.connectors.base import ( + OmniConnectorBase, + ) + from vllm_omni.distributed.omni_connectors.kv_transfer_manager import ( + OmniKVTransferManager, + ) + +logger = init_logger(__name__) + + +class OmniConnectorModelRunnerMixin: + """Unified data-plane communication mixin for Model Runners. + + Provides three transfer modes through a single pair of bg I/O threads: + - **full_payload_mode**: ``recv_full_payload_inputs`` / ``send_full_payload_outputs`` + - **Streaming (async_chunk)**: ``recv_chunk`` / ``send_chunk`` + - **KV cache**: ``send_kv_cache`` / ``recv_kv_cache`` (delegates to + the existing ``OmniKVTransferManager``) + + The mixin owns connector instances and background threads. It never + touches scheduling queues -- readiness is communicated to the Scheduler + via ``OmniConnectorOutput``. + """ + + # ------------------------------------------------------------------ # + # Init / Shutdown + # ------------------------------------------------------------------ # + + def init_omni_connectors( + self, + vllm_config: Any, + model_config: Any, + kv_transfer_manager: OmniKVTransferManager | None = None, + ) -> None: + """Initialize connectors and background threads. + + Args: + vllm_config: Full vLLM config object. + model_config: Stage-level model config with connector settings. + kv_transfer_manager: Existing KV transfer manager to delegate to. + """ + self._omni_connector: OmniConnectorBase | None = self._create_connector(model_config) + self._kv_transfer_manager = kv_transfer_manager + + self._async_chunk: bool = getattr(model_config, "async_chunk", False) + self._model_mode: str = getattr(model_config, "worker_type", "ar") + stage_id = getattr(model_config, "stage_id", 0) + if isinstance(stage_id, str): + stage_id = int(stage_id) + self._stage_id: int = stage_id if isinstance(stage_id, int) else 0 + + self._custom_process_func_path, self._custom_process_func = self._load_custom_func(model_config) + self._custom_process_supports_is_finished = self._custom_process_supports_is_finished_kwarg() + logger.info( + "[Stage-%s] init_omni_connectors: async_chunk=%s, custom_process_func=%s, connector=%s, func_path=%s", + self._stage_id, + self._async_chunk, + self._custom_process_func, + type(self._omni_connector).__name__ if self._omni_connector else None, + self._custom_process_func_path, + ) + + # -- next stage ID (from connector config or default stage_id + 1) -- + self._next_stage_id: int = self._resolve_next_stage_id(model_config) + + # -- heterogeneous TP rank support -- + rank_cfg = self._parse_rank_mapping(model_config) + self._from_tp: int = rank_cfg["from_tp"] + self._to_tp: int = rank_cfg["to_tp"] + self._local_rank: int = rank_cfg["local_rank"] + if self._kv_transfer_manager is not None: + self._kv_transfer_manager.kv_send_key_builder = self.get_rank_aware_kv_send_keys + self._kv_transfer_manager.kv_recv_key_builder = self.get_rank_aware_kv_keys + self._kv_transfer_manager.kv_payload_merger = self._merge_rank_sharded_kv_payloads + self._kv_transfer_manager.kv_payload_slicer = self._slice_rank_sharded_kv_payload + + # -- chunk index tracking (ported from OmniChunkTransferAdapter) -- + self._put_req_chunk: dict[str, int] = defaultdict(int) + self._get_req_chunk: dict[str, int] = defaultdict(int) + # Send-side async accumulation / staging buffer. Receive-side payload + # ownership lives in ``_local_stage_payload_cache``. + self._send_side_request_payload: dict[str, dict[str, Any]] = {} + self._code_prompt_token_ids: dict[str, list[list[int]]] = defaultdict(list) + self._request_ids_mapping: dict[str, str] = {} + + # -- async I/O state (shared by chunk + full_payload_mode) -- + self._pending_load_reqs: dict[str, Any] = {} + self._finished_load_reqs: set[str] = set() + self._pending_save_reqs: dict[str, deque] = {} + self._pending_save_counts: dict[str, int] = defaultdict(int) + self._deferred_send_cleanup: set[str] = set() + # -- per-cycle output accumulator -- + self._chunk_ready_req_ids: set[str] = set() + self._chunk_finished_req_ids: set[str] = set() + self._stage_recv_req_ids: set[str] = set() + self._full_payload_pending_broadcast_req_ids: set[str] = set() + self._async_chunk_updated_req_ids: set[str] = set() + + # -- Model Runner local payload cache (RFC §2.4) -- + # Full stage payloads land here first on the recv side. We + # intentionally do not write connector recv results straight into + # `model_intermediate_buffer`: runner-owned runtime state is + # materialized later by `_sync_local_stage_payloads()` on the + # model thread. This keeps recv timing separate from execute-step + # visibility and avoids mixing connector I/O with model runtime + # ownership. + self._local_stage_payload_cache: dict[str, dict[str, Any]] = {} + # Lightweight scheduling metadata pending delivery to the Scheduler. + self._local_request_metadata: dict[str, dict[str, Any]] = {} + + # -- persistent set of request IDs whose chunk stream is complete -- + # Prevents re-registration after the finish sentinel has been received. + self._chunk_stream_completed: set[str] = set() + + # -- full_payload_mode: accumulate latest pooler_output per request, + # send only when the request finishes (next-cycle flush) -- + self._pending_full_payload_send: dict[str, tuple[Any, Any]] = {} + + # -- KV sent accumulator -- + self._kv_sent_req_ids: list[str] = [] + + # -- KV transfer lifecycle (absorbed from scheduler) -- + # Requests marked for KV transfer: {req_id: {seq_len, block_ids}} + self._kv_pending_transfers: dict[str, dict[str, Any]] = {} + # Requests whose KV transfer has been submitted but not yet acked + self._kv_active_transfers: set[str] = set() + # Requests whose KV transfer is complete (acked by kv_extracted_req_ids) + self._kv_completed_transfers: set[str] = set() + # Dedup guard: requests that have already triggered KV transfer + self._kv_triggered_requests: set[str] = set() + + self._lock = threading.Lock() + self._stop_event = threading.Event() + self._work_available = threading.Event() + + # Start background threads only when there's a connector + self._recv_thread: threading.Thread | None = None + self._save_thread: threading.Thread | None = None + if self._omni_connector is not None: + self._recv_thread = threading.Thread( + target=self._recv_loop, + daemon=True, + name="omni-mixin-recv", + ) + self._recv_thread.start() + self._save_thread = threading.Thread( + target=self._save_loop, + daemon=True, + name="omni-mixin-save", + ) + self._save_thread.start() + + def shutdown_omni_connectors(self) -> None: + """Stop background threads and release connector resources.""" + self._stop_event.set() + if self._recv_thread is not None: + self._recv_thread.join(timeout=5) + if self._save_thread is not None: + self._save_thread.join(timeout=5) + if self._omni_connector is not None: + try: + self._omni_connector.close() + except Exception: + pass + + def cleanup_finished_request(self, req_id: str) -> None: + """Clean up per-request state after a request is fully finished. + + Call this when a request is freed from the model runner to prevent + memory leaks in the mixin's tracking dicts/sets. The external + request ID is resolved before cleaning up ``_put_req_chunk`` which + is keyed by external ID. + """ + ext_id = self._request_ids_mapping.pop(req_id, None) + send_req_id = ext_id if ext_id is not None else req_id + + with self._lock: + if self._pending_save_counts.get(send_req_id, 0): + self._deferred_send_cleanup.add(send_req_id) + else: + self._put_req_chunk.pop(send_req_id, None) + self._send_side_request_payload.pop(send_req_id, None) + self._code_prompt_token_ids.pop(send_req_id, None) + self._kv_pending_transfers.pop(req_id, None) + self._kv_active_transfers.discard(req_id) + self._kv_completed_transfers.discard(req_id) + self._kv_triggered_requests.discard(req_id) + self._cleanup_recv_delivery_state(req_id) + + def drop_inactive_request_delivery_state(self, req_id: str) -> None: + """Clear recv-side state for inactive requests.""" + ext_id = self._request_ids_mapping.pop(req_id, None) + if hasattr(self, "_lock"): + with self._lock: + self._drop_send_side_payload_state(req_id, ext_id) + else: + self._drop_send_side_payload_state(req_id, ext_id) + self._cleanup_recv_delivery_state(req_id) + + def _drop_send_side_payload_state(self, req_id: str, ext_id: str | None) -> None: + if ext_id is not None: + self._send_side_request_payload.pop(ext_id, None) + self._send_side_request_payload.pop(req_id, None) + + def _cleanup_recv_delivery_state(self, req_id: str) -> None: + """Clear recv-side delivery-cycle state.""" + if hasattr(self, "_lock"): + with self._lock: + self._clear_recv_delivery_state(req_id) + else: + self._clear_recv_delivery_state(req_id) + + def _clear_recv_delivery_state(self, req_id: str) -> None: + self._get_req_chunk.pop(req_id, None) + self._pending_load_reqs.pop(req_id, None) + self._finished_load_reqs.discard(req_id) + self._chunk_ready_req_ids.discard(req_id) + self._chunk_finished_req_ids.discard(req_id) + self._chunk_stream_completed.discard(req_id) + self._stage_recv_req_ids.discard(req_id) + self._full_payload_pending_broadcast_req_ids.discard(req_id) + self._async_chunk_updated_req_ids.discard(req_id) + self._local_stage_payload_cache.pop(req_id, None) + self._local_request_metadata.pop(req_id, None) + + def prune_inactive_requests(self, active_req_ids: Any) -> set[str]: + """Drop connector state for requests that no longer exist locally. + + Preempted / unscheduled requests are expected to stay in + ``self.requests`` and therefore remain untouched. This only prunes + stale request IDs that have already fallen out of the active request + map, preventing background recv/send bookkeeping from outliving the + request lifecycle. + """ + if active_req_ids is None: + return set() + + active_req_ids = set(active_req_ids) + pending_req_ids = set(getattr(self, "_pending_load_reqs", {}).keys()) + received_req_ids = set(getattr(self, "_stage_recv_req_ids", set())) + received_req_ids.update(getattr(self, "_full_payload_pending_broadcast_req_ids", set())) + received_req_ids.update(getattr(self, "_local_request_metadata", {}).keys()) + # Pending recv requests may not yet be in the caller's active set + # (e.g. WAITING_FOR_CHUNK requests live in the coordinator's internal + # queues, not in model runner self.requests). Protect them so that + # legitimate waiting requests are not pruned. + # + # Likewise, a full payload can arrive on the background recv thread + # after the scheduler_output snapshot for the current execute_model() + # cycle was already materialized. Those requests may briefly live only + # in recv-side buffers/local cache until the next scheduler cycle wakes + # them up; pruning them here drops the payload before stage_recv can be + # published. + active_req_ids.update(pending_req_ids) + active_req_ids.update(received_req_ids) + stale_req_ids: set[str] = set() + + # NOTE: _pending_load_reqs is excluded from the scan list because + # all its entries are unconditionally protected above. The mixin + # cannot distinguish a legitimately-waiting pending recv from an + # orphaned one (only the coordinator/scheduler knows). + # + # Requests with freshly received full payloads / local stage payloads + # are also protected above. Their scheduler wake-up may lag the recv + # thread by one execute_model() cycle, especially when the request was + # added after the current scheduler_output snapshot. + # + # Orphaned pending recv entries (e.g. from upstream stage crash) + # are handled by OmniSchedulingCoordinator.collect_timed_out_request_ids() + # which detects wait-time violations. The scheduler then removes the + # request from its queues, sets FINISHED_ERROR, and calls _free_request() + # which ultimately triggers cleanup_finished_request() here. + for attr_name in ( + "_request_ids_mapping", + "_get_req_chunk", + "_finished_load_reqs", + "_chunk_ready_req_ids", + "_chunk_finished_req_ids", + "_chunk_stream_completed", + "_stage_recv_req_ids", + "_full_payload_pending_broadcast_req_ids", + "_async_chunk_updated_req_ids", + "_local_stage_payload_cache", + "_local_request_metadata", + "_kv_pending_transfers", + "_kv_active_transfers", + "_kv_completed_transfers", + "_kv_triggered_requests", + ): + state = getattr(self, attr_name, None) + if isinstance(state, dict): + stale_req_ids.update(req_id for req_id in state if req_id not in active_req_ids) + elif isinstance(state, set): + stale_req_ids.update(req_id for req_id in state if req_id not in active_req_ids) + + for req_id in stale_req_ids: + self.cleanup_finished_request(req_id) + + return stale_req_ids + + # ------------------------------------------------------------------ # + # Local payload cache (RFC §2.4 – Model Runner ownership) + # ------------------------------------------------------------------ # + + def put_local_stage_payload(self, req_id: str, payload: dict[str, Any]) -> None: + """Store a full stage payload in the local cache.""" + self._local_stage_payload_cache[req_id] = payload + + def get_local_stage_payload(self, req_id: str) -> dict[str, Any] | None: + """Read a stage payload without removing it.""" + return self._local_stage_payload_cache.get(req_id) + + def pop_local_stage_payload(self, req_id: str) -> dict[str, Any] | None: + """Remove and return a stage payload (consume after use).""" + return self._local_stage_payload_cache.pop(req_id, None) + + def put_local_request_metadata(self, req_id: str, metadata: dict[str, Any]) -> None: + """Store lightweight scheduling metadata for a request.""" + self._local_request_metadata[req_id] = metadata + + def get_local_request_metadata(self, req_id: str) -> dict[str, Any] | None: + """Retrieve scheduling metadata for a request.""" + return self._local_request_metadata.get(req_id) + + # ------------------------------------------------------------------ # + # Scheduling metadata extraction + # ------------------------------------------------------------------ # + + _SCHEDULING_METADATA_KEYS = ( + "next_stage_prompt_len", + "code_predictor_codes", + "left_context_size", + ) + + @classmethod + def _extract_scheduling_metadata(cls, payload: dict[str, Any]) -> dict[str, Any]: + """Extract only the fields the scheduler needs from a full payload.""" + return {k: payload[k] for k in cls._SCHEDULING_METADATA_KEYS if k in payload} + + _NON_CONSUMABLE_PAYLOAD_KEYS = { + "finished", + "override_keys", + "next_stage_prompt_len", + "left_context_size", + THINKER_OUTPUT_TOKEN_IDS_KEY, + THINKER_DECODE_TOKEN_START_KEY, + THINKER_DECODE_TOKEN_END_KEY, + } + + @staticmethod + def _payload_value_has_content(value: Any) -> bool: + if value is None: + return False + if isinstance(value, torch.Tensor): + return value.numel() > 0 + if isinstance(value, (list, tuple, dict, set)): + return len(value) > 0 + return True + + @classmethod + def _payload_is_consumable(cls, payload: dict[str, Any] | None) -> bool: + """Return True when an async payload can drive a real forward step. + + Metadata-only wake-ups should not transition WAITING_FOR_CHUNK requests + back to schedulable state. In particular, a widened token horizon without + any newly visible thinker decode embeds should not force a placeholder-only + talker decode step. + """ + if not isinstance(payload, dict) or not payload: + return False + + decode_embeddings = payload.get(THINKER_DECODE_EMBEDDINGS_KEY) + if isinstance(decode_embeddings, torch.Tensor): + if decode_embeddings.ndim == 0: + return True + return decode_embeddings.numel() > 0 and decode_embeddings.shape[0] > 0 + + if "code_predictor_codes" in payload: + code_predictor_codes = payload.get("code_predictor_codes") + if isinstance(code_predictor_codes, torch.Tensor): + return code_predictor_codes.numel() > 0 + # Codec code 0 is valid; non-empty code payloads are consumable. + if hasattr(code_predictor_codes, "__len__"): + return len(code_predictor_codes) > 0 + else: + return code_predictor_codes is not None + + for key, value in payload.items(): + if key in cls._NON_CONSUMABLE_PAYLOAD_KEYS: + continue + if cls._payload_value_has_content(value): + return True + return False + + @staticmethod + def _get_local_tp_group() -> Any | None: + """Return the local TP group when tensor parallelism is initialized.""" + try: + return get_tp_group() + except Exception: + return None + + def _recv_ordinary_stage_result( + self, + connector: OmniConnectorBase, + from_stage: str, + to_stage: str, + connector_get_key: str, + ) -> Any: + """Receive one ordinary non-KV stage payload on the local leader rank only.""" + tp_group = self._get_local_tp_group() + if tp_group is None or getattr(tp_group, "world_size", 1) <= 1: + return connector.get(from_stage, to_stage, connector_get_key) + if not self.is_data_transfer_rank(): + return None + return connector.get(from_stage, to_stage, connector_get_key) + + def _recv_full_payload_result( + self, + connector: OmniConnectorBase, + from_stage: str, + to_stage: str, + connector_get_key: str, + ) -> Any: + """Receive one full-payload transfer on the local leader rank only.""" + return self._recv_ordinary_stage_result( + connector, + from_stage, + to_stage, + connector_get_key, + ) + + def _recv_async_chunk_result( + self, + connector: OmniConnectorBase, + from_stage: str, + to_stage: str, + connector_get_key: str, + ) -> Any: + """Receive one ordinary async chunk on the local leader rank only.""" + return self._recv_ordinary_stage_result( + connector, + from_stage, + to_stage, + connector_get_key, + ) + + @staticmethod + def _snapshot_payload(payload: Any) -> Any: + if isinstance(payload, dict): + return dict(payload) + return payload + + def _broadcast_tp_payload_packet(self, packet: Any) -> Any: + """Broadcast one ordinary payload packet from TP rank 0 when TP is active.""" + tp_group = self._get_local_tp_group() + if tp_group is None or getattr(tp_group, "world_size", 1) <= 1: + return packet + leader_packet = packet if self.is_data_transfer_rank() else None + return tp_group.broadcast_object(leader_packet, src=0) + + def _apply_staged_payloads_locked(self, staged_payloads: dict[str, Any]) -> None: + for req_id, payload in staged_payloads.items(): + self._local_stage_payload_cache[req_id] = self._snapshot_payload(payload) + + def _collect_full_payload_results_locked(self) -> dict[str, Any] | None: + if not self._full_payload_pending_broadcast_req_ids: + return None + results: dict[str, Any] = {} + missing_req_ids: list[str] = [] + for req_id in tuple(self._full_payload_pending_broadcast_req_ids): + payload = self._local_stage_payload_cache.get(req_id) + if payload is None: + missing_req_ids.append(req_id) + continue + results[req_id] = self._snapshot_payload(payload) + self._full_payload_pending_broadcast_req_ids.discard(req_id) + if missing_req_ids: + logger.warning( + "[Stage-%s] _collect_full_payload_results_locked: " + "pending full-payload reqs missing from local cache: %s", + self._stage_id, + missing_req_ids, + ) + return results or None + + def _collect_async_chunk_fanout_packet_locked(self) -> dict[str, Any] | None: + payload_req_ids = set(self._async_chunk_updated_req_ids) + payload_req_ids.update(self._finished_load_reqs) + payload_req_ids.update(self._chunk_finished_req_ids) + payload_req_ids.update(self._local_request_metadata) + if not ( + payload_req_ids or self._finished_load_reqs or self._chunk_finished_req_ids or self._local_request_metadata + ): + return None + + staged_payloads = { + req_id: self._snapshot_payload(self._local_stage_payload_cache[req_id]) + for req_id in payload_req_ids + if req_id in self._local_stage_payload_cache + } + packet = { + "staged_payloads": staged_payloads, + "request_metadata": dict(self._local_request_metadata), + "newly_finished": set(self._finished_load_reqs), + "chunk_finished": set(self._chunk_finished_req_ids), + } + + self._async_chunk_updated_req_ids.clear() + self._finished_load_reqs.clear() + self._chunk_finished_req_ids.clear() + self._local_request_metadata.clear() + + for req_id in packet["chunk_finished"]: + if req_id not in self._local_stage_payload_cache: + continue + ext_req_id = self._request_ids_mapping.get(req_id, req_id) + self._send_side_request_payload.pop(ext_req_id, None) + if ext_req_id != req_id: + self._send_side_request_payload.pop(req_id, None) + + return packet + + def _apply_async_chunk_fanout_packet(self, packet: dict[str, Any]) -> None: + staged_payloads = packet.get("staged_payloads", {}) + chunk_finished = set(packet.get("chunk_finished", ())) + with self._lock: + self._apply_staged_payloads_locked(staged_payloads) + for req_id in chunk_finished: + self._pending_load_reqs.pop(req_id, None) + self._chunk_stream_completed.add(req_id) + + # ------------------------------------------------------------------ # + # full_payload_mode (recv_full_payload_inputs / send_full_payload_outputs) + # ------------------------------------------------------------------ # + + def recv_full_payload_inputs(self, scheduler_output: Any) -> dict[str, Any] | None: + """Check for incoming full_payload_mode stage inputs (non-blocking). + + Returns a dict mapping ``request_id -> engine_inputs`` for data + that has arrived, or ``None`` if nothing is ready. Stores full + payloads in the local cache and extracts scheduling metadata. + """ + with self._lock: + results = self._collect_full_payload_results_locked() if self.is_data_transfer_rank() else None + results = self._broadcast_tp_payload_packet(results) + if not results: + return None + with self._lock: + self._stage_recv_req_ids.update(results.keys()) + for req_id in results: + self._pending_load_reqs.pop(req_id, None) + self._apply_staged_payloads_locked(results) + for req_id, payload in results.items(): + self._local_request_metadata[req_id] = self._extract_scheduling_metadata(payload) + logger.info( + "[Stage-%s] recv_full_payload_inputs: consumed %s reqs: %s, stage_recv_req_ids now=%s", + self._stage_id, + len(results), + list(results.keys()), + self._stage_recv_req_ids, + ) + return results + + @staticmethod + def _is_all_zero_tensor(t: Any) -> bool: + """Return True if *t* is a torch.Tensor whose elements are all zero.""" + return isinstance(t, torch.Tensor) and t.numel() > 0 and not t.any() + + def accumulate_full_payload_output( + self, + req_id: str, + pooler_output: Any, + request: Any, + ) -> None: + """Accumulate pooler_output for a request across steps (full_payload_mode). + + Per-token tensors (2-D+, matching trailing dims) are concatenated + along dim-0. Scalar / global tensors (1-D or 0-D) are replaced + with the latest value. + + All-zero tensors (e.g. ``code_predictor_codes`` emitted during + prefill) are dropped so that they do not pollute downstream stages + with garbage / noise frames. + + The data is actually sent when ``flush_full_payload_outputs`` is called + with the finished request IDs from the next scheduler cycle. + """ + # ---- Filter out all-zero tensors from the incoming pooler_output ---- + filtered: dict[str, Any] = {} + dropped_zero_keys: list[tuple[str, tuple[int, ...]]] = [] + for k, v in pooler_output.items(): + if self._is_all_zero_tensor(v): + dropped_zero_keys.append((k, tuple(v.shape))) + continue # skip prefill zero-filled placeholders + filtered[k] = v + if dropped_zero_keys: + logger.info( + "[Stage-%s] accumulate_full_payload_output: req=%s dropped_zero_keys=%s", + self._stage_id, + req_id, + dropped_zero_keys, + ) + pooler_output = filtered + + existing = self._pending_full_payload_send.get(req_id) + if existing is None: + self._pending_full_payload_send[req_id] = (pooler_output, request) + return + + prev_output, _ = existing + merged: dict[str, Any] = {} + for k in set(prev_output) | set(pooler_output): + v_new = pooler_output.get(k) + v_old = prev_output.get(k) + if v_new is None: + merged[k] = v_old + elif v_old is None: + merged[k] = v_new + elif ( + isinstance(v_new, torch.Tensor) + and isinstance(v_old, torch.Tensor) + and v_new.dim() >= 2 + and v_old.dim() >= 2 + and v_new.shape[1:] == v_old.shape[1:] + ): + merged[k] = torch.cat([v_old, v_new], dim=0) + else: + merged[k] = v_new + self._pending_full_payload_send[req_id] = (merged, request) + + def flush_full_payload_outputs(self, finished_req_ids: set[str]) -> None: + """Send accumulated full_payload outputs for requests that just finished.""" + logger.info( + "[Stage-%s] flush_full_payload_outputs: finished_req_ids=%s, pending=%s", + self._stage_id, + finished_req_ids, + list(self._pending_full_payload_send.keys()), + ) + to_send: dict[str, tuple[Any, Any]] = {} + for req_id in finished_req_ids: + entry = self._pending_full_payload_send.pop(req_id, None) + if entry is not None: + to_send[req_id] = entry + logger.info("[Stage-%s] flush_full_payload_outputs: to_send=%s", self._stage_id, list(to_send.keys())) + if to_send: + self.send_full_payload_outputs(scheduler_output=None, outputs=to_send) + + def send_full_payload_outputs( + self, + scheduler_output: Any, + outputs: dict[str, tuple[Any, Any] | Any], + ) -> list[str]: + """Send full_payload stage outputs to the next stage via connector. + + Args: + outputs: Mapping of ``req_id`` to either a + ``(pooling_output, request)`` tuple (preferred) or a raw + payload dict. When a tuple is supplied the request object + is forwarded to ``custom_process_stage_input_func``. + + Returns list of request IDs successfully enqueued. + """ + if self._omni_connector is None: + logger.info("[Stage-%s] send_full_payload_outputs: connector is None, skip", self._stage_id) + return [] + if not self.is_data_transfer_rank(): + logger.info( + "[Stage-%s] send_full_payload_outputs: not data_transfer_rank (rank=%s), skip", + self._stage_id, + self._local_rank, + ) + return list(outputs.keys()) + sent_ids: list[str] = [] + next_stage_id = self._next_stage_id + for req_id, value in outputs.items(): + if isinstance(value, tuple) and len(value) == 2: + raw_output, request = value + else: + raw_output, request = value, None + + payload = raw_output + if self._custom_process_func is not None: + payload = self._build_custom_process_payload( + request_id=req_id, + request=request, + pooling_output=raw_output, + ) + if payload is None: + continue + if payload is None: + logger.info("[Stage-%s] send_full_payload_outputs: payload is None for %s", self._stage_id, req_id) + continue + if isinstance(payload, dict): + code_predictor_codes = payload.get("code_predictor_codes") + if isinstance(code_predictor_codes, torch.Tensor): + code_len = int(code_predictor_codes.numel()) + elif hasattr(code_predictor_codes, "__len__"): + code_len = len(code_predictor_codes) + else: + code_len = None + logger.info( + "[Stage-%s] send_full_payload_outputs: req=%s payload_keys=%s code_len=%s left_context_size=%s", + self._stage_id, + req_id, + sorted(payload.keys()), + code_len, + payload.get("left_context_size"), + ) + + external_req_id = self._resolve_external_req_id(request, req_id) + chunk_id = self._put_req_chunk[req_id] + self._put_req_chunk[req_id] += 1 + connector_put_key = f"{external_req_id}_{self._stage_id}_{chunk_id}" + + logger.info( + "[Stage-%s] send_full_payload_outputs: enqueue req=%s put_key=%s next_stage=%s", + self._stage_id, + req_id, + connector_put_key, + next_stage_id, + ) + task = { + "stage_id": self._stage_id, + "next_stage_id": next_stage_id, + "put_key": connector_put_key, + "data": payload, + "request_id": req_id, + } + with self._lock: + self._pending_save_reqs.setdefault(req_id, deque()).append(task) + self._pending_save_counts[req_id] += 1 + sent_ids.append(req_id) + if sent_ids: + self._work_available.set() + return sent_ids + + def recv_stage_inputs(self, scheduler_output: Any) -> dict[str, Any] | None: + """Compatibility wrapper for ``recv_full_payload_inputs``.""" + return self.recv_full_payload_inputs(scheduler_output) + + def accumulate_batch_output( + self, + req_id: str, + pooler_output: Any, + request: Any, + ) -> None: + """Compatibility wrapper for ``accumulate_full_payload_output``.""" + self.accumulate_full_payload_output(req_id, pooler_output, request) + + def flush_batch_outputs(self, finished_req_ids: set[str]) -> None: + """Compatibility wrapper for ``flush_full_payload_outputs``.""" + self.flush_full_payload_outputs(finished_req_ids) + + def send_stage_outputs( + self, + scheduler_output: Any, + outputs: dict[str, tuple[Any, Any] | Any], + ) -> list[str]: + """Compatibility wrapper for ``send_full_payload_outputs``.""" + return self.send_full_payload_outputs(scheduler_output, outputs) + + # ------------------------------------------------------------------ # + # Streaming chunk mode (recv_chunk / send_chunk) + # ------------------------------------------------------------------ # + + def register_chunk_recv(self, request: Any) -> None: + """Register a request for async chunk retrieval by the bg thread. + + Stage-0 has no upstream producer so this is a no-op there. + Skips requests whose batch data has already been received to + prevent the bg thread from polling for non-existent chunks. + """ + if self._stage_id == 0: + return + request_id = request.request_id + self._request_ids_mapping[request_id] = getattr( + request, + "external_req_id", + request_id, + ) + with self._lock: + if request_id in self._stage_recv_req_ids: + return + # Don't re-register if the finish sentinel was already received + if request_id in self._chunk_stream_completed: + return + self._pending_load_reqs[request_id] = request + self._work_available.set() + + def recv_chunk(self) -> dict[str, Any]: + """Collect chunks received by the bg thread since last call. + + Returns a dict ``{request_id: chunk_payload}`` for newly arrived + chunks. Empty dict when nothing is ready. + + This method reads from ``_finished_load_reqs`` without clearing + it -- ``get_omni_connector_output()`` is the sole consumer that + drains and resets ``_finished_load_reqs`` at the end of each + ``execute_model`` cycle. + + Returns **shallow copies** of the cached payloads so that the + caller can read them without racing against the background recv + thread, which may concurrently mutate the live cache entries via + ``dict.update()``. + """ + with self._lock: + finished = set(self._finished_load_reqs) + if not finished: + return {} + # Snapshot the payloads under the lock to avoid racing with + # _poll_single_request which does existing.update(payload_data) + # on the same dict objects. + result = {} + for rid in finished: + payload = self._local_stage_payload_cache.get(rid) + result[rid] = dict(payload) if isinstance(payload, dict) else payload + + self._chunk_ready_req_ids.update(finished) + return result + + def send_chunk( + self, + request: Any, + pooling_output: Any | None = None, + ) -> bool: + """Derive and enqueue one chunk for async sending. + + Payload extraction runs in the caller thread (via + ``custom_process_stage_input_func``); the actual + ``connector.put()`` is done by the background save thread. + Non-KV data is identical across TP ranks; only rank 0 sends. + """ + if self._omni_connector is None: + logger.warning("[Stage-%s] send_chunk: connector is None", self._stage_id) + return False + if not self.is_data_transfer_rank(): + return True + raw_req_id = getattr(request, "request_id", None) or getattr(request, "req_id", None) + request_id = self._resolve_external_req_id(request, raw_req_id) + # Cache the internal→external mapping so that finish sentinels can + # resolve the external ID even after the request is freed. + if raw_req_id and raw_req_id != request_id: + self._request_ids_mapping.setdefault(raw_req_id, request_id) + chunk_id = self._put_req_chunk[request_id] + + payload_data = self._build_custom_process_payload( + request_id=request_id, + request=request, + pooling_output=pooling_output, + ) + if payload_data is None: + if chunk_id == 0: + logger.warning( + "[Stage-%s] send_chunk: payload is None for req=%s chunk=%s (process_func=%s)", + self._stage_id, + request_id, + chunk_id, + self._custom_process_func, + ) + return False + + self._put_req_chunk[request_id] += 1 + next_stage_id = self._next_stage_id + connector_put_key = f"{request_id}_{self._stage_id}_{chunk_id}" + + if chunk_id == 0: + logger.info( + "[Stage-%s] send_chunk: first chunk enqueued, req=%s key=%s", + self._stage_id, + request_id, + connector_put_key, + ) + + task = { + "stage_id": self._stage_id, + "next_stage_id": next_stage_id, + "put_key": connector_put_key, + "data": payload_data, + "request_id": request_id, + } + with self._lock: + self._pending_save_reqs.setdefault(request_id, deque()).append(task) + self._pending_save_counts[request_id] += 1 + self._work_available.set() + return True + + # ------------------------------------------------------------------ # + # KV cache (delegates to OmniKVTransferManager) + # ------------------------------------------------------------------ # + + def send_kv_cache( + self, + finished_reqs: dict[str, dict[str, Any]], + kv_caches: list[torch.Tensor], + block_size: int, + cache_dtype: str, + request_id_resolver: Any | None = None, + ) -> list[str]: + """Send KV cache for finished requests. + + Delegates to the existing ``OmniKVTransferManager``. + """ + if self._kv_transfer_manager is None: + return list(finished_reqs.keys()) if finished_reqs else [] + result = self._kv_transfer_manager.handle_finished_requests_kv_transfer( + finished_reqs=finished_reqs, + kv_caches=kv_caches, + block_size=block_size, + cache_dtype=cache_dtype, + request_id_resolver=request_id_resolver, + ) + if result: + self._kv_sent_req_ids.extend(result) + return result + + def recv_kv_cache( + self, + request_id: str, + target_device: torch.device | None = None, + ) -> tuple[dict[str, Any] | None, int]: + """Receive KV cache for a request. + + Delegates to the existing ``OmniKVTransferManager``. + """ + if self._kv_transfer_manager is None: + return None, 0 + return self._kv_transfer_manager.receive_kv_cache_for_request( + request_id=request_id, + target_device=target_device, + ) + + def receive_cfg_companion_kv_payloads( + self, + cfg_request_ids: dict[str, str], + target_device: torch.device | None = None, + ) -> dict[str, tuple[dict[str, Any] | None, int]]: + """Receive raw CFG companion KV payloads keyed by role.""" + return { + role: self.recv_kv_cache(companion_rid, target_device=target_device) + for role, companion_rid in cfg_request_ids.items() + } + + def receive_multi_kv_cache( + self, + req: Any, + cfg_kv_collect_func: Any | None = None, + target_device: torch.device | None = None, + ) -> bool: + """Receive primary and optional companion KV caches for a request. + + The mixin owns the runner-facing orchestration: primary KV receive, + companion payload fetch, and applying any model-specific CFG fields back + onto ``req.sampling_params``. + """ + if self._kv_transfer_manager is None: + return False + + request_id = getattr(req, "request_id", None) or ( + req.request_ids[0] if hasattr(req, "request_ids") and req.request_ids else None + ) + if not request_id: + logger.warning("Request has no ID, cannot receive KV cache") + return False + + active_requests = getattr(self, "requests", None) + if active_requests is not None and request_id not in active_requests: + logger.info("Skip receiving KV cache for inactive request %s", request_id) + return False + + primary_ok = False + data, _size = self.recv_kv_cache(request_id, target_device=target_device) + if data: + self._kv_transfer_manager.apply_kv_cache_to_request(req, data) + primary_ok = True + + cfg_ids = getattr(getattr(req, "sampling_params", None), "cfg_kv_request_ids", None) + if cfg_ids and cfg_kv_collect_func: + try: + cfg_role_payloads = self.receive_cfg_companion_kv_payloads( + cfg_ids, + target_device=target_device, + ) + cfg_kvs = cfg_kv_collect_func(request_id, cfg_role_payloads) + if cfg_kvs and hasattr(req, "sampling_params") and req.sampling_params is not None: + for key, value in cfg_kvs.items(): + setattr(req.sampling_params, key, value) + logger.info("Applied CFG KV caches: %s", list(cfg_kvs.keys())) + except Exception: + logger.exception("Failed to collect CFG KV caches for %s", request_id) + + return primary_ok + + # ------------------------------------------------------------------ # + # Rank-aware KV transfer routing + # ------------------------------------------------------------------ # + + def get_rank_aware_kv_keys( + self, + req_id: str, + from_stage: int, + to_stage: int | None = None, + chunk_id: int = 0, + ) -> list[str]: + """Build recv-side connector keys for all remote ranks this rank needs. + + For heterogeneous TP receive, the local rank is the target rank and must + fetch one or more source-rank shards keyed as ``from_rank -> to_rank``. + """ + remote_ranks = self.get_kv_remote_ranks() + return [ + self.get_kv_connector_key( + req_id=req_id, + from_stage=from_stage, + chunk_id=chunk_id, + from_rank=remote_rank, + to_rank=self._local_rank, + ) + for remote_rank in remote_ranks + ] + + def get_kv_target_ranks_for_send(self) -> list[int]: + """Determine which target ranks this local rank should send KV shards to.""" + self._validate_kv_tp_topology() + if self._from_tp == self._to_tp: + return [self._local_rank] + if self._from_tp > self._to_tp: + tp_ratio = self._from_tp // self._to_tp + return [self._local_rank // tp_ratio] + tp_ratio = self._to_tp // self._from_tp + base_rank = self._local_rank * tp_ratio + return [base_rank + i for i in range(tp_ratio)] + + def get_rank_aware_kv_send_keys( + self, + req_id: str, + from_stage: int, + to_stage: int | None = None, + chunk_id: int = 0, + ) -> list[str]: + """Build send-side connector keys for this rank's KV shard(s).""" + target_ranks = self.get_kv_target_ranks_for_send() + return [ + self.get_kv_connector_key( + req_id=req_id, + from_stage=from_stage, + chunk_id=chunk_id, + from_rank=self._local_rank, + to_rank=target_rank, + ) + for target_rank in target_ranks + ] + + @staticmethod + def _merge_rank_sharded_kv_payloads(payloads: list[dict[str, Any]]) -> dict[str, Any] | None: + """Merge multiple source-rank KV shards for one target rank.""" + payloads = [payload for payload in payloads if isinstance(payload, dict)] + if not payloads: + return None + if len(payloads) == 1: + return payloads[0] + + merged = dict(payloads[0]) + layer_blocks = merged.get("layer_blocks") + if not isinstance(layer_blocks, dict): + return merged + + def _merge_tensor_lists(name: str) -> list[torch.Tensor | None]: + merged_list: list[torch.Tensor | None] = [] + cache_lists = [payload.get("layer_blocks", {}).get(name, []) for payload in payloads] + max_len = max((len(cache_list) for cache_list in cache_lists), default=0) + for idx in range(max_len): + tensors = [cache_list[idx] for cache_list in cache_lists if idx < len(cache_list)] + tensors = [tensor for tensor in tensors if isinstance(tensor, torch.Tensor)] + if not tensors: + merged_list.append(None) + elif len(tensors) == 1: + merged_list.append(tensors[0]) + else: + merged_list.append(torch.cat(tensors, dim=-2).contiguous()) + return merged_list + + merged["layer_blocks"] = { + "key_cache": _merge_tensor_lists("key_cache"), + "value_cache": _merge_tensor_lists("value_cache"), + } + metadata = dict(merged.get("metadata", {})) + metadata["merged_remote_rank_count"] = len(payloads) + merged["metadata"] = metadata + return merged + + def _slice_rank_sharded_kv_payload(self, payload: dict[str, Any] | None) -> dict[str, Any] | None: + """Slice a duplicated source-rank KV shard for ``from_tp < to_tp`` cases.""" + if payload is None or self._from_tp >= self._to_tp: + return payload + + tp_ratio = self._to_tp // self._from_tp + shard_index = self._local_rank % tp_ratio + layer_blocks = payload.get("layer_blocks") if isinstance(payload, dict) else None + if not isinstance(layer_blocks, dict): + return payload + + def _slice_tensor_list(name: str) -> list[torch.Tensor | None]: + sliced: list[torch.Tensor | None] = [] + for tensor in layer_blocks.get(name, []): + if not isinstance(tensor, torch.Tensor) or tensor.ndim < 2: + sliced.append(tensor) + continue + head_dim = tensor.shape[-2] + if head_dim % tp_ratio != 0: + sliced.append(tensor) + continue + per_rank = head_dim // tp_ratio + start = shard_index * per_rank + sliced.append(tensor.narrow(-2, start, per_rank).contiguous()) + return sliced + + payload = dict(payload) + payload["layer_blocks"] = { + "key_cache": _slice_tensor_list("key_cache"), + "value_cache": _slice_tensor_list("value_cache"), + } + metadata = dict(payload.get("metadata", {})) + metadata["sliced_for_local_rank"] = self._local_rank + payload["metadata"] = metadata + return payload + + def should_replicate_payload(self) -> bool: + """Whether non-KV payloads should be replicated across ranks. + + Data payloads (stage inputs, chunks) are identical after all-gather, + so only rank 0 transfers them. KV payloads are rank-specific and + all ranks participate. + """ + return self._local_rank != 0 + + def get_kv_rank_mapping(self) -> dict[str, Any]: + """Return the current rank mapping configuration. + + Useful for debugging and for downstream code that needs to know + the TP topology without re-parsing model config. + """ + return { + "from_tp": self._from_tp, + "to_tp": self._to_tp, + "local_rank": self._local_rank, + "remote_ranks": self.get_kv_remote_ranks(), + "is_data_transfer_rank": self.is_data_transfer_rank(), + } + + # ------------------------------------------------------------------ # + # KV transfer lifecycle (RFC – mixin-owned) + # ------------------------------------------------------------------ # + + def mark_kv_transfer( + self, + req_id: str, + seq_len: int, + block_ids: list[int], + custom_metadata: dict[str, Any] | None = None, + ) -> None: + """Mark a request as needing KV cache transfer. + + Called by the scheduler when a transfer trigger fires. The mixin + owns the lifecycle from this point: pending → active → completed. + """ + if req_id in self._kv_pending_transfers: + return + self._kv_triggered_requests.add(req_id) + transfer = { + "seq_len": seq_len, + "block_ids": block_ids, + } + if custom_metadata is not None: + transfer["custom_metadata"] = custom_metadata + self._kv_pending_transfers[req_id] = transfer + + def drain_pending_kv_transfers(self) -> dict[str, dict[str, Any]]: + """Drain pending KV transfers and move them to active. + + Returns ``{req_id: {seq_len, block_ids}}`` for the model runner + to submit to ``send_kv_cache``. + """ + if not self._kv_pending_transfers: + return {} + pending = dict(self._kv_pending_transfers) + self._kv_active_transfers.update(pending.keys()) + self._kv_pending_transfers.clear() + return pending + + def ack_kv_transfers(self, req_ids: list[str] | set[str]) -> None: + """Acknowledge completed KV transfers (from kv_extracted_req_ids). + + Moves requests from active to completed so the scheduler can + safely free their blocks. + """ + for req_id in req_ids: + self._kv_active_transfers.discard(req_id) + self._kv_completed_transfers.add(req_id) + + def drain_completed_kv_transfers(self) -> set[str]: + """Drain and return completed KV transfer request IDs. + + The scheduler calls this to know which requests' blocks can be freed. + """ + completed = set(self._kv_completed_transfers) + self._kv_completed_transfers.clear() + return completed + + def is_kv_transfer_triggered(self, req_id: str) -> bool: + """Check if a request has already triggered KV transfer.""" + return req_id in self._kv_triggered_requests + + def has_pending_kv_work(self) -> bool: + """True if any KV transfers are pending, active, or awaiting ack.""" + return bool(self._kv_pending_transfers or self._kv_active_transfers or self._kv_completed_transfers) + + # Output aggregation + # ------------------------------------------------------------------ # + + def _empty_output_with_connector_signals(self) -> Any: + """Return a minimal ModelRunnerOutput carrying pending connector signals. + + Used by early-return paths (e.g. ``num_scheduled_tokens == 0``) + that still need to deliver ``omni_connector_output`` to the + Scheduler so that WAITING_FOR_INPUT / WAITING_FOR_CHUNK + transitions are not lost. + """ + from vllm_omni.outputs import OmniModelRunnerOutput + + output = OmniModelRunnerOutput(req_ids=[], req_id_to_index={}) + output.omni_connector_output = self.get_omni_connector_output() + return output + + def get_omni_connector_output(self) -> OmniConnectorOutput: + """Collect and reset transfer results for this execute_model cycle. + + ``request_metadata`` carries only lightweight scheduling metadata. + Full payloads remain owned by the Model Runner local cache for all + paths. + """ + if not hasattr(self, "_lock"): + return OmniConnectorOutput() + + tp_group = self._get_local_tp_group() + if self._async_chunk and tp_group is not None and getattr(tp_group, "world_size", 1) > 1: + if self.is_data_transfer_rank(): + with self._lock: + fanout_packet = self._collect_async_chunk_fanout_packet_locked() + else: + fanout_packet = None + fanout_packet = self._broadcast_tp_payload_packet(fanout_packet) + if fanout_packet is None: + newly_finished = set() + chunk_finished = set() + request_metadata = {} + else: + if not self.is_data_transfer_rank(): + self._apply_async_chunk_fanout_packet(fanout_packet) + newly_finished = set(fanout_packet["newly_finished"]) + chunk_finished = set(fanout_packet["chunk_finished"]) + request_metadata = dict(fanout_packet["request_metadata"]) + else: + with self._lock: + newly_finished = set(self._finished_load_reqs) + self._finished_load_reqs.clear() + chunk_finished = set(self._chunk_finished_req_ids) + self._chunk_finished_req_ids.clear() + request_metadata = dict(self._local_request_metadata) + self._local_request_metadata.clear() + # _send_side_request_payload is the async accumulation buffer for + # future recv chunks. Clearing it on every consumable wake-up drops + # intermediate + # thinker decode spans before the model side can consume them. + # Only terminal chunk_finished requests may release that buffer. + for req_id in chunk_finished: + if req_id not in self._local_stage_payload_cache: + continue + ext_req_id = self._request_ids_mapping.get(req_id, req_id) + self._send_side_request_payload.pop(ext_req_id, None) + if ext_req_id != req_id: + self._send_side_request_payload.pop(req_id, None) + self._chunk_ready_req_ids.update(newly_finished) + + output = OmniConnectorOutput( + chunk_ready_req_ids=set(self._chunk_ready_req_ids), + chunk_finished_req_ids=chunk_finished, + request_metadata=request_metadata, + kv_sent_req_ids=list(self._kv_sent_req_ids), + stage_recv_req_ids=set(self._stage_recv_req_ids), + has_pending_kv_work=self.has_pending_kv_work(), + ) + if output.stage_recv_req_ids or chunk_finished or newly_finished: + logger.info( + "[Stage-%s] get_omni_connector_output: stage_recv=%s, chunk_finished=%s, chunk_ready=%s", + self._stage_id, + output.stage_recv_req_ids, + chunk_finished, + output.chunk_ready_req_ids, + ) + self._chunk_ready_req_ids.clear() + self._kv_sent_req_ids.clear() + self._stage_recv_req_ids.clear() + return output + + @staticmethod + def _connector_output_has_signals(output: OmniConnectorOutput) -> bool: + return bool( + output.chunk_ready_req_ids + or output.chunk_finished_req_ids + or output.request_metadata + or output.kv_sent_req_ids + or output.stage_recv_req_ids + or output.has_pending_kv_work + ) + + def attach_omni_connector_output(self, result: Any | None) -> Any: + omni_output = self.get_omni_connector_output() + if not self._connector_output_has_signals(omni_output): + return result + + from copy import copy + + from vllm.v1.worker.gpu_model_runner import EMPTY_MODEL_RUNNER_OUTPUT + + wrapped = copy(result if result is not None else EMPTY_MODEL_RUNNER_OUTPUT) + wrapped.omni_connector_output = omni_output + return wrapped + + # ------------------------------------------------------------------ # + # Properties for compatibility with custom_process funcs that access + # transfer_manager.put_req_chunk / request_payload / code_prompt_token_ids + # ------------------------------------------------------------------ # + + @property + def put_req_chunk(self) -> dict[str, int]: + return self._put_req_chunk + + @property + def request_payload(self) -> dict[str, dict[str, Any]]: + return self._send_side_request_payload + + @request_payload.setter + def request_payload(self, value: dict[str, dict[str, Any]]) -> None: + self._send_side_request_payload = value + + @property + def code_prompt_token_ids(self) -> dict[str, list[list[int]]]: + return self._code_prompt_token_ids + + @property + def connector(self) -> Any | None: + return self._omni_connector + + # ------------------------------------------------------------------ # + # Background I/O threads + # ------------------------------------------------------------------ # + + def _recv_loop(self) -> None: + """Background thread: poll connector for incoming data.""" + _recv_poll_count = 0 + while not self._stop_event.is_set(): + with self._lock: + pending_ids = list(self._pending_load_reqs.keys()) + + if not pending_ids: + self._work_available.wait(timeout=0.01) + self._work_available.clear() + continue + + _recv_poll_count += 1 + if _recv_poll_count % 5000 == 1: + logger.info( + "[Stage-%s] _recv_loop: polling %s pending reqs: %s (poll#%s)", + self._stage_id, + len(pending_ids), + pending_ids[:5], + _recv_poll_count, + ) + + made_progress = False + for req_id in pending_ids: + if self._stop_event.is_set(): + break + try: + made_progress = self._poll_single_request(req_id) or made_progress + except Exception: + logger.warning("Error receiving data for %s", req_id, exc_info=True) + + if not made_progress and not self._stop_event.is_set(): + self._work_available.wait(timeout=0.001) + self._work_available.clear() + + _MAX_SEND_RETRIES = 3 + + def _save_loop(self) -> None: + """Background thread: send outgoing data via connector.""" + while not self._stop_event.is_set(): + task = None + with self._lock: + for req_id in list(self._pending_save_reqs.keys()): + dq = self._pending_save_reqs[req_id] + if dq: + task = dq.popleft() + if not dq: + del self._pending_save_reqs[req_id] + break + del self._pending_save_reqs[req_id] + + if task is not None: + success = False + try: + success = self._send_single_request(task) + except Exception: + logger.error( + "Error saving data for %s", + task.get("request_id"), + exc_info=True, + ) + if not success: + self._requeue_or_drop_failed_send(task) + continue + + self._work_available.wait(timeout=0.01) + self._work_available.clear() + + def _requeue_or_drop_failed_send(self, task: dict) -> None: + """Re-enqueue a failed send task or drop it after max retries.""" + retry_count = task.get("_retry_count", 0) + 1 + req_id = task.get("request_id") + if retry_count <= self._MAX_SEND_RETRIES: + task["_retry_count"] = retry_count + logger.warning( + "[Stage-%s] Re-enqueuing failed send for %s (retry %d/%d)", + getattr(self, "_stage_id", "?"), + req_id, + retry_count, + self._MAX_SEND_RETRIES, + ) + with self._lock: + dq = self._pending_save_reqs.setdefault(req_id, deque()) + dq.appendleft(task) + else: + logger.error( + "[Stage-%s] Giving up on send for %s after %d retries", + getattr(self, "_stage_id", "?"), + req_id, + self._MAX_SEND_RETRIES, + ) + self._decrement_pending_save_count(req_id) + + # ------------------------------------------------------------------ # + # Chunk-level poll / send (ported from OmniChunkTransferAdapter) + # ------------------------------------------------------------------ # + + def _poll_single_request(self, req_id: str) -> bool: + """Poll connector for one chunk of a request (non-blocking).""" + connector = self._omni_connector + if connector is None: + return False + + if self._async_chunk and self._model_mode != "ar": + with self._lock: + staged_payload = self._local_stage_payload_cache.get(req_id) + metadata_in_flight = req_id in self._local_request_metadata + scheduler_wakeup_pending = req_id in self._finished_load_reqs + if self._payload_is_consumable(staged_payload) or metadata_in_flight or scheduler_wakeup_pending: + logger.debug( + "[Stage-%s] delaying recv for req=%s until staged async payload is handed to scheduler", + self._stage_id, + req_id, + ) + return False + + target_stage_id = self._stage_id - 1 + chunk_id = self._get_req_chunk[req_id] + external_req_id = self._request_ids_mapping.get(req_id, req_id) + connector_get_key = f"{external_req_id}_{target_stage_id}_{chunk_id}" + + if self._async_chunk: + result = self._recv_async_chunk_result( + connector, + str(target_stage_id), + str(self._stage_id), + connector_get_key, + ) + else: + result = self._recv_full_payload_result( + connector, + str(target_stage_id), + str(self._stage_id), + connector_get_key, + ) + + if result is None: + return False + + payload_data, _size = result + if not payload_data: + return False + if isinstance(payload_data, dict): + logger.info( + "[Stage-%s] recv_chunk_result: req=%s ext=%s key=%s keys=%s finished=%s", + self._stage_id, + req_id, + external_req_id, + connector_get_key, + sorted(payload_data.keys()), + bool(payload_data.get("finished")) if "finished" in payload_data else None, + ) + + self._get_req_chunk[req_id] += 1 + + if self._async_chunk: + is_finished = bool(payload_data.get("finished")) + incoming_payload_consumable = self._payload_is_consumable(payload_data) + + if self._model_mode == "ar": + payload_data = self._accumulate_payload(external_req_id, payload_data) + payload_consumable = incoming_payload_consumable + else: + new_ids = payload_data.get("code_predictor_codes", []) + if not new_ids and not is_finished: + return False + payload_consumable = self._payload_is_consumable(payload_data) + + with self._lock: + if is_finished: + self._chunk_finished_req_ids.add(req_id) + self._chunk_stream_completed.add(req_id) + # Local cache (RFC §2.4) — merge, don't replace, so that + # earlier chunk keys (e.g. thinker_prefill_embeddings from + # chunk 0) are not overwritten by later chunks. + existing = self._local_stage_payload_cache.get(req_id) + if existing is not None and isinstance(existing, dict) and isinstance(payload_data, dict): + existing.update(payload_data) + else: + self._local_stage_payload_cache[req_id] = payload_data + staged_payload = self._local_stage_payload_cache[req_id] + self._async_chunk_updated_req_ids.add(req_id) + self.put_local_request_metadata(req_id, self._extract_scheduling_metadata(staged_payload)) + # A finish-only sentinel still needs one terminal wake-up so + # the downstream stage can sync the merged local payload and + # flush/finish even when the last recv carries no new + # consumable chunk bytes. + if payload_consumable or is_finished: + self._finished_load_reqs.add(req_id) + if is_finished and not payload_consumable: + logger.debug( + "[Stage-%s] finish sentinel arrived for req=%s without new consumable payload", + self._stage_id, + req_id, + ) + elif not payload_consumable: + logger.debug( + "[Stage-%s] req=%s received metadata-only / non-consumable async payload; delaying wake-up", + self._stage_id, + req_id, + ) + if is_finished: + self._pending_load_reqs.pop(req_id, None) + else: + # full_payload_mode: the complete payload arrives in a single get(), + # so always unregister immediately. + if isinstance(payload_data, dict): + engine_inputs = payload_data.get("engine_inputs", payload_data) + else: + engine_inputs = payload_data + with self._lock: + self._local_stage_payload_cache[req_id] = self._snapshot_payload(engine_inputs) + # Publish full-payload readiness only after the aligned TP broadcast + # path in recv_full_payload_inputs() has materialized the payload on all + # local ranks. Publishing metadata / stage_recv from the background recv + # thread can let the scheduler observe a request before the payload is + # actually visible to the model thread. + self._full_payload_pending_broadcast_req_ids.add(req_id) + self._pending_load_reqs.pop(req_id, None) + logger.info( + "[Stage-%s] full_payload recv complete: req=%s key=%s payload_type=%s", + self._stage_id, + req_id, + connector_get_key, + type(engine_inputs).__name__, + ) + + logger.debug("[Stage-%s] Received data for key %s", self._stage_id, connector_get_key) + return True + + def _build_custom_process_payload( + self, + request_id: str | None, + request: Any | None, + pooling_output: Any | None, + ) -> Any | None: + """Run the custom process hook with a best-effort finished kwarg.""" + if self._custom_process_func is None: + return None + + kwargs = { + "transfer_manager": self, + "pooling_output": pooling_output, + "request": request, + } + supports_is_finished = getattr( + self, + "_custom_process_supports_is_finished", + self._custom_process_supports_is_finished_kwarg(), + ) + is_finished_fn = getattr(request, "is_finished", None) + if callable(is_finished_fn): + try: + if supports_is_finished is not False: + kwargs["is_finished"] = bool(is_finished_fn()) + except Exception: + logger.debug("request.is_finished() failed for %s", request_id, exc_info=True) + + try: + return self._custom_process_func(**kwargs) + except TypeError as exc: + if "is_finished" not in kwargs or not self._is_unexpected_is_finished_kwarg_error(exc): + logger.exception("custom_process_stage_input_func failed for chunk %s", request_id) + return None + kwargs.pop("is_finished", None) + try: + return self._custom_process_func(**kwargs) + except Exception: + logger.exception("custom_process_stage_input_func failed for chunk %s", request_id) + return None + except Exception: + logger.exception("custom_process_stage_input_func failed for chunk %s", request_id) + return None + + def _custom_process_supports_is_finished_kwarg(self) -> bool | None: + """Return whether the custom process hook accepts `is_finished`.""" + if self._custom_process_func is None: + return None + try: + signature = inspect.signature(self._custom_process_func) + except (TypeError, ValueError): + return None + + for param in signature.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + return True + + is_finished_param = signature.parameters.get("is_finished") + if is_finished_param is None: + return False + return is_finished_param.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + + @staticmethod + def _is_unexpected_is_finished_kwarg_error(exc: TypeError) -> bool: + message = str(exc) + return ( + "unexpected keyword argument 'is_finished'" in message + or 'unexpected keyword argument "is_finished"' in message + or "positional-only arguments passed as keyword arguments: 'is_finished'" in message + ) + + def _send_single_request(self, task: dict) -> bool: + """Send one queued task via connector.put(). + + Returns True on success. On failure (put() raises or returns + ``success=False``), returns False **without** decrementing + ``_pending_save_counts`` so the caller can retry or clean up. + """ + connector = self._omni_connector + if connector is None: + return True + + request_id = task.get("request_id") + payload_data = task.get("data") + if payload_data is None and task.get("request") is not None: + payload_data = self._build_custom_process_payload( + request_id=request_id, + request=task.get("request"), + pooling_output=task.get("pooling_output"), + ) + put_key = task.get("put_key") + + success, _size, _metadata = connector.put( + from_stage=str(task["stage_id"]), + to_stage=str(task["next_stage_id"]), + put_key=put_key, + data=payload_data, + ) + logger.info( + "[Stage-%s] _send_single_request: put_key=%s success=%s size=%s", + task["stage_id"], + put_key, + success, + _size, + ) + + if not success: + return False + + self._decrement_pending_save_count(request_id) + return True + + def _decrement_pending_save_count(self, request_id: str) -> None: + """Decrement pending save count and run deferred cleanup if zero.""" + cleanup_req_id = None + with self._lock: + remaining = self._pending_save_counts.get(request_id, 0) + if remaining > 1: + self._pending_save_counts[request_id] = remaining - 1 + elif remaining == 1: + self._pending_save_counts.pop(request_id, None) + if request_id in self._deferred_send_cleanup: + self._deferred_send_cleanup.remove(request_id) + cleanup_req_id = request_id + if cleanup_req_id is not None: + self._put_req_chunk.pop(cleanup_req_id, None) + self._send_side_request_payload.pop(cleanup_req_id, None) + self._code_prompt_token_ids.pop(cleanup_req_id, None) + + # ------------------------------------------------------------------ # + # Payload accumulation (ported from OmniChunkTransferAdapter) + # ------------------------------------------------------------------ # + + def _accumulate_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]: + """Accumulate chunk payloads (concat tensors, extend lists). + + Returns a **shallow copy** of the accumulated state so callers + (e.g. ``_poll_single_request``) can store it in + ``_local_stage_payload_cache`` without aliasing the authoritative + ``_send_side_request_payload`` dict. + """ + if req_id not in self._send_side_request_payload: + self._send_side_request_payload[req_id] = dict(payload_data) + return dict(self._send_side_request_payload[req_id]) + + origin = self._send_side_request_payload[req_id] + merged = dict(origin) + override_keys = payload_data.get("override_keys", ()) + drop_decode_span = False + decode_span_handled = False + for key, value in payload_data.items(): + if key == "finished": + merged[key] = value + continue + if key == THINKER_DECODE_EMBEDDINGS_KEY: + merged_span = merge_tensor_spans( + get_tensor_span( + origin, + tensor_key=THINKER_DECODE_EMBEDDINGS_KEY, + start_key=THINKER_DECODE_TOKEN_START_KEY, + end_key=THINKER_DECODE_TOKEN_END_KEY, + ), + get_tensor_span( + payload_data, + tensor_key=THINKER_DECODE_EMBEDDINGS_KEY, + start_key=THINKER_DECODE_TOKEN_START_KEY, + end_key=THINKER_DECODE_TOKEN_END_KEY, + ), + ) + if merged_span is not None: + merged[key], merged[THINKER_DECODE_TOKEN_START_KEY], merged[THINKER_DECODE_TOKEN_END_KEY] = ( + merged_span + ) + decode_span_handled = True + continue + if isinstance(value, torch.Tensor) and key in origin: + if ( + THINKER_DECODE_TOKEN_START_KEY in origin + or THINKER_DECODE_TOKEN_END_KEY in origin + or THINKER_DECODE_TOKEN_START_KEY in payload_data + or THINKER_DECODE_TOKEN_END_KEY in payload_data + ): + logger.warning( + "[Stage-%s] req=%s falling back to legacy thinker decode " + "merge due to missing/invalid/non-contiguous span " + "metadata", + self._stage_id, + req_id, + ) + drop_decode_span = True + merged[key] = torch.cat([origin[key], value], dim=0) + continue + merged[key] = value + continue + if key in {THINKER_DECODE_TOKEN_START_KEY, THINKER_DECODE_TOKEN_END_KEY}: + if decode_span_handled or drop_decode_span: + continue + merged[key] = value + continue + if key in override_keys: + merged[key] = value + continue + if isinstance(value, torch.Tensor) and key in origin: + merged[key] = torch.cat([origin[key], value], dim=0) + elif isinstance(value, list) and key in origin: + merged[key] = origin[key] + value + else: + merged[key] = value + + if drop_decode_span: + merged.pop(THINKER_DECODE_TOKEN_START_KEY, None) + merged.pop(THINKER_DECODE_TOKEN_END_KEY, None) + self._send_side_request_payload[req_id] = merged + return dict(merged) + + def drop_inactive_request_runtime_state(self, req_id: str) -> None: + """Clear inactive request state used by both the runner and mixin. + + This centralizes the model-runner-side cleanup pattern so + ``OmniGPUModelRunner`` can reuse it instead of open-coding the same + inactive-request state mutations. + """ + if hasattr(self, "model_intermediate_buffer"): + self.model_intermediate_buffer.pop(req_id, None) + self.drop_inactive_request_delivery_state(req_id) + + # ------------------------------------------------------------------ # + # Helpers + # ------------------------------------------------------------------ # + + @staticmethod + def _freeze_request_attr(value: Any) -> Any: + if isinstance(value, list): + return list(value) + if isinstance(value, tuple): + return list(value) + if isinstance(value, torch.Tensor): + return value.clone() + raw_list = getattr(value, "_x", None) + if raw_list is not None: + return list(raw_list) + return value + + def _snapshot_request_for_send(self, request: Any, external_req_id: str) -> Any: + finished = bool(getattr(request, "is_finished", lambda: False)()) + attrs: dict[str, Any] = {} + try: + attrs.update(vars(request)) + except TypeError: + pass + + for name in ( + "request_id", + "req_id", + "external_req_id", + "prompt_token_ids", + "output_token_ids", + "all_token_ids", + "additional_information", + "sampling_params", + "multi_modal_data", + "mm_hashes", + ): + if hasattr(request, name): + attrs[name] = self._freeze_request_attr(getattr(request, name)) + + attrs["external_req_id"] = external_req_id + attrs["_frozen_is_finished"] = finished + snapshot = SimpleNamespace(**attrs) + snapshot.is_finished = lambda: finished + return snapshot + + @staticmethod + def _create_connector(model_config: Any) -> OmniConnectorBase | None: + """Create a connector from model_config, or None if unconfigured.""" + connector_config = getattr(model_config, "stage_connector_config", None) + if connector_config is None: + return None + + if not isinstance(connector_config, dict): + connector_config = { + "name": getattr(connector_config, "name", None), + "extra": getattr(connector_config, "extra", None), + } + + name = connector_config.get("name") + if not isinstance(name, str) or not name.strip(): + raise RuntimeError("Invalid stage connector config: missing connector name") + name = name.strip() + + extra = connector_config.get("extra") + if extra is None: + extra = {} + elif not isinstance(extra, dict): + raise RuntimeError(f"Invalid extra config for connector {name}: expected dict, got {type(extra).__name__}") + + spec = ConnectorSpec(name=name, extra=extra) + try: + return OmniConnectorFactory.create_connector(spec) + except Exception as exc: + raise RuntimeError(f"Failed to create connector {name}") from exc + + @staticmethod + def _load_custom_func(model_config: Any) -> tuple[str | None, Any | None]: + """Load the connector payload builder for the downstream stage. + + Preferred source is ``custom_process_next_stage_input_func``. Some + full_payload_mode configs (async_chunk=false) only expose the next-stage prompt builder via + ``custom_process_input_func`` (for example ``thinker2talker``), while the + connector payload builder lives beside it as ``thinker2talker_full_payload``. + In that case, derive the full_payload_mode builder path automatically. + """ + candidates: list[str] = [] + + next_stage_func = getattr(model_config, "custom_process_next_stage_input_func", None) + if isinstance(next_stage_func, str) and next_stage_func: + candidates.append(next_stage_func) + + if not getattr(model_config, "async_chunk", False): + input_func = getattr(model_config, "custom_process_input_func", None) + if isinstance(input_func, str) and input_func: + try: + module_path, func_name = input_func.rsplit(".", 1) + if func_name.endswith("_full_payload") or func_name.endswith("_batch"): + candidates.append(f"{module_path}.{func_name}") + else: + candidates.append(f"{module_path}.{func_name}_full_payload") + candidates.append(f"{module_path}.{func_name}_batch") + candidates.append(input_func) + except ValueError: + candidates.append(input_func) + + tried: set[str] = set() + for func_path in candidates: + if func_path in tried: + continue + tried.add(func_path) + try: + module_path, func_name = func_path.rsplit(".", 1) + module = importlib.import_module(module_path) + func = getattr(module, func_name, None) + if callable(func): + if not OmniConnectorModelRunnerMixin._is_connector_payload_builder(func): + logger.debug( + "Skipping incompatible connector payload hook %s; signature=%s", + func_path, + inspect.signature(func), + ) + continue + return func_path, func + except Exception: + logger.warning("Failed to load custom func: %s", func_path, exc_info=True) + + return None, None + + @staticmethod + def _is_connector_payload_builder(func: Any) -> bool: + """Whether *func* matches the mixin payload-builder contract.""" + try: + signature = inspect.signature(func) + except (TypeError, ValueError): + return False + + params = signature.parameters + if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()): + return True + + required = {"transfer_manager", "pooling_output", "request"} + supported = { + name + for name, param in params.items() + if param.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + } + return required.issubset(supported) + + def _resolve_external_req_id(self, request: Any, fallback_req_id: str) -> str: + """Resolve the external request ID consistently. + + Checks ``_request_ids_mapping`` first (populated by + ``register_chunk_recv``), then falls back to the request's + ``external_req_id`` attribute, and finally to the given + ``fallback_req_id``. + """ + mapped = self._request_ids_mapping.get(fallback_req_id) + if mapped is not None: + return mapped + if request is not None: + return getattr(request, "external_req_id", fallback_req_id) + return fallback_req_id + + def _resolve_next_stage_id(self, model_config: Any) -> int: + """Determine the downstream stage ID from connector config. + + Falls back to ``stage_id + 1`` when the config does not specify + a ``to_stage`` explicitly. + """ + connector_config = getattr(model_config, "stage_connector_config", None) + if connector_config is not None: + if isinstance(connector_config, dict): + to_stage = connector_config.get("to_stage") + else: + to_stage = getattr(connector_config, "to_stage", None) + if isinstance(to_stage, int): + return to_stage + if isinstance(to_stage, str) and to_stage.strip(): + return int(to_stage) + return self._stage_id + 1 + + @staticmethod + def _parse_rank_mapping(model_config: Any) -> dict[str, int]: + """Parse rank_mapping from connector config (optional). + + Returns ``{"from_tp": int, "to_tp": int, "local_rank": int}``. + When ``rank_mapping`` is absent, assumes 1:1 homogeneous mapping. + """ + connector_config = getattr(model_config, "stage_connector_config", None) + if connector_config is not None and not isinstance(connector_config, dict): + connector_config = getattr(connector_config, "__dict__", {}) + + rank_mapping: dict = {} + if isinstance(connector_config, dict): + rank_mapping = connector_config.get("rank_mapping", {}) + + from_tp = int(rank_mapping.get("from_tp", 1)) + to_tp = int(rank_mapping.get("to_tp", 1)) + + local_rank = 0 + try: + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + except (ValueError, TypeError): + pass + + return {"from_tp": from_tp, "to_tp": to_tp, "local_rank": local_rank} + + # ------------------------------------------------------------------ # + # Heterogeneous TP rank support + # ------------------------------------------------------------------ # + + def _validate_kv_tp_topology(self) -> None: + """Reject heterogeneous TP mappings that cannot be routed losslessly.""" + if self._from_tp <= 0 or self._to_tp <= 0: + raise ValueError(f"Invalid KV TP mapping: from_tp={self._from_tp}, to_tp={self._to_tp}") + larger = max(self._from_tp, self._to_tp) + smaller = min(self._from_tp, self._to_tp) + if larger % smaller != 0: + raise ValueError( + f"KV TP mapping must be divisible for rank-aware routing: from_tp={self._from_tp}, to_tp={self._to_tp}" + ) + + def get_kv_remote_ranks(self) -> list[int]: + """Determine which remote ranks this local rank exchanges KV with. + + Follows vLLM's ``TpKVTopology.get_target_remote_ranks()`` pattern: + - ``from_tp > to_tp``: each to-rank reads from multiple from-ranks + - ``from_tp < to_tp``: multiple to-ranks read from the same from-rank + - ``from_tp == to_tp``: 1:1 mapping + """ + self._validate_kv_tp_topology() + if self._from_tp == self._to_tp: + return [self._local_rank] + + if self._from_tp > self._to_tp: + tp_ratio = self._from_tp // self._to_tp + return [self._local_rank * tp_ratio + i for i in range(tp_ratio)] + else: + tp_ratio = self._to_tp // self._from_tp + return [self._local_rank // tp_ratio] + + def is_data_transfer_rank(self) -> bool: + """Whether this rank should participate in data (non-KV) transfer. + + Ordinary stage payloads are TP-identical, so exactly one TP rank + should talk to the connector. When TP is initialized, use TP rank 0 + so the connector leader matches TP-local broadcast source rank. + Otherwise fall back to LOCAL_RANK==0 for the single-rank case. + """ + tp_group = self._get_local_tp_group() + if tp_group is not None and getattr(tp_group, "world_size", 1) > 1: + return getattr(tp_group, "rank_in_group", 0) == 0 + return self._local_rank == 0 + + def get_kv_connector_key( + self, + req_id: str, + from_stage: int, + chunk_id: int, + from_rank: int, + to_rank: int, + ) -> str: + """Build connector key that includes rank info for KV transfers.""" + return f"{req_id}_{from_stage}_{chunk_id}_{from_rank}_{to_rank}" diff --git a/vllm_omni/worker/payload_span.py b/vllm_omni/worker/payload_span.py new file mode 100644 index 00000000000..994392343a9 --- /dev/null +++ b/vllm_omni/worker/payload_span.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Helpers for explicit thinker decode span metadata.""" + +from collections.abc import Mapping +from typing import Any + +import torch + +THINKER_DECODE_EMBEDDINGS_KEY = "thinker_decode_embeddings" +THINKER_OUTPUT_TOKEN_IDS_KEY = "thinker_output_token_ids" +THINKER_DECODE_TOKEN_START_KEY = "thinker_decode_embeddings_token_start" +THINKER_DECODE_TOKEN_END_KEY = "thinker_decode_embeddings_token_end" + +CACHED_THINKER_DECODE_EMBEDDINGS_KEY = "cached_thinker_decode_embeddings" +CACHED_THINKER_DECODE_TOKEN_START_KEY = "cached_thinker_decode_embeddings_token_start" +CACHED_THINKER_DECODE_TOKEN_END_KEY = "cached_thinker_decode_embeddings_token_end" + +TensorSpan = tuple[torch.Tensor, int, int] + + +def get_tensor_span(payload: Mapping[str, Any], *, tensor_key: str, start_key: str, end_key: str) -> TensorSpan | None: + tensor = payload.get(tensor_key) + start = payload.get(start_key) + end = payload.get(end_key) + if not isinstance(tensor, torch.Tensor): + return None + if not isinstance(start, int) or not isinstance(end, int): + return None + if start < 0 or end < start or (end - start) != int(tensor.shape[0]): + return None + return tensor, start, end + + +def merge_tensor_spans(existing_span: TensorSpan | None, incoming_span: TensorSpan | None) -> TensorSpan | None: + if existing_span is None or incoming_span is None: + return None + + existing_tensor, existing_start, existing_end = existing_span + incoming_tensor, incoming_start, incoming_end = incoming_span + if incoming_tensor.device != existing_tensor.device or incoming_tensor.dtype != existing_tensor.dtype: + incoming_tensor = incoming_tensor.to(device=existing_tensor.device, dtype=existing_tensor.dtype) + if incoming_start == existing_end: + return torch.cat([existing_tensor, incoming_tensor], dim=0), existing_start, incoming_end + if incoming_start < existing_end: + overlap = existing_end - incoming_start + if overlap >= int(incoming_tensor.shape[0]): + return existing_tensor, existing_start, existing_end + trimmed_tensor = incoming_tensor[overlap:] + return ( + torch.cat([existing_tensor, trimmed_tensor], dim=0), + existing_start, + existing_end + int(trimmed_tensor.shape[0]), + ) + return None + + +def get_tensor_span_row(span: TensorSpan | None, index: int) -> torch.Tensor | None: + if span is None: + return None + tensor, start, end = span + if index < start or index >= end: + return None + return tensor[index - start]