diff --git a/tests/diffusion/test_diffusion_engine_rpc_routing.py b/tests/diffusion/test_diffusion_engine_rpc_routing.py index f731f24728e..b1b63a1a2b8 100644 --- a/tests/diffusion/test_diffusion_engine_rpc_routing.py +++ b/tests/diffusion/test_diffusion_engine_rpc_routing.py @@ -136,6 +136,7 @@ def _make_engine_with_loop( calls and a real ``RequestScheduler``. """ engine = DiffusionEngine.__new__(DiffusionEngine) + engine._closed = False engine.executor = _ConcurrencyTrackingExecutor(rpc_delay=rpc_delay) sched = RequestScheduler() @@ -371,6 +372,7 @@ def test_collective_rpc_before_loop_starts_calls_executor_directly(): on the caller's thread without enqueueing. """ engine = DiffusionEngine.__new__(DiffusionEngine) + engine._closed = False engine._loop_started = False engine._rpc_lock = threading.RLock() engine._cv = threading.Condition(engine._rpc_lock) @@ -391,6 +393,7 @@ def test_collective_rpc_before_loop_starts_serializes_concurrent_callers(): serialize them so they cannot race on the shared executor MQ pair. """ engine = DiffusionEngine.__new__(DiffusionEngine) + engine._closed = False engine._loop_started = False engine._rpc_lock = threading.RLock() engine._cv = threading.Condition(engine._rpc_lock) diff --git a/tests/distributed/omni_coordinator/test_load_balancer.py b/tests/distributed/omni_coordinator/test_load_balancer.py index 8350b33d396..1bf4575eac4 100644 --- a/tests/distributed/omni_coordinator/test_load_balancer.py +++ b/tests/distributed/omni_coordinator/test_load_balancer.py @@ -6,18 +6,18 @@ import pytest from vllm_omni.distributed.omni_coordinator import ( - InstanceInfo, LeastQueueLengthBalancer, RandomBalancer, + ReplicaInfo, + ReplicaStatus, RoundRobinBalancer, - StageStatus, ) pytestmark = [pytest.mark.core_model, pytest.mark.cpu] def test_load_balancer_select_returns_valid_index(): - """Verify RandomBalancer.select() returns a valid index for instances.""" + """Verify RandomBalancer.select() returns a valid index for replicas.""" # Task structure mirrors async_omni; RandomBalancer ignores task contents. task: dict = { "request_id": "test", @@ -26,30 +26,30 @@ def test_load_balancer_select_returns_valid_index(): } now = time() - instances = [ - InstanceInfo( + replicas = [ + ReplicaInfo( input_addr="tcp://host:10001", output_addr="tcp://host:10001-out", stage_id=0, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=0, last_heartbeat=now, registered_at=now, ), - InstanceInfo( + ReplicaInfo( input_addr="tcp://host:10002", output_addr="tcp://host:10002-out", stage_id=0, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=1, last_heartbeat=now, registered_at=now, ), - InstanceInfo( + ReplicaInfo( input_addr="tcp://host:10003", output_addr="tcp://host:10003-out", stage_id=1, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=2, last_heartbeat=now, registered_at=now, @@ -58,38 +58,38 @@ def test_load_balancer_select_returns_valid_index(): balancer = RandomBalancer() - index = balancer.select(task, instances) + index = balancer.select(task, replicas) assert isinstance(index, int) - assert 0 <= index < len(instances) + assert 0 <= index < len(replicas) -def test_round_robin_balancer_cycles_instances(): +def test_round_robin_balancer_cycles_replicas(): now = time() - instances = [ - InstanceInfo( + replicas = [ + ReplicaInfo( input_addr="tcp://host:10001", output_addr="tcp://host:10001-out", stage_id=0, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=2, last_heartbeat=now, registered_at=now, ), - InstanceInfo( + ReplicaInfo( input_addr="tcp://host:10002", output_addr="tcp://host:10002-out", stage_id=0, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=1, last_heartbeat=now, registered_at=now, ), - InstanceInfo( + ReplicaInfo( input_addr="tcp://host:10003", output_addr="tcp://host:10003-out", stage_id=1, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=0, last_heartbeat=now, registered_at=now, @@ -97,35 +97,35 @@ def test_round_robin_balancer_cycles_instances(): ] balancer = RoundRobinBalancer() - results = [balancer.select({}, instances) for _ in range(5)] + results = [balancer.select({}, replicas) for _ in range(5)] # Default start_index=0 => 0,1,2,0,1 assert results == [0, 1, 2, 0, 1] -def test_round_robin_balancer_empty_instances_raises(): - with pytest.raises(ValueError, match="instances must not be empty"): +def test_round_robin_balancer_empty_replicas_raises(): + with pytest.raises(ValueError, match="replicas must not be empty"): RoundRobinBalancer().select({}, []) def test_round_robin_balancer_after_large_index_and_shorter_list(): - """Large start_index % len(instances) then counter wraps with shorter list.""" + """Large start_index % len(replicas) then counter wraps with shorter list.""" now = time() two = [ - InstanceInfo( + ReplicaInfo( input_addr="tcp://host:10001", output_addr="tcp://host:10001-out", stage_id=0, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=0, last_heartbeat=now, registered_at=now, ), - InstanceInfo( + ReplicaInfo( input_addr="tcp://host:10002", output_addr="tcp://host:10002-out", stage_id=0, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=0, last_heartbeat=now, registered_at=now, @@ -138,30 +138,30 @@ def test_round_robin_balancer_after_large_index_and_shorter_list(): def test_least_queue_length_balancer_picks_min_queue(): now = time() - instances = [ - InstanceInfo( + replicas = [ + ReplicaInfo( input_addr="tcp://host:10001", output_addr="tcp://host:10001-out", stage_id=0, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=2, last_heartbeat=now, registered_at=now, ), - InstanceInfo( + ReplicaInfo( input_addr="tcp://host:10002", output_addr="tcp://host:10002-out", stage_id=0, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=0, last_heartbeat=now, registered_at=now, ), - InstanceInfo( + ReplicaInfo( input_addr="tcp://host:10003", output_addr="tcp://host:10003-out", stage_id=1, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=5, last_heartbeat=now, registered_at=now, @@ -169,41 +169,41 @@ def test_least_queue_length_balancer_picks_min_queue(): ] balancer = LeastQueueLengthBalancer() - index = balancer.select({}, instances) + index = balancer.select({}, replicas) assert index == 1 -def test_least_queue_length_balancer_empty_instances_raises(): - with pytest.raises(ValueError, match="instances must not be empty"): +def test_least_queue_length_balancer_empty_replicas_raises(): + with pytest.raises(ValueError, match="replicas must not be empty"): LeastQueueLengthBalancer().select({}, []) def test_least_queue_length_balancer_equal_queues_uses_choice(mocker): now = time() - instances = [ - InstanceInfo( + replicas = [ + ReplicaInfo( input_addr="tcp://host:10001", output_addr="tcp://host:10001-out", stage_id=0, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=3, last_heartbeat=now, registered_at=now, ), - InstanceInfo( + ReplicaInfo( input_addr="tcp://host:10002", output_addr="tcp://host:10002-out", stage_id=0, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=3, last_heartbeat=now, registered_at=now, ), - InstanceInfo( + ReplicaInfo( input_addr="tcp://host:10003", output_addr="tcp://host:10003-out", stage_id=1, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=3, last_heartbeat=now, registered_at=now, @@ -214,21 +214,21 @@ def test_least_queue_length_balancer_equal_queues_uses_choice(mocker): "vllm_omni.distributed.omni_coordinator.load_balancer.random.choice", return_value=2, ) - assert balancer.select({}, instances) == 2 + assert balancer.select({}, replicas) == 2 def test_least_queue_length_balancer_negative_queue_raises(): now = time() - instances = [ - InstanceInfo( + replicas = [ + ReplicaInfo( input_addr="tcp://host:10001", output_addr="tcp://host:10001-out", stage_id=0, - status=StageStatus.UP, + status=ReplicaStatus.UP, queue_length=-1, last_heartbeat=now, registered_at=now, ), ] with pytest.raises(ValueError, match="queue_length must be non-negative"): - LeastQueueLengthBalancer().select({}, instances) + LeastQueueLengthBalancer().select({}, replicas) diff --git a/tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py b/tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py index 2fbd7c85bf8..542e3b06da8 100644 --- a/tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py +++ b/tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py @@ -8,8 +8,8 @@ import zmq from vllm_omni.distributed.omni_coordinator import ( - InstanceList, OmniCoordClientForHub, + ReplicaList, ) pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -32,8 +32,8 @@ def _wait_for_condition(cond, timeout: float = 2.0, interval: float = 0.01) -> b return False -def test_hub_client_caches_instance_list_from_pub(): - """Verify OmniCoordClientForHub receives instance list updates from OmniCoordinator and caches for get_instance_list().""" +def test_hub_client_caches_replica_list_from_pub(): + """Verify OmniCoordClientForHub receives replica list updates from OmniCoordinator and caches for get_replica_list().""" ctx, pub, endpoint = _bind_pub() client = OmniCoordClientForHub(endpoint) @@ -41,7 +41,7 @@ def test_hub_client_caches_instance_list_from_pub(): time.sleep(0.2) now = time.time() - instances_payload = [ + replicas_payload = [ { "input_addr": "tcp://stage:10001", "output_addr": "tcp://stage:10001-out", @@ -71,37 +71,37 @@ def test_hub_client_caches_instance_list_from_pub(): }, ] - payload = {"instances": instances_payload, "timestamp": now} + payload = {"replicas": replicas_payload, "timestamp": now} pub.send(json.dumps(payload).encode("utf-8")) - assert _wait_for_condition(lambda: len(client.get_instance_list().instances) == 3) + assert _wait_for_condition(lambda: len(client.get_replica_list().replicas) == 3) - inst_list = client.get_instance_list() - assert isinstance(inst_list, InstanceList) - assert len(inst_list.instances) == 3 + rep_list = client.get_replica_list() + assert isinstance(rep_list, ReplicaList) + assert len(rep_list.replicas) == 3 - for src, inst in zip(instances_payload, inst_list.instances, strict=True): - assert inst.input_addr == src["input_addr"] - assert inst.output_addr == src["output_addr"] - assert inst.stage_id == src["stage_id"] - assert inst.status.value == src["status"] + for src, rep in zip(replicas_payload, rep_list.replicas, strict=True): + assert rep.input_addr == src["input_addr"] + assert rep.output_addr == src["output_addr"] + assert rep.stage_id == src["stage_id"] + assert rep.status.value == src["status"] - stage0 = client.get_instances_for_stage(0) - stage1 = client.get_instances_for_stage(1) + stage0 = client.get_replicas_for_stage(0) + stage1 = client.get_replicas_for_stage(1) - assert all(inst.stage_id == 0 for inst in stage0.instances) - assert all(inst.stage_id == 1 for inst in stage1.instances) + assert all(rep.stage_id == 0 for rep in stage0.replicas) + assert all(rep.stage_id == 1 for rep in stage1.replicas) - # Send an updated list with fewer instances and verify cache refresh. + # Send an updated list with fewer replicas and verify cache refresh. updated_payload = { - "instances": instances_payload[:2], + "replicas": replicas_payload[:2], "timestamp": now + 1.0, } pub.send(json.dumps(updated_payload).encode("utf-8")) - assert _wait_for_condition(lambda: len(client.get_instance_list().instances) == 2) - updated_list = client.get_instance_list() - assert len(updated_list.instances) == 2 + assert _wait_for_condition(lambda: len(client.get_replica_list().replicas) == 2) + updated_list = client.get_replica_list() + assert len(updated_list.replicas) == 2 client.close() pub.close(0) diff --git a/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py b/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py index 0ba19c7fff7..46eb2e0fb90 100644 --- a/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py +++ b/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py @@ -9,7 +9,7 @@ from vllm_omni.distributed.omni_coordinator import ( OmniCoordClientForStage, - StageStatus, + ReplicaStatus, ) from vllm_omni.distributed.omni_coordinator import ( omni_coord_client_for_stage as stage_client_module, @@ -47,7 +47,7 @@ def test_stage_client_auto_register_on_init(): event = _recv_event(router) assert event["event_type"] == "update" - assert event["status"] == StageStatus.UP.value + assert event["status"] == ReplicaStatus.UP.value assert event["stage_id"] == stage_id assert event["input_addr"] == input_addr assert event["output_addr"] == output_addr @@ -70,13 +70,13 @@ def test_stage_client_update_info_sends_correct_event(): # Discard initial registration event. _recv_event(router) - client.update_info(status=StageStatus.ERROR) + client.update_info(status=ReplicaStatus.ERROR) client.update_info(queue_length=10) first = _recv_event(router) second = _recv_event(router) - assert first["status"] == StageStatus.ERROR.value + assert first["status"] == ReplicaStatus.ERROR.value assert first["stage_id"] == stage_id assert first["input_addr"] == input_addr assert first["output_addr"] == output_addr @@ -107,7 +107,7 @@ def test_stage_client_close_sends_down_status(): client.close() event = _recv_event(router) - assert event["status"] == StageStatus.DOWN.value + assert event["status"] == ReplicaStatus.DOWN.value assert event["stage_id"] == stage_id assert event["input_addr"] == input_addr assert event["output_addr"] == output_addr @@ -274,6 +274,7 @@ def set(self): client._closed = False client._heartbeat_interval = 0.0 client._stop_event = _FakeStopEvent() + client._on_heartbeat = None calls = {"count": 0} diff --git a/tests/distributed/omni_coordinator/test_omni_coordinator.py b/tests/distributed/omni_coordinator/test_omni_coordinator.py index eff3d429e40..81bd12454cb 100644 --- a/tests/distributed/omni_coordinator/test_omni_coordinator.py +++ b/tests/distributed/omni_coordinator/test_omni_coordinator.py @@ -11,14 +11,14 @@ from vllm_omni.distributed.omni_coordinator import ( OmniCoordClientForStage, OmniCoordinator, - StageStatus, + ReplicaStatus, ) pytestmark = [pytest.mark.core_model, pytest.mark.cpu] -def _recv_instance_list(sub: zmq.Socket, timeout_ms: int = 2000) -> dict | None: - """Receive InstanceList JSON from SUB socket. Returns None on timeout.""" +def _recv_replica_list(sub: zmq.Socket, timeout_ms: int = 2000) -> dict | None: + """Receive ReplicaList JSON from SUB socket. Returns None on timeout.""" sub.setsockopt(zmq.RCVTIMEO, timeout_ms) try: data = sub.recv() @@ -27,16 +27,16 @@ def _recv_instance_list(sub: zmq.Socket, timeout_ms: int = 2000) -> dict | None: return None -def _wait_for_instance_list( +def _wait_for_replica_list( sub: zmq.Socket, expected_count: int, timeout: float = 3.0, ) -> dict | None: - """Wait until received InstanceList with expected_count active instances.""" + """Wait until received ReplicaList with expected_count active replicas.""" start = time.time() while time.time() - start < timeout: - msg = _recv_instance_list(sub, timeout_ms=500) - if msg is not None and len(msg.get("instances", [])) == expected_count: + msg = _recv_replica_list(sub, timeout_ms=500) + if msg is not None and len(msg.get("replicas", [])) == expected_count: return msg return None @@ -45,7 +45,7 @@ def _drain_sub_messages(sub: zmq.Socket, max_seconds: float = 0.4) -> None: """Drain queued SUB messages for a short window.""" deadline = time.time() + max_seconds while time.time() < deadline: - _recv_instance_list(sub, timeout_ms=50) + _recv_replica_list(sub, timeout_ms=50) def test_omni_coordinator_pub_coalescing_on_rapid_queue_updates(): @@ -81,7 +81,7 @@ def test_omni_coordinator_pub_coalescing_on_rapid_queue_updates(): ) # Wait for initial registration broadcast and clear any queued messages. - msg = _wait_for_instance_list(sub, expected_count=1) + msg = _wait_for_replica_list(sub, expected_count=1) assert msg is not None _drain_sub_messages(sub) @@ -96,7 +96,7 @@ def test_omni_coordinator_pub_coalescing_on_rapid_queue_updates(): deadline = time.time() + window_s recv_count = 0 while time.time() < deadline: - if _recv_instance_list(sub, timeout_ms=100) is not None: + if _recv_replica_list(sub, timeout_ms=100) is not None: recv_count += 1 assert recv_count < update_count // 2, ( @@ -110,8 +110,8 @@ def test_omni_coordinator_pub_coalescing_on_rapid_queue_updates(): def test_omni_coordinator_registration_broadcast(): - """Verify that after multiple OmniCoordClientForStage instances register, - OmniCoordinator publishes an InstanceList containing all registered instances. + """Verify that after multiple OmniCoordClientForStage replicas register, + OmniCoordinator publishes a ReplicaList containing all registered replicas. """ router_addr = get_engine_client_zmq_addr( local_only=False, @@ -144,12 +144,12 @@ def test_omni_coordinator_registration_broadcast(): OmniCoordClientForStage(router_addr, "tcp://stage:10003", "tcp://stage:10003-out", 1), ] - msg = _wait_for_instance_list(sub, expected_count=3) - assert msg is not None, "Expected InstanceList with 3 instances" - assert len(msg["instances"]) == 3 + msg = _wait_for_replica_list(sub, expected_count=3) + assert msg is not None, "Expected ReplicaList with 3 replicas" + assert len(msg["replicas"]) == 3 assert isinstance(msg["timestamp"], (int, float)) - input_addrs = {inst["input_addr"] for inst in msg["instances"]} + input_addrs = {rep["input_addr"] for rep in msg["replicas"]} assert "tcp://stage:10001" in input_addrs assert "tcp://stage:10002" in input_addrs assert "tcp://stage:10003" in input_addrs @@ -162,7 +162,7 @@ def test_omni_coordinator_registration_broadcast(): def test_omni_coordinator_heartbeat_timeout_handling(): - """Verify that when a stage instance stops sending heartbeats, + """Verify that when a stage replica stops sending heartbeats, OmniCoordinator marks it as unhealthy and excludes it from the active list. """ router_addr = get_engine_client_zmq_addr( @@ -201,23 +201,23 @@ def test_omni_coordinator_heartbeat_timeout_handling(): "output_addr": "tcp://stage:c-out", "stage_id": 0, "event_type": "update", - "status": StageStatus.UP.value, + "status": ReplicaStatus.UP.value, "queue_length": 0, } dealer_c.send(json.dumps(reg_event).encode("utf-8")) - msg = _wait_for_instance_list(sub, expected_count=3) - assert msg is not None, "Expected initial 3 instances" - assert len(msg["instances"]) == 3 + msg = _wait_for_replica_list(sub, expected_count=3) + assert msg is not None, "Expected initial 3 replicas" + assert len(msg["replicas"]) == 3 # Wait for heartbeat timeout (timeout=5s, check interval ~2.5s). time.sleep(8.0) # Receive the update (C should be ERROR and excluded from active list). - msg_after_timeout = _wait_for_instance_list(sub, expected_count=2, timeout=5.0) - assert msg_after_timeout is not None, "Expected InstanceList with 2 instances after timeout" - instances = msg_after_timeout.get("instances", []) - input_addrs = {inst["input_addr"] for inst in instances} + msg_after_timeout = _wait_for_replica_list(sub, expected_count=2, timeout=5.0) + assert msg_after_timeout is not None, "Expected ReplicaList with 2 replicas after timeout" + replicas = msg_after_timeout.get("replicas", []) + input_addrs = {rep["input_addr"] for rep in replicas} assert "tcp://stage:a" in input_addrs assert "tcp://stage:b" in input_addrs @@ -232,8 +232,8 @@ def test_omni_coordinator_heartbeat_timeout_handling(): sub_ctx.term() -def test_omni_coordinator_instance_shutdown_handling(): - """Verify that when a stage instance sends status='down', +def test_omni_coordinator_replica_shutdown_handling(): + """Verify that when a stage replica sends status='down', OmniCoordinator removes it from the active list and broadcasts an updated list. """ router_addr = get_engine_client_zmq_addr( @@ -261,18 +261,18 @@ def test_omni_coordinator_instance_shutdown_handling(): client = OmniCoordClientForStage(router_addr, "tcp://stage:shutdown", "tcp://stage:shutdown-out", 0) - msg = _wait_for_instance_list(sub, expected_count=1) + msg = _wait_for_replica_list(sub, expected_count=1) assert msg is not None - assert len(msg["instances"]) == 1 - assert msg["instances"][0]["input_addr"] == "tcp://stage:shutdown" + assert len(msg["replicas"]) == 1 + assert msg["replicas"][0]["input_addr"] == "tcp://stage:shutdown" # Send down status (simulating graceful shutdown). - client.update_info(status=StageStatus.DOWN) + client.update_info(status=ReplicaStatus.DOWN) - # Receive updated list (should have 0 active instances). - msg = _wait_for_instance_list(sub, expected_count=0) + # Receive updated list (should have 0 active replicas). + msg = _wait_for_replica_list(sub, expected_count=0) assert msg is not None - assert len(msg["instances"]) == 0 + assert len(msg["replicas"]) == 0 client.close() coordinator.close() diff --git a/tests/engine/test_async_omni_engine_stage_init.py b/tests/engine/test_async_omni_engine_stage_init.py index 51ed6b46cc2..b423aa407e7 100644 --- a/tests/engine/test_async_omni_engine_stage_init.py +++ b/tests/engine/test_async_omni_engine_stage_init.py @@ -171,18 +171,6 @@ def test_collect_initialized_clients_for_cleanup_deduplicates_clients(): assert cleanup_clients == [shared, extra] -def test_initialize_stages_rejects_non_diffusion_replicas_in_single_stage_mode(): - engine = object.__new__(AsyncOmniEngine) - engine.single_stage_mode = True - engine.stage_configs = [types.SimpleNamespace(stage_id=0, runtime={"num_replicas": 2})] - - with pytest.raises( - ValueError, - match="single_stage_mode only supports num_replicas > 1 for diffusion stages", - ): - engine._validate_single_stage_mode_replica_constraints() - - def test_initialize_diffusion_replica_restores_device_visibility_after_local_init(monkeypatch): import vllm_omni.engine.async_omni_engine as engine_mod from vllm_omni.platforms import current_omni_platform @@ -427,6 +415,7 @@ def test_initialize_stages_cleans_up_successful_replicas_after_partial_multi_rep engine.single_stage_mode = False engine._single_stage_id_filter = None engine._omni_master_server = None + engine._coordinator_runtime = None engine.stage_configs = [types.SimpleNamespace()] cfg0 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64)) @@ -471,6 +460,7 @@ def test_initialize_stages_cleans_up_late_successful_replicas_after_early_multi_ engine.single_stage_mode = False engine._single_stage_id_filter = None engine._omni_master_server = None + engine._coordinator_runtime = None engine.stage_configs = [types.SimpleNamespace()] cfg0 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64)) diff --git a/tests/engine/test_orchestrator.py b/tests/engine/test_orchestrator.py index c84256f048f..9e050a3f872 100644 --- a/tests/engine/test_orchestrator.py +++ b/tests/engine/test_orchestrator.py @@ -1052,7 +1052,7 @@ async def test_stage_pool_abort_requests_logs_when_binding_is_missing(caplog) -> target_logger.setLevel(prev_level) assert not stage0.abort_calls - assert "abort: no binding for req=missing-req in stage-0" in caplog.text + assert "abort: no live binding for req=missing-req in stage-0" in caplog.text @pytest.mark.asyncio diff --git a/tests/engine/test_single_stage_mode.py b/tests/engine/test_single_stage_mode.py index bcd910620b4..52187c34172 100644 --- a/tests/engine/test_single_stage_mode.py +++ b/tests/engine/test_single_stage_mode.py @@ -203,6 +203,32 @@ def test_get_allocation_returns_correct_object(self): server = OmniMasterServer(master_address="127.0.0.1", master_port=15004, stage_ids=[3]) assert server.get_allocation(3) is server._stage_routes[(3, 0)] + def test_next_free_replica_id_skips_head_local_slot_until_filled(self): + # Head pre-allocates slot (0, 0) for its own register_stage_with_omni_master + # call. A same-host headless that registers with auto-assign BEFORE the + # head's own registration must NOT be handed slot 0 — it should land on + # slot 1 instead. Without the head_local_replicas reservation, + # _next_free_replica_id would see (0, 0) absent from _stage_configs and + # return 0, colliding with the head's own bound sockets. + server = OmniMasterServer( + master_address="127.0.0.1", + master_port=15010, + stage_ids=[0], + head_local_replicas={0: [0]}, + ) + assert server._next_free_replica_id(0) == 1 + + def test_next_free_replica_id_uses_remote_slot_when_unowned(self): + # When the head pre-allocates a remote-only slot (the head's _initialize_* + # path waits on get_stage_config), auto-assign SHOULD fill it so the + # head's wait unblocks. This is the original behavior, preserved. + server = OmniMasterServer( + master_address="127.0.0.1", + master_port=15011, + stage_ids=[0], + ) + assert server._next_free_replica_id(0) == 0 + # --------------------------------------------------------------------------- # OmniMasterServer registration flow @@ -233,36 +259,6 @@ def test_registration_reply_contains_handshake_address(self): ctx.term() server.stop() - def test_server_handles_unknown_stage_id_gracefully(self): - import msgspec - import zmq - from vllm.utils.network_utils import get_open_port - - master_port = get_open_port() - server = OmniMasterServer(master_address="127.0.0.1", master_port=master_port, stage_ids=[0]) - server.start() - - ctx = zmq.Context() - bad_sock = None - good_sock = None - try: - bad_sock = ctx.socket(zmq.DEALER) - bad_sock.connect(f"tcp://127.0.0.1:{master_port}") - bad_sock.send(msgspec.msgpack.encode({"stage_id": 99})) - assert not bad_sock.poll(timeout=500) - - good_sock = ctx.socket(zmq.DEALER) - good_sock.connect(f"tcp://127.0.0.1:{master_port}") - good_sock.send(msgspec.msgpack.encode({"stage_id": 0})) - assert good_sock.poll(timeout=2_000) - good_sock.recv() - finally: - for sock in (bad_sock, good_sock): - if sock is not None: - sock.close(linger=0) - ctx.term() - server.stop() - def test_registration_stores_stage_config(self): import msgspec import zmq @@ -474,6 +470,8 @@ def _build_engine( engine._omni_master_address = "127.0.0.1" engine._omni_master_port = 26000 engine._omni_master_server = None + engine._omni_heartbeat_timeout = 30.0 + engine._coordinator_runtime = None engine.async_chunk = False engine.diffusion_batch_size = 2 return engine @@ -508,10 +506,16 @@ def test_build_logical_stage_init_plans_marks_non_matching_stage_remote(self, mo def test_start_omni_master_server_uses_configured_stage_ids(self, mocker: MockerFixture): import vllm_omni.engine.async_omni_engine as engine_mod + from vllm_omni.distributed import omni_coordinator as omni_coord_mod engine = self._build_engine([], single_stage_mode=True, stage_id_filter=7) mock_oms = mocker.Mock(spec=OmniMasterServer) mocker.patch.object(engine_mod, "OmniMasterServer", return_value=mock_oms) + mocker.patch.object( + omni_coord_mod, + "OmniCoordinatorRuntime", + return_value=mocker.Mock(router_address="tcp://127.0.0.1:9999"), + ) stage_plans = [ _make_llm_plan(0, configured_stage_id=7, launch_mode="local"), @@ -520,12 +524,16 @@ def test_start_omni_master_server_uses_configured_stage_ids(self, mocker: Mocker engine._start_omni_master_server(stage_plans) - engine_mod.OmniMasterServer.assert_called_once_with( - master_address="127.0.0.1", - master_port=26000, - stage_ids=[7, 11], - stage_replica_counts={7: 1, 11: 1}, - ) + call_kwargs = engine_mod.OmniMasterServer.call_args.kwargs + assert call_kwargs["master_address"] == "127.0.0.1" + assert call_kwargs["master_port"] == 26000 + assert call_kwargs["stage_ids"] == [7, 11] + assert call_kwargs["stage_replica_counts"] == {7: 1, 11: 1} + # head_local_replicas reserves slots that the head will register + # itself (launch_mode == "local"). Stage 11 is remote, so it must + # NOT appear in the head-owned set — that slot is for the headless + # to fill via auto-assign. + assert call_kwargs["head_local_replicas"] == {7: [0]} mock_oms.start.assert_called_once() def test_start_omni_master_server_duplicate_stage_ids_raise(self): @@ -545,7 +553,9 @@ def test_start_omni_master_server_missing_address_raises(self): with pytest.raises(ValueError, match="requires both"): engine._start_omni_master_server([_make_llm_plan(0, configured_stage_id=7, launch_mode="local")]) - def test_build_logical_stage_init_plans_clears_runtime_cfg_in_single_stage_mode(self, mocker: MockerFixture): + def test_build_logical_stage_init_plans_preserves_runtime_cfg_for_local_llm_in_single_stage_mode( + self, mocker: MockerFixture + ): import vllm_omni.engine.async_omni_engine as engine_mod engine = self._build_engine([_make_stage_cfg(7)], single_stage_mode=True, stage_id_filter=7) @@ -574,7 +584,7 @@ def test_build_logical_stage_init_plans_clears_runtime_cfg_in_single_stage_mode( finally: monkeypatch.undo() - assert stage_plans[0].replicas[0].metadata.runtime_cfg is None + assert stage_plans[0].replicas[0].metadata.runtime_cfg == {"devices": "0"} def test_validate_single_stage_mode_allows_diffusion_replicas(self): stage_cfg = _make_stage_cfg(0, stage_type="diffusion") @@ -583,14 +593,6 @@ def test_validate_single_stage_mode_allows_diffusion_replicas(self): engine._validate_single_stage_mode_replica_constraints() - def test_validate_single_stage_mode_rejects_llm_replicas(self): - stage_cfg = _make_stage_cfg(0, stage_type="llm") - stage_cfg.runtime.num_replicas = 2 - engine = self._build_engine([stage_cfg], single_stage_mode=True, stage_id_filter=0) - - with pytest.raises(ValueError, match="only supports num_replicas > 1 for diffusion"): - engine._validate_single_stage_mode_replica_constraints() - def test_build_logical_stage_init_plans_preserves_diffusion_runtime_cfg_in_single_stage_mode( self, mocker: MockerFixture ): @@ -809,6 +811,7 @@ def test_initialize_llm_replica_single_stage_local_uses_launch_omni_core_engines engine.model = "fake-model" engine.single_stage_mode = True engine._omni_master_server = mocker.Mock(spec=OmniMasterServer) + engine._coordinator_runtime = None engine.stage_configs = [] fake_vllm_config = SimpleNamespace(parallel_config=SimpleNamespace()) @@ -896,6 +899,7 @@ def test_initialize_diffusion_replica_single_stage_local_registers_with_master(s engine._omni_master_server = mocker.Mock(spec=OmniMasterServer) engine._omni_master_server.address = "127.0.0.1" engine._omni_master_server.port = 25000 + engine._coordinator_runtime = None plan = _make_diffusion_plan(0, configured_stage_id=5, launch_mode="local").replicas[0] sentinel_client = SimpleNamespace() @@ -950,6 +954,9 @@ def test_initialize_diffusion_replica_single_stage_local_registers_with_master(s handshake_address="tcp://hs", request_address="tcp://req", response_address="tcp://resp", + omni_coordinator_address=None, + omni_stage_id=5, + omni_replica_id=0, ) mock_handshake.assert_called_once_with(proc, "tcp://hs", 60) mock_from_addresses.assert_called_once_with( @@ -972,6 +979,7 @@ def test_initialize_diffusion_replica_local_failure_terminates_proc(self, mocker engine._omni_master_server = mocker.Mock(spec=OmniMasterServer) engine._omni_master_server.address = "127.0.0.1" engine._omni_master_server.port = 25000 + engine._coordinator_runtime = None plan = _make_diffusion_plan(0, configured_stage_id=5, launch_mode="local").replicas[0] proc = mocker.Mock() diff --git a/tests/entrypoints/test_serve.py b/tests/entrypoints/test_serve.py index 1267a189a32..86c311a60ad 100644 --- a/tests/entrypoints/test_serve.py +++ b/tests/entrypoints/test_serve.py @@ -3,6 +3,8 @@ from __future__ import annotations import argparse +from types import SimpleNamespace +from typing import Any import pytest from pytest_mock import MockerFixture @@ -31,226 +33,366 @@ def test_serve_parser_accepts_no_async_chunk() -> None: assert args.async_chunk is False -def _make_headless_args() -> argparse.Namespace: - return argparse.Namespace( +# --------------------------------------------------------------------------- +# run_headless validation +# --------------------------------------------------------------------------- + + +def _make_headless_args(**overrides: Any) -> argparse.Namespace: + """Build an argparse.Namespace shaped like the headless CLI passes in. + + Defaults pass every validation gate so individual tests can mutate just + the field they're exercising. + """ + defaults = dict( model="fake-model", - stage_id=3, + stage_id=0, replica_id=0, omni_master_address="127.0.0.1", omni_master_port=26000, + omni_replica_address=None, + omni_dp_size_local=1, api_server_count=0, worker_backend="multi_process", stage_configs_path=None, + deploy_config=None, log_stats=False, disable_log_stats=False, stage_init_timeout=600, + tokenizer=None, + ) + defaults.update(overrides) + return argparse.Namespace(**defaults) + + +def test_run_headless_requires_stage_id() -> None: + args = _make_headless_args(stage_id=None) + with pytest.raises(ValueError, match="--stage-id is required"): + run_headless(args) + + +def test_run_headless_requires_master_address() -> None: + args = _make_headless_args(omni_master_address=None) + with pytest.raises(ValueError, match="--omni-master-address and --omni-master-port"): + run_headless(args) + + +def test_run_headless_requires_master_port() -> None: + args = _make_headless_args(omni_master_port=None) + with pytest.raises(ValueError, match="--omni-master-address and --omni-master-port"): + run_headless(args) + + +def test_run_headless_rejects_multi_api_server_count() -> None: + args = _make_headless_args(api_server_count=2) + with pytest.raises(ValueError, match="api_server_count can't be set"): + run_headless(args) + + +def test_run_headless_rejects_non_multiprocess_worker_backend() -> None: + args = _make_headless_args(worker_backend="ray") + with pytest.raises(ValueError, match="worker_backend=multi_process"): + run_headless(args) + + +def test_run_headless_raises_when_stage_id_not_in_configs(mocker: MockerFixture) -> None: + """Headless looks up its assigned stage_id in the loaded deploy YAML and + fails fast when the launcher's --stage-id doesn't match any entry.""" + other_stage = SimpleNamespace(stage_id=99) + mocker.patch( + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", + return_value=("/fake/stages.yaml", [other_stage]), + ) + + args = _make_headless_args(stage_id=0) + with pytest.raises(ValueError, match="No stage config found for stage_id=0"): + run_headless(args) + + +# --------------------------------------------------------------------------- +# run_headless happy paths +# --------------------------------------------------------------------------- + + +def _make_stage_cfg(stage_id: int, stage_type: str) -> SimpleNamespace: + """Build a stage config that satisfies every attribute run_headless reads. + + Notably ``engine_args`` is a real dict (not a Mock) so + ``get_stage_devices_per_replica`` can call ``.get("tensor_parallel_size")`` + and feed the result through ``int()`` without TypeError. + """ + return SimpleNamespace( + stage_id=stage_id, + stage_type=stage_type, + # No "devices" key -> split_devices_for_replicas skipped, each replica + # inherits the launcher's CUDA_VISIBLE_DEVICES. + runtime=None, + engine_args={}, ) -def test_run_headless_registers_stage_once_and_launches_all_local_engines(mocker: MockerFixture) -> None: - args = _make_headless_args() - stage_cfg = mocker.Mock(stage_id=3) - stage_cfgs = [stage_cfg] - parallel_config = mocker.Mock( - data_parallel_size_local=2, - data_parallel_rank=4, - data_parallel_rank_local=1, +def test_run_headless_llm_registers_with_auto_assigned_replica_id(mocker: MockerFixture) -> None: + """LLM headless: each loop iteration registers with auto-assigned + replica_id (master picks a free slot) and spawns one + ``OmniCoreEngineProcManager`` per local replica.""" + from vllm_omni.engine.stage_engine_startup import StageRegistrationResponse + + stage_cfg = _make_stage_cfg(0, stage_type="llm") + parallel_config = SimpleNamespace( + data_parallel_size_local=1, + data_parallel_rank=0, + data_parallel_rank_local=0, node_rank_within_dp=0, ) - vllm_config = mocker.Mock(parallel_config=parallel_config) - vllm_config.needs_dp_coordinator = False - executor_class = mocker.Mock() + vllm_config = SimpleNamespace(parallel_config=parallel_config, needs_dp_coordinator=False) engine_manager = mocker.Mock() mocker.patch( "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - return_value=("/fake/stages.yaml", stage_cfgs), + return_value=("/fake/stages.yaml", [stage_cfg]), ) mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment") - mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock()) - mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}) - mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}) + mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=None) mocker.patch( "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage", return_value=(None, None, None), ) - mock_build_vllm_config = mocker.patch( + mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}) + mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}) + mocker.patch( "vllm_omni.engine.stage_init_utils.build_vllm_config", - return_value=(vllm_config, executor_class), + return_value=(vllm_config, object), ) mock_register = mocker.patch( "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master", - return_value="tcp://127.0.0.1:26001", + return_value=StageRegistrationResponse( + handshake_address="tcp://127.0.0.1:26001", + input_address="tcp://127.0.0.1:26002", + output_address="tcp://127.0.0.1:26003", + replica_id=0, + coordinator_router_address="tcp://127.0.0.1:26100", + ), ) - mock_manager_cls = mocker.patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager) - mocker.patch("signal.signal") - run_headless(args) - - mock_build_vllm_config.assert_called_once_with( - stage_cfg, - "fake-model", - stage_connector_spec={}, - engine_args_dict={}, - headless=True, + mock_manager_cls = mocker.patch( + "vllm_omni.engine.omni_core_engine_proc_manager.OmniCoreEngineProcManager", + return_value=engine_manager, ) - mock_register.assert_called_once_with( - omni_master_address="127.0.0.1", - omni_master_port=26000, - omni_stage_id=3, - omni_stage_config=stage_cfg, - coordinator=None, - replica_id=0, - ) - mock_manager_cls.assert_called_once() - manager_kwargs = mock_manager_cls.call_args.kwargs - assert manager_kwargs["local_engine_count"] == 2 - assert manager_kwargs["start_index"] == 4 - assert manager_kwargs["local_start_index"] == 0 - assert manager_kwargs["local_client"] is False - assert manager_kwargs["handshake_address"] == "tcp://127.0.0.1:26001" - assert manager_kwargs["log_stats"] is False + mocker.patch("signal.signal") + + run_headless(_make_headless_args(stage_id=0)) + + # The launcher must request auto-assignment (replica_id=None) and the + # full response so it can wire the master-allocated coordinator into the + # spawned subprocess. ``replica_binds_sockets=False`` is required for LLM + # because the head binds all three sockets (handshake, input, output). + assert mock_register.call_count == 1 + kwargs = mock_register.call_args.kwargs + assert kwargs["omni_master_address"] == "127.0.0.1" + assert kwargs["omni_master_port"] == 26000 + assert kwargs["omni_stage_id"] == 0 + assert kwargs["omni_stage_config"] is stage_cfg + assert kwargs["replica_id"] is None + assert kwargs["return_full_response"] is True + assert kwargs["replica_binds_sockets"] is False + + assert mock_manager_cls.call_count == 1 + mgr_kwargs = mock_manager_cls.call_args.kwargs + assert mgr_kwargs["local_engine_count"] == 1 + assert mgr_kwargs["local_client"] is False + assert mgr_kwargs["handshake_address"] == "tcp://127.0.0.1:26001" + assert mgr_kwargs["omni_stage_id"] == 0 + assert mgr_kwargs["omni_coordinator_address"] == "tcp://127.0.0.1:26100" + assert mgr_kwargs["omni_replica_base_id"] == 0 + engine_manager.monitor_engine_liveness.assert_called_once_with() engine_manager.shutdown.assert_called_once_with() -def test_run_headless_honors_explicit_log_stats_flag(mocker: MockerFixture) -> None: - args = _make_headless_args() - args.log_stats = True - stage_cfg = mocker.Mock(stage_id=3) - stage_cfgs = [stage_cfg] - parallel_config = mocker.Mock( - data_parallel_size_local=2, - data_parallel_rank=4, - data_parallel_rank_local=1, +def test_run_headless_llm_launches_one_manager_per_omni_dp_size_local(mocker: MockerFixture) -> None: + """``--omni-dp-size-local=N`` must spawn N managers, each with its own + master-assigned replica_id, and join all of them before returning.""" + from vllm_omni.engine.stage_engine_startup import StageRegistrationResponse + + stage_cfg = _make_stage_cfg(0, stage_type="llm") + parallel_config = SimpleNamespace( + data_parallel_size_local=1, + data_parallel_rank=0, + data_parallel_rank_local=0, node_rank_within_dp=0, ) - vllm_config = mocker.Mock(parallel_config=parallel_config) - executor_class = mocker.Mock() - engine_manager = mocker.Mock() + vllm_config = SimpleNamespace(parallel_config=parallel_config, needs_dp_coordinator=False) + manager_a = mocker.Mock() + manager_b = mocker.Mock() mocker.patch( "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - return_value=("/fake/stages.yaml", stage_cfgs), + return_value=("/fake/stages.yaml", [stage_cfg]), ) mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment") - mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock()) - mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}) - mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}) + mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=None) mocker.patch( "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage", return_value=(None, None, None), ) + mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}) + mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}) mocker.patch( "vllm_omni.engine.stage_init_utils.build_vllm_config", - return_value=(vllm_config, executor_class), + return_value=(vllm_config, object), ) mocker.patch( "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master", - return_value="tcp://127.0.0.1:26001", + side_effect=[ + StageRegistrationResponse( + handshake_address=f"tcp://127.0.0.1:2700{idx}", + input_address=f"tcp://127.0.0.1:2710{idx}", + output_address=f"tcp://127.0.0.1:2720{idx}", + replica_id=idx, + coordinator_router_address=None, + ) + for idx in (0, 1) + ], + ) + mock_manager_cls = mocker.patch( + "vllm_omni.engine.omni_core_engine_proc_manager.OmniCoreEngineProcManager", + side_effect=[manager_a, manager_b], ) - mock_manager_cls = mocker.patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager) mocker.patch("signal.signal") - run_headless(args) - manager_kwargs = mock_manager_cls.call_args.kwargs - assert manager_kwargs["log_stats"] is True + run_headless(_make_headless_args(stage_id=0, omni_dp_size_local=2)) + assert mock_manager_cls.call_count == 2 + assigned_ids = [call.kwargs["omni_replica_base_id"] for call in mock_manager_cls.call_args_list] + assert assigned_ids == [0, 1] + + # Multi-replica path joins the monitor threads instead of calling + # ``monitor_engine_liveness`` synchronously on the main thread, but every + # manager must still be shut down in the finally block. + manager_a.shutdown.assert_called_once_with() + manager_b.shutdown.assert_called_once_with() -def test_run_headless_registers_llm_replica_id(mocker: MockerFixture) -> None: - args = _make_headless_args() - args.replica_id = 2 - stage_cfg = mocker.Mock(stage_id=3) - parallel_config = mocker.Mock( - data_parallel_size_local=1, - data_parallel_rank=0, - data_parallel_rank_local=0, - node_rank_within_dp=0, - ) - vllm_config = mocker.Mock(parallel_config=parallel_config) - vllm_config.needs_dp_coordinator = False - engine_manager = mocker.Mock() + +def test_run_headless_diffusion_registers_and_spawns_proc(mocker: MockerFixture) -> None: + """Diffusion headless: registers as auto-assign, spawns a single + ``StageDiffusionProc`` per local replica, and waits for it via + ``multiprocessing.connection.wait``.""" + from vllm_omni.engine.stage_engine_startup import StageRegistrationResponse + + stage_cfg = _make_stage_cfg(1, stage_type="diffusion") + od_config = mocker.Mock() + proc = mocker.Mock(sentinel=object(), exitcode=0) + proc.is_alive.return_value = False mocker.patch( "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", return_value=("/fake/stages.yaml", [stage_cfg]), ) mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment") - mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock()) - mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}) - mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}) + mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=None) mocker.patch( "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage", return_value=(None, None, None), ) mocker.patch( - "vllm_omni.engine.stage_init_utils.build_vllm_config", - return_value=(vllm_config, mocker.Mock()), + "vllm_omni.engine.stage_init_utils.extract_stage_metadata", + return_value=SimpleNamespace(stage_id=1, stage_type="diffusion"), ) + mock_inject = mocker.patch("vllm_omni.engine.stage_init_utils.inject_kv_stage_info") + mocker.patch("vllm_omni.engine.stage_init_utils.build_diffusion_config", return_value=od_config) mock_register = mocker.patch( "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master", - return_value="tcp://127.0.0.1:26001", + return_value=StageRegistrationResponse( + handshake_address="tcp://127.0.0.1:26001", + input_address="tcp://127.0.0.1:26002", + output_address="tcp://127.0.0.1:26003", + replica_id=0, + coordinator_router_address="tcp://127.0.0.1:26100", + ), + ) + mock_spawn = mocker.patch( + "vllm_omni.diffusion.stage_diffusion_proc.spawn_diffusion_proc", + return_value=(proc, None, None, None), ) - mocker.patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager) + mock_handshake = mocker.patch("vllm_omni.diffusion.stage_diffusion_proc.complete_diffusion_handshake") + # Replace the blocking wait with one that returns the only proc's sentinel + # immediately so the test does not hang. + mocker.patch( + "multiprocessing.connection.wait", + side_effect=lambda sentinels: [sentinels[0]], + ) + mocker.patch("vllm_omni.engine.stage_init_utils.terminate_alive_proc") mocker.patch("signal.signal") - run_headless(args) + run_headless(_make_headless_args(stage_id=1)) - assert mock_register.call_args.kwargs["replica_id"] == 2 + mock_inject.assert_called_once_with(stage_cfg, 1) + reg_kwargs = mock_register.call_args.kwargs + assert reg_kwargs["omni_master_address"] == "127.0.0.1" + assert reg_kwargs["omni_master_port"] == 26000 + assert reg_kwargs["omni_stage_id"] == 1 + assert reg_kwargs["omni_stage_config"] is stage_cfg + assert reg_kwargs["replica_id"] is None + assert reg_kwargs["return_full_response"] is True -def test_run_headless_launches_diffusion_stage_via_omni_master(mocker: MockerFixture) -> None: - args = _make_headless_args() - args.replica_id = 1 - stage_cfg = mocker.Mock(stage_id=3, stage_type="diffusion") - stage_cfg.engine_args = mocker.Mock() - stage_cfg.engine_input_source = [] - stage_cfgs = [stage_cfg] - metadata = mocker.Mock(stage_id=3, stage_type="diffusion") - od_config = mocker.Mock() - proc = mocker.Mock() - proc.exitcode = 0 + spawn_kwargs = mock_spawn.call_args.kwargs + assert spawn_kwargs["handshake_address"] == "tcp://127.0.0.1:26001" + assert spawn_kwargs["request_address"] == "tcp://127.0.0.1:26002" + assert spawn_kwargs["response_address"] == "tcp://127.0.0.1:26003" + assert spawn_kwargs["omni_coordinator_address"] == "tcp://127.0.0.1:26100" + assert spawn_kwargs["omni_stage_id"] == 1 + assert spawn_kwargs["omni_replica_id"] == 0 + + mock_handshake.assert_called_once_with(proc, "tcp://127.0.0.1:26001", 600) + + +def test_run_headless_diffusion_raises_on_nonzero_proc_exit(mocker: MockerFixture) -> None: + """A diffusion replica that exits with a non-zero code must surface as a + RuntimeError from ``run_headless`` (the head needs the signal to roll + back its own stage init).""" + from vllm_omni.engine.stage_engine_startup import StageRegistrationResponse + + stage_cfg = _make_stage_cfg(1, stage_type="diffusion") + proc = mocker.Mock(sentinel=object(), exitcode=137, name="proc-stage1-rep0") proc.is_alive.return_value = False mocker.patch( "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - return_value=("/fake/stages.yaml", stage_cfgs), + return_value=("/fake/stages.yaml", [stage_cfg]), ) mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment") - mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock()) + mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=None) mocker.patch( "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage", return_value=(None, None, None), ) - mocker.patch("vllm_omni.engine.stage_init_utils.extract_stage_metadata", return_value=metadata) - mock_inject_stage_info = mocker.patch("vllm_omni.engine.stage_init_utils.inject_kv_stage_info") - mocker.patch("vllm_omni.engine.stage_init_utils.build_diffusion_config", return_value=od_config) - mock_register = mocker.patch( + mocker.patch( + "vllm_omni.engine.stage_init_utils.extract_stage_metadata", + return_value=SimpleNamespace(stage_id=1, stage_type="diffusion"), + ) + mocker.patch("vllm_omni.engine.stage_init_utils.inject_kv_stage_info") + mocker.patch("vllm_omni.engine.stage_init_utils.build_diffusion_config", return_value=mocker.Mock()) + mocker.patch( "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master", - return_value=("tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"), + return_value=StageRegistrationResponse( + handshake_address="tcp://127.0.0.1:26001", + input_address="tcp://127.0.0.1:26002", + output_address="tcp://127.0.0.1:26003", + replica_id=0, + coordinator_router_address=None, + ), ) - mock_spawn = mocker.patch( + mocker.patch( "vllm_omni.diffusion.stage_diffusion_proc.spawn_diffusion_proc", - return_value=(proc, "tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"), + return_value=(proc, None, None, None), ) - mock_handshake = mocker.patch("vllm_omni.diffusion.stage_diffusion_proc.complete_diffusion_handshake") + mocker.patch("vllm_omni.diffusion.stage_diffusion_proc.complete_diffusion_handshake") + mocker.patch( + "multiprocessing.connection.wait", + side_effect=lambda sentinels: [sentinels[0]], + ) + mocker.patch("vllm_omni.engine.stage_init_utils.terminate_alive_proc") mocker.patch("signal.signal") - run_headless(args) - mock_inject_stage_info.assert_called_once_with(stage_cfg, 3) - mock_register.assert_called_once_with( - omni_master_address="127.0.0.1", - omni_master_port=26000, - omni_stage_id=3, - omni_stage_config=stage_cfg, - return_addresses=True, - replica_id=1, - ) - mock_spawn.assert_called_once_with( - "fake-model", - od_config, - handshake_address="tcp://127.0.0.1:26001", - request_address="tcp://127.0.0.1:26002", - response_address="tcp://127.0.0.1:26003", - ) - mock_handshake.assert_called_once_with(proc, "tcp://127.0.0.1:26001", 600) - proc.join.assert_called_once_with() + with pytest.raises(RuntimeError, match=r"exited with code 137"): + run_headless(_make_headless_args(stage_id=1)) diff --git a/vllm_omni/diffusion/stage_diffusion_client.py b/vllm_omni/diffusion/stage_diffusion_client.py index 476b44b5bb2..6ab44b4b0c6 100644 --- a/vllm_omni/diffusion/stage_diffusion_client.py +++ b/vllm_omni/diffusion/stage_diffusion_client.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any import zmq +import zmq.asyncio from vllm.logger import init_logger from vllm.v1.engine.exceptions import EngineDeadError @@ -124,6 +125,11 @@ def _initialize_client( self.engine_input_source = getattr(metadata, "engine_input_source", []) self._proc = proc self._owns_process = proc is not None + # Expose the ZMQ addresses on the instance so callers (e.g. + # ``StagePool._client_input_addr``) can identify the diffusion + # replica by its bound address. + self.request_address = request_address + self.response_address = response_address self._zmq_ctx = zmq.Context() self._request_socket = self._zmq_ctx.socket(zmq.PUSH) @@ -131,6 +137,9 @@ def _initialize_client( self._response_socket = self._zmq_ctx.socket(zmq.PULL) self._response_socket.connect(response_address) + self._response_poller = zmq.asyncio.Poller() + self._response_poller.register(self._response_socket, zmq.POLLIN) + self._encoder = OmniMsgpackEncoder() self._decoder = OmniMsgpackDecoder() @@ -468,17 +477,27 @@ async def collective_rpc_async( try: while True: self._drain_responses() - if rpc_id in self._rpc_results: - return self._rpc_results.pop(rpc_id) + result = self._rpc_results.pop(rpc_id, None) + if result is not None: + return result if self._engine_dead or (self._owns_process and self._proc is not None and not self._proc.is_alive()): self._engine_dead = True raise EngineDeadError( f"StageDiffusionProc died while waiting for " f"collective_rpc '{method}' (exit code {self._proc.exitcode})" ) - if deadline and time.monotonic() > deadline: + if deadline is not None and time.monotonic() > deadline: raise TimeoutError(f"collective_rpc_async '{method}' timed out after {timeout}s") - await asyncio.sleep(0.01) + # Block (async) until data arrives on the ZMQ response + # socket or until the timeout expires, then loop back to + # drain and check. + if deadline is not None: + poll_timeout_ms = max(int((deadline - time.monotonic()) * 1000), 0) + else: + poll_timeout_ms = 100 + # no exception raised on timeout (capped at 100ms so the + # engine-dead check still fires regularly). + await self._response_poller.poll(timeout=min(poll_timeout_ms, 100)) finally: self._pending_rpcs.discard(rpc_id) diff --git a/vllm_omni/diffusion/stage_diffusion_proc.py b/vllm_omni/diffusion/stage_diffusion_proc.py index 2f3ad55943a..871a29729f2 100644 --- a/vllm_omni/diffusion/stage_diffusion_proc.py +++ b/vllm_omni/diffusion/stage_diffusion_proc.py @@ -7,6 +7,7 @@ from __future__ import annotations import asyncio +import contextlib import signal import time from concurrent.futures import ThreadPoolExecutor @@ -30,6 +31,7 @@ OmniMsgpackDecoder, OmniMsgpackEncoder, ) +from vllm_omni.distributed.omni_coordinator import OmniCoordClientForStage from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -54,6 +56,56 @@ def __init__(self, model: str, od_config: OmniDiffusionConfig) -> None: self._engine: DiffusionEngine | None = None self._executor: ThreadPoolExecutor | None = None self._closed = False + # Set by ``run_loop`` to the live dispatch task dict so + # :attr:`queue_length` can report in-flight requests for the + # OmniCoordinator heartbeat hook. + self._active_tasks: dict[str, asyncio.Task] | None = None + # Set when a request-handler detects that the engine's multiproc + # executor has died (e.g. a worker process crashed and the executor's + # monitor thread closed it). Once set, ``run_loop`` breaks out so the + # outer ``except``/``finally`` can send ``DIFFUSION_PROC_DEAD`` and + # ``ReplicaStatus.DOWN``, then the subprocess exits non-zero. Without + # this, the run_loop would swallow per-request errors and keep + # serving 500s indefinitely while heartbeats still report UP. + self._fatal_event: asyncio.Event | None = None + + @property + def queue_length(self) -> int: + """Number of in-flight diffusion requests. + + Returns 0 before :meth:`run_loop` starts and after it exits. + """ + tasks = self._active_tasks + return 0 if tasks is None else len(tasks) + + def _is_executor_dead(self) -> bool: + """True iff the multiproc executor has been closed or marked failed. + + Detects the "workers died but the diffusion proc is still pulling + requests" case: ``MultiprocDiffusionExecutor`` sets ``_closed = True`` + and ``is_failed = True`` from its worker-monitor thread the moment any + worker process exits; every subsequent ``execute_request`` / + ``collective_rpc`` then raises ``RuntimeError("DiffusionExecutor is + closed.")`` inside the engine. Callers in ``run_loop`` use this to + decide whether a per-request failure is recoverable or fatal. + """ + if self._engine is None: + return False + executor = getattr(self._engine, "executor", None) + if executor is None: + return False + return bool(getattr(executor, "_closed", False) or getattr(executor, "is_failed", False)) + + def _signal_fatal_engine_failure(self, reason: str) -> None: + """Idempotently signal ``run_loop`` to tear down on a fatal engine error.""" + if self._fatal_event is None or self._fatal_event.is_set(): + return + logger.error( + "[StageDiffusionProc] fatal engine failure detected (%s); " + "signaling run_loop to send DIFFUSION_PROC_DEAD and exit.", + reason, + ) + self._fatal_event.set() # ------------------------------------------------------------------ # Initialization @@ -309,6 +361,14 @@ async def run_loop( decoder = OmniMsgpackDecoder() tasks: dict[str, asyncio.Task] = {} + # Expose the live task dict so :attr:`queue_length` (used by the + # OmniCoordinator heartbeat hook) can read the in-flight count. + self._active_tasks = tasks + # Wakes the main recv loop when a request-handler detects a fatal + # engine failure so we tear down promptly instead of swallowing + # "DiffusionExecutor is closed" on every subsequent request. + fatal_event = asyncio.Event() + self._fatal_event = fatal_event async def _dispatch_request( request_id: str, @@ -342,12 +402,40 @@ async def _dispatch_request( } ) ) + # Per-request errors are usually recoverable, but a closed + # executor means every future request will get the same + # "DiffusionExecutor is closed" error. Signal the main loop + # to send DIFFUSION_PROC_DEAD and exit so the head's hub + # demotes this replica instead of waiting on the heartbeat + # timeout (~30 s by default). + if self._is_executor_dead(): + self._signal_fatal_engine_failure(f"add_request {request_id}: {e!s}") finally: tasks.pop(request_id, None) try: while True: - raw = await request_socket.recv() + # Await recv and fatal_event concurrently so the loop wakes + # up immediately when a per-request handler signals a fatal + # engine failure — even if no fresh ZMQ frame arrives. + recv_task: asyncio.Task = asyncio.ensure_future(request_socket.recv()) + fatal_task: asyncio.Task = asyncio.ensure_future(fatal_event.wait()) + try: + done, pending = await asyncio.wait( + [recv_task, fatal_task], + return_when=asyncio.FIRST_COMPLETED, + ) + finally: + for waiter in (recv_task, fatal_task): + if not waiter.done(): + waiter.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await waiter + if fatal_event.is_set(): + raise RuntimeError( + "StageDiffusionProc executor reported permanent failure; tearing down the diffusion subprocess." + ) + raw = recv_task.result() msg = decoder.decode(raw) msg_type = msg.get("type") @@ -397,6 +485,11 @@ async def _dispatch_batch( } ) ) + # Same rationale as the single-request path: a + # closed executor turns every subsequent batch + # into a 500, so escalate now. + if self._is_executor_dead(): + self._signal_fatal_engine_failure(f"add_batch_request {rid}: {e!s}") finally: tasks.pop(rid, None) @@ -446,6 +539,13 @@ async def _dispatch_batch( } ) ) + # Collective RPCs run through the same multiproc + # executor — a closed executor means every future + # RPC fails the same way, so tear down promptly. + if self._is_executor_dead(): + self._signal_fatal_engine_failure( + f"collective_rpc {msg['method']} (rpc_id={rpc_id}): {e!s}" + ) elif msg_type == "shutdown": break @@ -466,6 +566,8 @@ async def _dispatch_batch( if tasks: await asyncio.gather(*tasks.values(), return_exceptions=True) + self._active_tasks = None + self._fatal_event = None request_socket.close() response_socket.close() ctx.term() @@ -504,8 +606,22 @@ def run_diffusion_proc( handshake_address: str, request_address: str, response_address: str, + *, + omni_coordinator_address: str | None = None, + omni_stage_id: int | None = None, + omni_replica_id: int = 0, ) -> None: - """Entry point for the diffusion subprocess.""" + """Entry point for the diffusion subprocess. + + Omni-specific kwargs (mirroring :meth:`StageEngineCoreProc.run_stage_core`): + - ``omni_coordinator_address``: ROUTER address of the head-side + OmniCoordinator. When set, a :class:`OmniCoordClientForStage` + reports the diffusion replica's status + queue length. + - ``omni_stage_id``: logical stage id; required when + ``omni_coordinator_address`` is set. + - ``omni_replica_id``: cluster-unique replica id within the + stage (logging / metrics only). + """ shutdown_requested = False def signal_handler(signum: int, frame: Any) -> None: @@ -518,6 +634,7 @@ def signal_handler(signum: int, frame: Any) -> None: signal.signal(signal.SIGINT, signal_handler) proc = cls(model, od_config) + coord_client: OmniCoordClientForStage | None = None try: proc.initialize() @@ -529,6 +646,32 @@ def signal_handler(signum: int, frame: Any) -> None: handshake_socket.close() handshake_ctx.term() + # Wire OmniCoordClientForStage *after* READY so that the head + # has bound its head-side request/response sockets — the + # address pair we report is the same pair this proc binds to + # (request/response addresses passed in). + if omni_coordinator_address is not None: + if omni_stage_id is None: + raise ValueError("omni_stage_id must be provided when omni_coordinator_address is set") + coord_client = OmniCoordClientForStage( + coord_zmq_addr=omni_coordinator_address, + input_addr=request_address, + output_addr=response_address, + stage_id=int(omni_stage_id), + ) + + def _refresh_queue_length() -> None: + coord_client._queue_length = proc.queue_length # type: ignore[union-attr] + + coord_client._on_heartbeat = _refresh_queue_length + + logger.info( + "StageDiffusionProc registered with OmniCoordinator (stage_id=%d replica_id=%d coord=%s)", + omni_stage_id, + omni_replica_id, + omni_coordinator_address, + ) + # Run async event loop asyncio.run(proc.run_loop(request_address, response_address)) @@ -539,6 +682,9 @@ def signal_handler(signum: int, frame: Any) -> None: logger.exception("StageDiffusionProc encountered a fatal error.") raise finally: + if coord_client is not None: + with contextlib.suppress(RuntimeError): + coord_client.close() proc.close() @@ -551,10 +697,17 @@ def spawn_diffusion_proc( handshake_address: str | None = None, request_address: str | None = None, response_address: str | None = None, + *, + omni_coordinator_address: str | None = None, + omni_stage_id: int | None = None, + omni_replica_id: int = 0, ) -> tuple[BaseProcess, str, str, str]: """Spawn a StageDiffusionProc subprocess. Returns ``(proc, handshake_address, request_address, response_address)``. + + Pass ``omni_coordinator_address`` / ``omni_stage_id`` / ``omni_replica_id`` + to have the subprocess publish heartbeats to an OmniCoordinator. """ handshake_address = handshake_address or get_open_zmq_ipc_path() request_address = request_address or get_open_zmq_ipc_path() @@ -570,6 +723,9 @@ def spawn_diffusion_proc( "handshake_address": handshake_address, "request_address": request_address, "response_address": response_address, + "omni_coordinator_address": omni_coordinator_address, + "omni_stage_id": omni_stage_id, + "omni_replica_id": omni_replica_id, }, ) proc.start() diff --git a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py index 7fdd66cdbd3..241a6070ddc 100644 --- a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py +++ b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py @@ -21,6 +21,7 @@ build_rank_aware_send_keys, get_kv_target_ranks, get_local_tp_rank, + get_omni_replica_id, get_tp_world_size, kv_zmq_port, merge_received_rank_shards, @@ -415,7 +416,13 @@ def connector(self): stage_int = int(self.config.from_stage) if self.config.from_stage is not None else 0 except (TypeError, ValueError): stage_int = 0 - zmq_port = kv_zmq_port(base_port, stage_int, self._tp_topo.local_rank) + replica_id = get_omni_replica_id() + zmq_port = kv_zmq_port( + base_port, + stage_int, + self._tp_topo.local_rank, + replica_id=replica_id, + ) if self.config.need_send_cache: c_extra["role"] = "sender" diff --git a/vllm_omni/distributed/omni_connectors/utils/initialization.py b/vllm_omni/distributed/omni_connectors/utils/initialization.py index e6d68f9bcee..a8cafba6159 100644 --- a/vllm_omni/distributed/omni_connectors/utils/initialization.py +++ b/vllm_omni/distributed/omni_connectors/utils/initialization.py @@ -26,9 +26,17 @@ # Port stride between TP ranks so each worker binds a unique ZMQ port # when TP > 1. Must be larger than the maximum number of pipeline stages. -# Formula: zmq_port = base + KV_TRANSFER_PORT_OFFSET + rank * STRIDE + stage +# Formula: +# zmq_port = base + KV_TRANSFER_PORT_OFFSET +# + replica * KV_REPLICA_PORT_STRIDE +# + rank * KV_RANK_PORT_STRIDE +# + stage KV_RANK_PORT_STRIDE = 16 +# Port stride between Omni replicas of the same stage. This reserves a +# comfortably sized block per replica for TP-rank and stage offsets. +KV_REPLICA_PORT_STRIDE = 1024 + def initialize_connectors_from_config( config_path: str | Path | None = None, diff --git a/vllm_omni/distributed/omni_connectors/utils/kv_utils.py b/vllm_omni/distributed/omni_connectors/utils/kv_utils.py index 12b9b3d4f77..fe0df21f25b 100644 --- a/vllm_omni/distributed/omni_connectors/utils/kv_utils.py +++ b/vllm_omni/distributed/omni_connectors/utils/kv_utils.py @@ -16,7 +16,7 @@ ) from vllm.logger import init_logger -from .initialization import KV_RANK_PORT_STRIDE, KV_TRANSFER_PORT_OFFSET +from .initialization import KV_RANK_PORT_STRIDE, KV_REPLICA_PORT_STRIDE, KV_TRANSFER_PORT_OFFSET logger = init_logger(__name__) @@ -94,20 +94,42 @@ def get_tp_world_size() -> int: return 1 +def get_omni_replica_id() -> int: + """Return the Omni replica id for this worker process.""" + try: + replica_id = int(os.environ.get("VLLM_OMNI_REPLICA_ID", "0")) + except (ValueError, TypeError): + return 0 + return max(replica_id, 0) + + # ------------------------------------------------------------------ # # ZMQ port computation # ------------------------------------------------------------------ # -def kv_zmq_port(base_port: int, from_stage: int, local_rank: int = 0) -> int: +def kv_zmq_port( + base_port: int, + from_stage: int, + local_rank: int = 0, + replica_id: int | None = None, +) -> int: """Compute the ZMQ port for a KV-transfer connector. - Each TP rank gets its own port so that TP > 1 deployments do not - cause ``EADDRINUSE`` when multiple sender workers bind on the same - host. The formula is backward-compatible: rank 0 produces the same - port as the previous ``base + OFFSET + stage`` formula. + Each Omni replica and TP rank gets its own port so multi-replica or + TP > 1 deployments do not cause ``EADDRINUSE`` when multiple sender + workers bind on the same host. The formula is backward-compatible: + replica 0 / rank 0 produces the previous ``base + OFFSET + stage`` port. + """ - return base_port + KV_TRANSFER_PORT_OFFSET + local_rank * KV_RANK_PORT_STRIDE + from_stage + replica = get_omni_replica_id() if replica_id is None else max(int(replica_id), 0) + return ( + base_port + + KV_TRANSFER_PORT_OFFSET + + replica * KV_REPLICA_PORT_STRIDE + + local_rank * KV_RANK_PORT_STRIDE + + from_stage + ) # ------------------------------------------------------------------ # diff --git a/vllm_omni/distributed/omni_coordinator/__init__.py b/vllm_omni/distributed/omni_coordinator/__init__.py index 6894e311378..326fb9b9a16 100644 --- a/vllm_omni/distributed/omni_coordinator/__init__.py +++ b/vllm_omni/distributed/omni_coordinator/__init__.py @@ -8,18 +8,21 @@ RandomBalancer, RoundRobinBalancer, Task, + build_load_balancer_factory, ) -from .messages import InstanceEvent, InstanceInfo, InstanceList, StageStatus +from .messages import ReplicaEvent, ReplicaInfo, ReplicaList, ReplicaStatus from .omni_coord_client_for_hub import OmniCoordClientForHub from .omni_coord_client_for_stage import OmniCoordClientForStage from .omni_coordinator import OmniCoordinator +from .runtime import OmniCoordinatorRuntime __all__ = [ "OmniCoordinator", - "StageStatus", - "InstanceEvent", - "InstanceInfo", - "InstanceList", + "OmniCoordinatorRuntime", + "ReplicaStatus", + "ReplicaEvent", + "ReplicaInfo", + "ReplicaList", "OmniCoordClientForStage", "OmniCoordClientForHub", "Task", @@ -28,4 +31,5 @@ "RandomBalancer", "RoundRobinBalancer", "LeastQueueLengthBalancer", + "build_load_balancer_factory", ] diff --git a/vllm_omni/distributed/omni_coordinator/load_balancer.py b/vllm_omni/distributed/omni_coordinator/load_balancer.py index 41b03be1630..bcd9822ec69 100644 --- a/vllm_omni/distributed/omni_coordinator/load_balancer.py +++ b/vllm_omni/distributed/omni_coordinator/load_balancer.py @@ -6,17 +6,18 @@ import random import threading from abc import ABC, abstractmethod +from collections.abc import Callable from enum import Enum from typing import Any, TypedDict -from .messages import InstanceInfo +from .messages import ReplicaInfo class Task(TypedDict, total=False): - """Task structure passed from async_omni (stage.submit(task)). + """Task structure passed to ``StagePool.pick`` / ``LoadBalancer.select``. - Mirrors the dict built in AsyncOmni with request_id, engine_inputs, - sampling_params. Future load-balancing policies may use these fields. + Mirrors the dict built around a stage submission with request_id and any + payload-related fields a future load-balancing policy might inspect. """ request_id: str @@ -28,10 +29,7 @@ class LoadBalancingPolicy(str, Enum): """Enumeration for load balancing policies. These policies are used by :class:`LoadBalancer` implementations to route - tasks to a subset of available instances. - - TODO(NumberWan): Map enum values to balancer classes when OmniCoordinator - integration lands. Tracked in https://github.com/vllm-project/vllm-omni/pull/2448 + tasks to a subset of available replicas. """ RANDOM = "random" @@ -42,74 +40,60 @@ class LoadBalancingPolicy(str, Enum): class LoadBalancer(ABC): """Abstract base class for load balancers. - Subclasses implement :meth:`select` to choose an instance for a given task. + Subclasses implement :meth:`select` to choose a replica for a given task. """ @abstractmethod - def select(self, task: Task, instances: list[InstanceInfo]) -> int: - """Route a task to one of the available instances. + def select(self, task: Task, replicas: list[ReplicaInfo]) -> int: + """Route a task to one of the available replicas. Args: task: The task to route. Not used by the random policy but reserved for future strategies that may inspect task metadata. - instances: List of available instances to choose from. + replicas: List of available replicas to choose from. Returns: - Index of the selected instance in ``instances``. + Index of the selected replica in ``replicas``. Raises: - ValueError: If ``instances`` is empty. + ValueError: If ``replicas`` is empty. """ raise NotImplementedError class RandomBalancer(LoadBalancer): - """Load balancer that selects an instance uniformly at random. - - It intentionally ignores the task payload and chooses a random index from - the provided instance list. More sophisticated policies (e.g. round-robin, - least-queue-length) can be implemented as additional subclasses of - :class:`LoadBalancer`. - """ + """Load balancer that selects a replica uniformly at random.""" - def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002 - if not instances: - raise ValueError("instances must not be empty") + def select(self, task: Task, replicas: list[ReplicaInfo]) -> int: # noqa: ARG002 + if not replicas: + raise ValueError("replicas must not be empty") - return random.randrange(len(instances)) + return random.randrange(len(replicas)) class RoundRobinBalancer(LoadBalancer): - """Load balancer that selects instances in a round-robin fashion. + """Load balancer that selects replicas in a round-robin fashion. - This implementation keeps a running index modulo ``len(instances)``. It - therefore depends on the **order and stable meaning** of the ``instances`` + This implementation keeps a running index modulo ``len(replicas)``. It + therefore depends on the **order and stable meaning** of the ``replicas`` list between calls. If the list length or ordering changes, the sequence of picks may skip or repeat entries relative to a fixed set of backends. - When instance membership changes dynamically, callers should reset routing - state—for example by constructing a new ``RoundRobinBalancer`` or resetting - ``_next_index``—similar to rebuilding ``itertools.cycle`` after mutating - the instance list (as in vLLM's disaggregated proxy examples). - - Concurrency: ``select`` is synchronous and is expected to run on the - coordinator asyncio event loop thread without ``await`` inside this - method, so a single invocation is not interleaved with another on that - thread. A :class:`threading.Lock` still serializes updates to - ``_next_index`` for callers that might invoke ``select`` from multiple - threads or alongside threaded infrastructure (e.g. ZMQ receive threads). + Concurrency: a ``threading.Lock`` serializes updates to ``_next_index`` + for callers that invoke ``select`` from multiple threads or alongside + threaded infrastructure (e.g. ZMQ receive threads). """ def __init__(self, start_index: int = 0) -> None: self._next_index = start_index self._lock = threading.Lock() - def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002 - if not instances: - raise ValueError("instances must not be empty") + def select(self, task: Task, replicas: list[ReplicaInfo]) -> int: # noqa: ARG002 + if not replicas: + raise ValueError("replicas must not be empty") - n = len(instances) + n = len(replicas) with self._lock: idx = self._next_index % n self._next_index = (self._next_index + 1) % n @@ -117,27 +101,43 @@ def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG class LeastQueueLengthBalancer(LoadBalancer): - """Select the instance with the smallest ``queue_length``. + """Select the replica with the smallest ``queue_length``. - If multiple instances share the same minimum queue length, one of them is + If multiple replicas share the same minimum queue length, one of them is chosen uniformly at random. Raises: - ValueError: If any instance has a negative ``queue_length``. + ValueError: If any replica has a negative ``queue_length``. """ - def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002 - if not instances: - raise ValueError("instances must not be empty") + def select(self, task: Task, replicas: list[ReplicaInfo]) -> int: # noqa: ARG002 + if not replicas: + raise ValueError("replicas must not be empty") - queue_lengths = [inst.queue_length for inst in instances] + queue_lengths = [rep.queue_length for rep in replicas] if any(q < 0 for q in queue_lengths): - raise ValueError("queue_length must be non-negative for all instances") + raise ValueError("queue_length must be non-negative for all replicas") min_q = min(queue_lengths) candidates = [i for i, q in enumerate(queue_lengths) if q == min_q] return random.choice(candidates) +def build_load_balancer_factory(policy: str) -> Callable[[], LoadBalancer]: + """Translate ``--omni-lb-policy`` (string) into a per-pool LB factory.""" + try: + normalized = LoadBalancingPolicy(policy) + except ValueError as exc: + valid = ", ".join(p.value for p in LoadBalancingPolicy) + raise ValueError(f"unknown --omni-lb-policy {policy!r} (valid: {valid})") from exc + if normalized is LoadBalancingPolicy.RANDOM: + return RandomBalancer + if normalized is LoadBalancingPolicy.ROUND_ROBIN: + return RoundRobinBalancer + if normalized is LoadBalancingPolicy.LEAST_QUEUE_LENGTH: + return LeastQueueLengthBalancer + raise ValueError(f"unhandled load balancing policy {normalized!r}") + + __all__ = [ "Task", "LoadBalancingPolicy", @@ -145,4 +145,5 @@ def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG "RandomBalancer", "RoundRobinBalancer", "LeastQueueLengthBalancer", + "build_load_balancer_factory", ] diff --git a/vllm_omni/distributed/omni_coordinator/messages.py b/vllm_omni/distributed/omni_coordinator/messages.py index 2bb590139e2..824475195f2 100644 --- a/vllm_omni/distributed/omni_coordinator/messages.py +++ b/vllm_omni/distributed/omni_coordinator/messages.py @@ -7,55 +7,55 @@ from enum import Enum -class StageStatus(str, Enum): - """Enumeration for stage instance status.""" +class ReplicaStatus(str, Enum): + """Enumeration for stage replica status.""" - UP = "up" # Instance is ready and available - DOWN = "down" # Instance is shutdown gracefully - ERROR = "error" # Instance encountered an error or timeout + UP = "up" # Replica is ready and available + DOWN = "down" # Replica is shutdown gracefully + ERROR = "error" # Replica encountered an error or timeout @dataclass -class InstanceEvent: +class ReplicaEvent: """Wire payload from OmniCoordClientForStage to OmniCoordinator. Schema for Stage → Coordinator events over ZMQ: input_addr, output_addr, stage_id, status, queue_length, event_type. """ - input_addr: str # Stage instance input ZMQ address (e.g., "tcp://host:port") - output_addr: str # Stage instance output ZMQ address (e.g., "tcp://host:port") + input_addr: str # Stage replica input ZMQ address (e.g., "tcp://host:port") + output_addr: str # Stage replica output ZMQ address (e.g., "tcp://host:port") stage_id: int # Stage ID event_type: str # "update" | "heartbeat" - status: StageStatus # Current status + status: ReplicaStatus # Current status queue_length: int # Current queue length @dataclass -class InstanceInfo: - """Metadata for a single stage instance. +class ReplicaInfo: + """Metadata for a single stage replica. This type is stored in OmniCoordinator's internal registry and is also - published to hubs via :class:`InstanceList`. + published to hubs via :class:`ReplicaList`. """ - input_addr: str # Stage instance input ZMQ address (e.g., "tcp://host:port") - output_addr: str # Stage instance output ZMQ address (e.g., "tcp://host:port") - stage_id: int # Stage ID of this instance - status: StageStatus # Current status of the instance - queue_length: int # Current queue length of this instance + input_addr: str # Stage replica input ZMQ address (e.g., "tcp://host:port") + output_addr: str # Stage replica output ZMQ address (e.g., "tcp://host:port") + stage_id: int # Stage ID of this replica + status: ReplicaStatus # Current status of the replica + queue_length: int # Current queue length of this replica last_heartbeat: float # Timestamp of the last heartbeat received (seconds) - registered_at: float # Timestamp when the instance was registered (seconds) + registered_at: float # Timestamp when the replica was registered (seconds) @dataclass -class InstanceList: - """Container for instance list updates. +class ReplicaList: + """Container for replica list updates. - OmniCoordinator publishes an :class:`InstanceList` whenever its view of - active instances changes. OmniCoordClientForHub caches the latest value + OmniCoordinator publishes a :class:`ReplicaList` whenever its view of + active replicas changes. OmniCoordClientForHub caches the latest value and exposes it to AsyncOmni and the load balancer. """ - instances: list[InstanceInfo] + replicas: list[ReplicaInfo] timestamp: float # Time when the list was last updated (seconds) diff --git a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_hub.py b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_hub.py index 9081e45917c..f4b6e5f80fe 100644 --- a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_hub.py +++ b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_hub.py @@ -9,16 +9,16 @@ import zmq -from .messages import InstanceInfo, InstanceList, StageStatus +from .messages import ReplicaInfo, ReplicaList, ReplicaStatus logger = logging.getLogger(__name__) class OmniCoordClientForHub: - """Client for AsyncOmni side to receive instance list updates. + """Client for AsyncOmni side to receive replica list updates. This client maintains a SUB socket connected to OmniCoordinator's PUB - endpoint and caches the latest :class:`InstanceList` in memory for use by + endpoint and caches the latest :class:`ReplicaList` in memory for use by the load balancer and routing logic. """ @@ -28,7 +28,7 @@ def __init__(self, coord_zmq_addr: str) -> None: self._ctx = zmq.Context() self._lock = threading.Lock() - self._instance_list: InstanceList | None = None + self._replica_list: ReplicaList | None = None self._closed = False self._stop_event = threading.Event() self._init_done = threading.Event() @@ -41,29 +41,29 @@ def __init__(self, coord_zmq_addr: str) -> None: if self._init_error: raise RuntimeError(f"Failed to connect to coordinator at {self._coord_zmq_addr}") from self._init_error[0] - def _decode_instance_list(self, payload: dict[str, Any]) -> InstanceList: - """Convert a JSON-decoded dict into an :class:`InstanceList`.""" - instances_payload = payload.get("instances", []) - instances: list[InstanceInfo] = [] - for inst in instances_payload: - instances.append( - InstanceInfo( - input_addr=inst["input_addr"], - output_addr=inst["output_addr"], - stage_id=int(inst["stage_id"]), - status=StageStatus(inst["status"]), - queue_length=int(inst["queue_length"]), - last_heartbeat=float(inst["last_heartbeat"]), - registered_at=float(inst["registered_at"]), + def _decode_replica_list(self, payload: dict[str, Any]) -> ReplicaList: + """Convert a JSON-decoded dict into a :class:`ReplicaList`.""" + replicas_payload = payload.get("replicas", []) + replicas: list[ReplicaInfo] = [] + for rep in replicas_payload: + replicas.append( + ReplicaInfo( + input_addr=rep["input_addr"], + output_addr=rep["output_addr"], + stage_id=int(rep["stage_id"]), + status=ReplicaStatus(rep["status"]), + queue_length=int(rep["queue_length"]), + last_heartbeat=float(rep["last_heartbeat"]), + registered_at=float(rep["registered_at"]), ) ) timestamp = float(payload.get("timestamp", time())) - return InstanceList(instances=instances, timestamp=timestamp) + return ReplicaList(replicas=replicas, timestamp=timestamp) def _recv_loop(self) -> None: - """Background loop that receives and caches instance lists.""" - sub = None + """Background loop that receives and caches replica lists.""" + sub: zmq.Socket | None = None try: sub = self._ctx.socket(zmq.SUB) sub.setsockopt(zmq.SUBSCRIBE, b"") @@ -115,9 +115,9 @@ def _recv_loop(self) -> None: try: payload = json.loads(data.decode("utf-8")) - inst_list = self._decode_instance_list(payload) + rep_list = self._decode_replica_list(payload) with self._lock: - self._instance_list = inst_list + self._replica_list = rep_list except ( json.JSONDecodeError, KeyError, @@ -125,7 +125,7 @@ def _recv_loop(self) -> None: TypeError, AttributeError, ) as e: - logger.warning("Invalid instance list message, skipping: %s", e) + logger.warning("Invalid replica list message, skipping: %s", e) finally: try: if sub is not None: @@ -137,22 +137,22 @@ def _recv_loop(self) -> None: except zmq.ZMQError: pass - def get_instance_list(self) -> InstanceList: - """Return the latest cached :class:`InstanceList`. + def get_replica_list(self) -> ReplicaList: + """Return the latest cached :class:`ReplicaList`. If no update has been received yet, returns an empty list with ``timestamp=0.0``. """ with self._lock: - if self._instance_list is None: - return InstanceList(instances=[], timestamp=0.0) - return self._instance_list - - def get_instances_for_stage(self, stage_id: int) -> InstanceList: - """Return instances filtered by ``stage_id``.""" - base = self.get_instance_list() - filtered = [inst for inst in base.instances if inst.stage_id == stage_id] - return InstanceList(instances=filtered, timestamp=base.timestamp) + if self._replica_list is None: + return ReplicaList(replicas=[], timestamp=0.0) + return self._replica_list + + def get_replicas_for_stage(self, stage_id: int) -> ReplicaList: + """Return replicas filtered by ``stage_id``.""" + base = self.get_replica_list() + filtered = [rep for rep in base.replicas if rep.stage_id == stage_id] + return ReplicaList(replicas=filtered, timestamp=base.timestamp) def close(self) -> None: """Close the SUB socket and stop the background thread.""" diff --git a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py index cd3c99ab812..aee25e584a6 100644 --- a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py +++ b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py @@ -1,24 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import json import logging import threading import time +from collections.abc import Callable from dataclasses import asdict import zmq -from .messages import InstanceEvent, StageStatus +from .messages import ReplicaEvent, ReplicaStatus logger = logging.getLogger(__name__) class OmniCoordClientForStage: - """Client used by stage instances to send events to OmniCoordinator. + """Client used by stage replicas to send events to OmniCoordinator. This client maintains a DEALER socket connected to OmniCoordinator's - ROUTER endpoint and sends JSON-encoded events describing instance status. + ROUTER endpoint and sends JSON-encoded events describing replica status. """ def __init__( @@ -42,13 +44,18 @@ def __init__( self._socket.close() raise RuntimeError(f"Failed to connect to coordinator at {self._coord_zmq_addr}: {e}") from e - self._status = StageStatus.UP + self._status = ReplicaStatus.UP self._queue_length = 0 self._closed = False self._closing = False self._heartbeat_interval = 5.0 self._stop_event = threading.Event() self._send_lock = threading.RLock() + # Optional hook invoked from the heartbeat thread before each + # heartbeat send. Stages set this to refresh ``queue_length`` (or any + # other field) just-in-time. Exceptions raised by the hook are + # suppressed and logged. + self._on_heartbeat: Callable[[], None] | None = None self._send_event("update") @@ -100,11 +107,13 @@ def _reconnect(self, max_retries: int = 3, retry_interval: float = 5.0) -> bool: return False def _send_event(self, event_type: str) -> None: - """Send an InstanceEvent to OmniCoordinator. + """Send a ReplicaEvent to OmniCoordinator. Wire format: input_addr, output_addr, stage_id, status, queue_length, event_type. - For "update": includes status and queue_length from instance state. - For "heartbeat": status and queue_length are null. + For "update": includes status and queue_length from replica state. + For "heartbeat": includes the latest queue_length (refreshed by the + optional ``_on_heartbeat`` hook) so the coordinator can propagate + live load to load balancers between explicit ``update`` events. On send failure (ZMQError / RuntimeError), attempts to reconnect up to 3 times (5s sleep each) and retries the send once after a @@ -114,7 +123,7 @@ def _send_event(self, event_type: str) -> None: if self._closed: raise RuntimeError("Client already closed") - event = InstanceEvent( + event = ReplicaEvent( input_addr=self._input_addr, output_addr=self._output_addr, stage_id=self._stage_id, @@ -147,10 +156,10 @@ def _send_event(self, event_type: str) -> None: def update_info( self, - status: StageStatus | None = None, + status: ReplicaStatus | None = None, queue_length: int | None = None, ) -> None: - """Update instance information and notify OmniCoordinator. + """Update replica information and notify OmniCoordinator. At least one of ``status`` or ``queue_length`` must be provided. """ @@ -174,6 +183,15 @@ def _heartbeat_loop(self) -> None: if self._closed: break + # Invoke the optional pre-heartbeat hook so callers (e.g. the + # engine subprocess) can refresh ``queue_length`` from live state + # before the heartbeat is sent. Exceptions are swallowed so a + # buggy hook never breaks the heartbeat loop. + hook = self._on_heartbeat + if hook is not None: + with contextlib.suppress(Exception): + hook() + try: self._send_event("heartbeat") except (RuntimeError, zmq.ZMQError) as e: @@ -199,7 +217,7 @@ def close(self) -> None: self._closing = True # Mark status as DOWN and send one last update. - self._status = StageStatus.DOWN + self._status = ReplicaStatus.DOWN try: self._send_event("update") except (RuntimeError, zmq.ZMQError): diff --git a/vllm_omni/distributed/omni_coordinator/omni_coordinator.py b/vllm_omni/distributed/omni_coordinator/omni_coordinator.py index 2c7c8fbb995..8f75f841dbf 100644 --- a/vllm_omni/distributed/omni_coordinator/omni_coordinator.py +++ b/vllm_omni/distributed/omni_coordinator/omni_coordinator.py @@ -11,22 +11,22 @@ import zmq -from .messages import InstanceEvent, InstanceInfo, InstanceList, StageStatus +from .messages import ReplicaEvent, ReplicaInfo, ReplicaList, ReplicaStatus logger = logging.getLogger(__name__) class OmniCoordinator: - """Coordinator for stage instances and hub clients. + """Coordinator for stage replicas and hub clients. - This service receives instance events from :class:`OmniCoordClientForStage` - via a ZMQ ROUTER socket and publishes active instance lists to + This service receives replica events from :class:`OmniCoordClientForStage` + via a ZMQ ROUTER socket and publishes active replica lists to :class:`OmniCoordClientForHub` via a PUB socket. - The coordinator maintains an in-memory registry of all known instances, + The coordinator maintains an in-memory registry of all known replicas, including their status, queue length, and heartbeat timestamps. A background thread periodically checks for heartbeat timeouts and marks - unhealthy instances as ``StageStatus.ERROR``. + unhealthy replicas as ``ReplicaStatus.ERROR``. """ def __init__( @@ -40,7 +40,7 @@ def __init__( Args: router_zmq_addr: ZMQ address to bind the ROUTER socket. pub_zmq_addr: ZMQ address to bind the PUB socket. - heartbeat_timeout: Seconds before an instance is considered + heartbeat_timeout: Seconds before a replica is considered unhealthy if no heartbeat / update is received. """ self._router_zmq_addr = router_zmq_addr @@ -55,7 +55,7 @@ def __init__( self._pub = self._ctx.socket(zmq.PUB) self._pub.bind(self._pub_zmq_addr) - self._instances: dict[str, InstanceInfo] = {} + self._replicas: dict[str, ReplicaInfo] = {} self._lock = threading.Lock() self._pub_lock = threading.Lock() @@ -75,43 +75,43 @@ def __init__( self._periodic_thread = threading.Thread(target=self._periodic_loop, daemon=True) self._periodic_thread.start() - def get_active_instances(self) -> InstanceList: - """Return an :class:`InstanceList` of active (UP) instances only.""" + def get_active_replicas(self) -> ReplicaList: + """Return a :class:`ReplicaList` of active (UP) replicas only.""" with self._lock: - active = [inst for inst in self._instances.values() if inst.status == StageStatus.UP] - return InstanceList(instances=active, timestamp=time()) + active = [rep for rep in self._replicas.values() if rep.status == ReplicaStatus.UP] + return ReplicaList(replicas=active, timestamp=time()) - def add_new_instance(self, event: InstanceEvent) -> None: - """Add a new instance based on an incoming event.""" + def add_new_replica(self, event: ReplicaEvent) -> None: + """Add a new replica based on an incoming event.""" with self._lock: - self._add_new_instance_locked(event) + self._add_new_replica_locked(event) self._schedule_broadcast() - def update_instance_info(self, event: InstanceEvent) -> None: - """Update an existing instance based on an incoming event.""" + def update_replica_info(self, event: ReplicaEvent) -> None: + """Update an existing replica based on an incoming event.""" with self._lock: - self._update_instance_info_locked(event) + self._update_replica_info_locked(event) self._schedule_broadcast() - def remove_instance(self, event: InstanceEvent) -> None: - """Mark an instance as removed / down based on an incoming event. + def remove_replica(self, event: ReplicaEvent) -> None: + """Mark a replica as removed / down based on an incoming event. - This marks the instance's status as DOWN or ERROR (depending on the + This marks the replica's status as DOWN or ERROR (depending on the event) but keeps it in the internal registry. It is removed from the - *active* instance list published to hubs. + *active* replica list published to hubs. """ with self._lock: - self._remove_instance_locked(event) + self._remove_replica_locked(event) self._schedule_broadcast() - def publish_instance_list_update(self) -> bool: - """Publish the current active instance list to all subscribers. + def publish_replica_list_update(self) -> bool: + """Publish the current active replica list to all subscribers. Returns: True if the PUB send succeeded, False if it was dropped (e.g. socket not ready when using ``zmq.NOBLOCK``). """ - active_list = self.get_active_instances() + active_list = self.get_active_replicas() payload = asdict(active_list) data = json.dumps(payload).encode("utf-8") @@ -133,12 +133,12 @@ def _schedule_broadcast(self) -> None: with self._pending_lock: self._pending_broadcast = True - def _mark_instance_error_locked(self, info: InstanceInfo) -> None: - """Mark instance as ERROR (e.g. after heartbeat timeout).""" - info.status = StageStatus.ERROR + def _mark_replica_error_locked(self, info: ReplicaInfo) -> None: + """Mark replica as ERROR (e.g. after heartbeat timeout).""" + info.status = ReplicaStatus.ERROR def _check_heartbeat_timeouts(self) -> None: - """Mark instances as ERROR if their heartbeat has timed out.""" + """Mark replicas as ERROR if their heartbeat has timed out.""" now = time() timed_out = False gc_ttl = 600.0 # 10 minutes @@ -146,17 +146,17 @@ def _check_heartbeat_timeouts(self) -> None: with self._lock: to_delete: list[str] = [] - for input_addr, info in self._instances.items(): - if info.status == StageStatus.UP and now - info.last_heartbeat > self._heartbeat_timeout: - self._mark_instance_error_locked(info) + for input_addr, info in self._replicas.items(): + if info.status == ReplicaStatus.UP and now - info.last_heartbeat > self._heartbeat_timeout: + self._mark_replica_error_locked(info) timed_out = True - elif info.status in (StageStatus.DOWN, StageStatus.ERROR) and now - info.last_heartbeat > gc_ttl: + elif info.status in (ReplicaStatus.DOWN, ReplicaStatus.ERROR) and now - info.last_heartbeat > gc_ttl: to_delete.append(input_addr) for input_addr in to_delete: - del self._instances[input_addr] + del self._replicas[input_addr] if timed_out: - # Instance liveness changed; request broadcast. + # Replica liveness changed; request broadcast. self._schedule_broadcast() def close(self) -> None: @@ -187,22 +187,22 @@ def close(self) -> None: except zmq.ZMQError: pass - def _parse_instance_event(self, data: dict[str, Any]) -> InstanceEvent | None: - """Parse wire payload dict into InstanceEvent. Returns None if invalid.""" + def _parse_replica_event(self, data: dict[str, Any]) -> ReplicaEvent | None: + """Parse wire payload dict into ReplicaEvent. Returns None if invalid.""" try: - return InstanceEvent( + return ReplicaEvent( input_addr=str(data["input_addr"]), output_addr=str(data["output_addr"]), stage_id=int(data["stage_id"]), event_type=str(data["event_type"]), - status=StageStatus(data.get("status")), + status=ReplicaStatus(data.get("status")), queue_length=data.get("queue_length"), ) except (KeyError, ValueError, TypeError): return None def _recv_loop(self) -> None: - """Background loop that receives and processes instance events.""" + """Background loop that receives and processes replica events.""" while self._running: try: frames = self._router.recv_multipart() @@ -219,12 +219,12 @@ def _recv_loop(self) -> None: payload = frames[-1] try: data = json.loads(payload.decode("utf-8")) - event = self._parse_instance_event(data) + event = self._parse_replica_event(data) except json.JSONDecodeError as e: - logger.warning("Invalid JSON in instance event, dropping: %s", e) + logger.warning("Invalid JSON in replica event, dropping: %s", e) continue if event is None: - logger.warning("Malformed instance event, dropping") + logger.warning("Malformed replica event, dropping") continue self._handle_event(event) @@ -234,7 +234,10 @@ def _periodic_loop(self) -> None: Heartbeat timeouts are checked on their original cadence, while all broadcast requests are coalesced and flushed at most once per - ``_publish_min_interval``. + ``_publish_min_interval``. The heartbeat-check tick also schedules a + keepalive broadcast so late-joining hubs (which miss any PUB sends + that happened before their SUB connected) catch up within at most + ``heartbeat_interval`` seconds. """ heartbeat_interval = max(1.0, min(self._heartbeat_timeout / 2.0, 5.0)) loop_interval = self._publish_min_interval @@ -245,6 +248,13 @@ def _periodic_loop(self) -> None: if now - last_heartbeat_check >= heartbeat_interval: self._check_heartbeat_timeouts() + # Keepalive broadcast: ZMQ PUB doesn't queue for late + # subscribers, so an OmniCoordClientForHub that connects + # after the initial UP events miss them entirely and would + # never see the current replica list otherwise. Scheduling a + # broadcast on every heartbeat tick caps that staleness at + # ``heartbeat_interval`` without flooding the wire. + self._schedule_broadcast() last_heartbeat_check = now with self._pending_lock: @@ -256,43 +266,51 @@ def _periodic_loop(self) -> None: continue # Publish outside lock. Clear pending only on success. - if self.publish_instance_list_update(): + if self.publish_replica_list_update(): with self._pending_lock: self._pending_broadcast = False if self._stop_event.wait(timeout=loop_interval): break - def _handle_event(self, event: InstanceEvent) -> None: + def _handle_event(self, event: ReplicaEvent) -> None: """Dispatch an incoming event to the appropriate handler.""" try: input_addr = event.input_addr - # Heartbeat: only update last_heartbeat; if previously ERROR, - # promote back to UP and broadcast once. + # Heartbeat: refresh last_heartbeat and queue_length. The stage + # client refreshes queue_length just-in-time via its + # ``_on_heartbeat`` hook, so heartbeats are the only periodic + # source of live load for LeastQueueLengthBalancer; failing to + # propagate it here would let the policy route on stale data. + # If previously ERROR, promote back to UP and broadcast once. if event.event_type == "heartbeat": promote = False + queue_changed = False with self._lock: - info = self._instances.get(input_addr) + info = self._replicas.get(input_addr) if info is not None: info.last_heartbeat = time() - if info.status == StageStatus.ERROR: - info.status = StageStatus.UP + if event.queue_length is not None and info.queue_length != event.queue_length: + info.queue_length = event.queue_length + queue_changed = True + if info.status == ReplicaStatus.ERROR: + info.status = ReplicaStatus.UP promote = True - if promote: + if promote or queue_changed: self._schedule_broadcast() return # Check-and-act under single lock to avoid TOCTOU race (duplicate - # registration when concurrent events arrive for the same instance). + # registration when concurrent events arrive for the same replica). with self._lock: - if input_addr not in self._instances: - self._add_new_instance_locked(event) + if input_addr not in self._replicas: + self._add_new_replica_locked(event) else: - if event.status == StageStatus.DOWN: - self._remove_instance_locked(event) + if event.status == ReplicaStatus.DOWN: + self._remove_replica_locked(event) else: - self._update_instance_info_locked(event) + self._update_replica_info_locked(event) # Any non-heartbeat state change that affects the active list # is coalesced and flushed via the periodic loop. @@ -300,7 +318,7 @@ def _handle_event(self, event: InstanceEvent) -> None: except (KeyError, ValueError, TypeError) as e: logger.warning("Dropping malformed event: %s", e) - def _add_new_instance_locked(self, event: InstanceEvent) -> None: + def _add_new_replica_locked(self, event: ReplicaEvent) -> None: input_addr = event.input_addr if not input_addr: raise KeyError("input_addr required") @@ -309,7 +327,7 @@ def _add_new_instance_locked(self, event: InstanceEvent) -> None: raise KeyError("stage_id required and must be non-negative") now = time() - info = InstanceInfo( + info = ReplicaInfo( input_addr=input_addr, output_addr=event.output_addr, stage_id=stage_id, @@ -318,11 +336,11 @@ def _add_new_instance_locked(self, event: InstanceEvent) -> None: last_heartbeat=now, registered_at=now, ) - self._instances[input_addr] = info + self._replicas[input_addr] = info - def _update_instance_info_locked(self, event: InstanceEvent) -> None: + def _update_replica_info_locked(self, event: ReplicaEvent) -> None: input_addr = event.input_addr - info = self._instances[input_addr] + info = self._replicas[input_addr] if event.status is not None: info.status = event.status @@ -330,10 +348,10 @@ def _update_instance_info_locked(self, event: InstanceEvent) -> None: if event.queue_length is not None: info.queue_length = event.queue_length - def _remove_instance_locked(self, event: InstanceEvent) -> None: + def _remove_replica_locked(self, event: ReplicaEvent) -> None: input_addr = event.input_addr - info = self._instances.get(input_addr) + info = self._replicas.get(input_addr) if info is None: return - info.status = StageStatus.DOWN + info.status = ReplicaStatus.DOWN diff --git a/vllm_omni/distributed/omni_coordinator/runtime.py b/vllm_omni/distributed/omni_coordinator/runtime.py new file mode 100644 index 00000000000..71df7c0521c --- /dev/null +++ b/vllm_omni/distributed/omni_coordinator/runtime.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Lifecycle wrapper around :class:`OmniCoordinator`. + +``OmniCoordinatorRuntime`` is the single-purpose owner of the head-side +coordinator process artifacts: it picks two free TCP ports, constructs an +:class:`OmniCoordinator` bound to them, exposes the resulting addresses, and +provides a single ``close()`` method to tear everything down. + +The ROUTER address is later handed to :class:`OmniMasterServer` so it can be +published to registering replicas; the PUB address is handed to the +``Orchestrator``, which constructs its :class:`OmniCoordClientForHub` against +it. +""" + +from __future__ import annotations + +import logging + +from vllm.utils.network_utils import get_open_ports_list + +from .omni_coordinator import OmniCoordinator + +logger = logging.getLogger(__name__) + + +class OmniCoordinatorRuntime: + """Own one :class:`OmniCoordinator` and the two ports it binds. + + Constructor binds; :meth:`close` tears down. The class deliberately does + not expose the coordinator instance — callers should consume the + coordinator only via its wire protocol through + :class:`OmniCoordClientForStage` and :class:`OmniCoordClientForHub`. + """ + + def __init__( + self, + *, + host: str, + heartbeat_timeout: float, + ) -> None: + if not host: + raise ValueError("host must be a non-empty string") + if heartbeat_timeout <= 0: + raise ValueError("heartbeat_timeout must be positive") + + router_port, pub_port = get_open_ports_list(count=2) + self.router_address: str = f"tcp://{host}:{router_port}" + self.pub_address: str = f"tcp://{host}:{pub_port}" + + self._closed = False + self._coordinator = OmniCoordinator( + router_zmq_addr=self.router_address, + pub_zmq_addr=self.pub_address, + heartbeat_timeout=heartbeat_timeout, + ) + + logger.info( + "[OmniCoordinatorRuntime] Started (router=%s pub=%s heartbeat_timeout=%.1fs)", + self.router_address, + self.pub_address, + heartbeat_timeout, + ) + + def close(self) -> None: + """Tear down the underlying coordinator. Idempotent.""" + if self._closed: + return + self._closed = True + try: + self._coordinator.close() + except Exception: + logger.exception("[OmniCoordinatorRuntime] coordinator close failed") diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index 4a2f46b01bd..14d0d43e830 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -155,6 +155,10 @@ def _add_omni_specific_args(cls, parser: argparse.ArgumentParser) -> argparse.Ar omni_master_address: str | None = None omni_master_port: int | None = None + # OmniCoordinator integration knobs (process-local). + omni_dp_size_local: int = 1 + omni_lb_policy: str = "random" + omni_heartbeat_timeout: float = 30.0 stage_configs_path: str | None = None output_modalities: list[str] | None = None log_stats: bool = False diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 8a850f5659a..420f808dc2f 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -18,7 +18,7 @@ import time import uuid import weakref -from collections.abc import Mapping, Sequence +from collections.abc import Awaitable, Callable, Mapping, Sequence from contextlib import ExitStack from dataclasses import asdict, dataclass from typing import TYPE_CHECKING, Any, Literal, cast @@ -45,6 +45,10 @@ from vllm_omni.distributed.omni_connectors.utils.initialization import ( resolve_omni_kv_config_for_stage, ) +from vllm_omni.distributed.omni_coordinator import ( + LoadBalancer, + build_load_balancer_factory, +) from vllm_omni.engine import OmniEngineCoreRequest from vllm_omni.engine.messages import ( AbortRequestMessage, @@ -53,6 +57,7 @@ CollectiveRPCResultMessage, EngineQueueMessage, ErrorMessage, + RegisterRemoteReplicaMessage, ShutdownRequestMessage, StageSubmissionMessage, ) @@ -77,6 +82,7 @@ from vllm_omni.engine.stage_init_utils import ( LogicalStageInitPlan, ReplicaInitPlan, + StageRemoteFactoryContext, _inject_inferred_kv_tp_topology, acquire_device_locks, acquire_diffusion_device_locks, @@ -85,6 +91,7 @@ build_llm_stage_output_processor, build_stage0_input_processor, build_vllm_config, + capture_stage_factory_contexts, compute_replica_layout, extract_stage_metadata, get_stage_connector_spec, @@ -148,6 +155,7 @@ class StageRuntimeInfo: # trigger the ``create_model_config`` guard). _PARENT_ARGS_STRIP: frozenset[str] = frozenset({"stage_configs_path"}) + # Fields always populated by callers (via ``from_cli_args`` / ``asdict``) so # their presence as an override is never a surprise — suppress the # "override ignored" warning for these. @@ -229,10 +237,9 @@ def _weak_shutdown_async_omni_engine( pass for q in (request_queue, output_queue, rpc_output_queue): - if q is None: - continue try: - q.close() + if q is not None: + q.close() except Exception: pass @@ -301,6 +308,23 @@ def __init__( self._omni_master_port: int | None = kwargs.get("omni_master_port") self._omni_master_server: OmniMasterServer | None = None + # New omni-coordinator flags. Consumed only in single_stage_mode. + # ``omni_dp_size_local`` is process-local: each invocation (head and + # every headless) launches that many replicas for its own stage. + self._omni_dp_size_local: int = int(kwargs.get("omni_dp_size_local") or 1) + if self._omni_dp_size_local < 1: + raise ValueError(f"--omni-dp-size-local must be >= 1, got {self._omni_dp_size_local}") + self._omni_lb_policy: str = str(kwargs.get("omni_lb_policy") or "random") + self._omni_heartbeat_timeout: float = float(kwargs.get("omni_heartbeat_timeout") or 30.0) + if self._omni_heartbeat_timeout <= 0: + raise ValueError(f"--omni-heartbeat-timeout must be > 0, got {self._omni_heartbeat_timeout}") + # Coordinator runtime (head-distributed only). + self._coordinator_runtime: Any | None = None + # Per-stage construction context, captured after _initialize_stages + # and used by ``_build_remote_replica`` (the RemoteReplicaFactory + # passed to Orchestrator) when a headless replica registers. + self._stage_remote_factory_contexts: dict[int, StageRemoteFactoryContext] = {} + if single_stage_mode: logger.info( "[AsyncOmniEngine] Single-stage mode enabled (stage_id_filter=%s, master=%s:%s)", @@ -321,9 +345,14 @@ def __init__( self.supported_tasks: tuple[str, ...] = ("generate",) self.default_sampling_params_list: list[OmniSamplingParams] = [] self.stage_metadata: list[StageRuntimeInfo] = [] - self.request_queue: janus.Queue[EngineQueueMessage] | None = None - self.output_queue: janus.Queue[EngineQueueMessage] | None = None - self.rpc_output_queue: janus.Queue[EngineQueueMessage] | None = None + # Janus queues are constructed eagerly here (not deferred to the + # orchestrator thread) so the master server's ROUTER thread always + # sees a non-None ``self.request_queue`` when on_register fires. + # ``async_q`` lazily binds to whatever event loop first awaits on + # it (the orchestrator loop), so cross-thread use stays correct. + self.request_queue: janus.Queue[EngineQueueMessage] = janus.Queue() + self.output_queue: janus.Queue[EngineQueueMessage] = janus.Queue() + self.rpc_output_queue: janus.Queue[EngineQueueMessage] = janus.Queue() self._shutdown_called = False self._weak_finalizer: weakref.finalize | None = None self._rpc_lock = threading.Lock() @@ -443,30 +472,37 @@ def _shutdown_initialized_clients(clients: Sequence[StageClient]) -> None: ) def _validate_single_stage_mode_replica_constraints(self) -> None: - """Reject unsupported replica fan-out in single-stage mode.""" + """Apply --omni-dp-size-local to the local stage's runtime.num_replicas. + + In the previous revision this method rejected LLM stages with + ``num_replicas > 1``. The whole point of ``--omni-dp-size-local`` is + to lift that restriction for the *local* stage, so the rejection is + gone. We now use this hook to write ``--omni-dp-size-local`` onto + the self-stage's runtime config so downstream code + (``compute_replica_layout`` → ``_build_logical_stage_init_plans``) + sees a consistent view. + """ if not self.single_stage_mode: return + target_stage_id = self._single_stage_id_filter + if target_stage_id is None: + return - unsupported: list[tuple[int, int]] = [] for idx, stage_cfg in enumerate(self.stage_configs): - runtime_cfg = getattr(stage_cfg, "runtime", {}) - num_replicas = int( - runtime_cfg.get("num_replicas", 1) - if hasattr(runtime_cfg, "get") - else getattr(runtime_cfg, "num_replicas", 1) - ) - if num_replicas <= 1: - continue - if getattr(stage_cfg, "stage_type", "llm") == "diffusion": - continue stage_id = int(getattr(stage_cfg, "stage_id", idx)) - unsupported.append((stage_id, num_replicas)) - - if unsupported: - raise ValueError( - "single_stage_mode only supports num_replicas > 1 for diffusion stages; " - f"found non-diffusion stages {unsupported}" - ) + runtime_cfg = getattr(stage_cfg, "runtime", None) + if runtime_cfg is None: + continue + if stage_id == target_stage_id: + # Self stage: take --omni-dp-size-local from this process. + try: + runtime_cfg.num_replicas = self._omni_dp_size_local + except Exception: + if hasattr(runtime_cfg, "__setitem__"): + runtime_cfg["num_replicas"] = self._omni_dp_size_local + # Other stages keep their config-declared num_replicas; in + # head-distributed mode they will be launched as ``launch_mode + # == "remote"`` with the configured count. def _build_logical_stage_init_plans( self, @@ -540,9 +576,18 @@ def _build_logical_stage_init_plans( replica_metadata = extract_stage_metadata(replica_cfg) replica_metadata.replica_id = replica_id - if self.single_stage_mode: - if replica_metadata.stage_type != "diffusion": - replica_metadata.runtime_cfg = None + # In single_stage_mode the head only owns its self stage's + # replicas; the remote-stage metadata exists only so the + # orchestrator can route requests through StagePool. Wiping + # ``runtime_cfg`` there makes sense (we don't manage those + # devices). For the *self stage* we MUST keep the + # per-replica runtime so ``setup_stage_devices`` can apply + # the device split from ``compute_replica_layout`` — + # otherwise every replica inherits the parent's full + # CUDA_VISIBLE_DEVICES and stacks on cuda:0 (OOM with any + # model whose footprint exceeds ~1/(2N) of the card). + if launch_mode == "remote" and replica_metadata.stage_type != "diffusion": + replica_metadata.runtime_cfg = None replicas.append( ReplicaInitPlan( @@ -578,6 +623,11 @@ def _start_omni_master_server(self, stage_plans: Sequence[LogicalStageInitPlan]) all_stage_ids: list[int] = [] stage_replica_counts: dict[int, int] = {} + # Slots that the head itself will register (launch_mode == "local") + # — auto-assign on the master must not hand these out to remote + # headless registrations even within the race window before the + # head's own register_stage_with_omni_master call completes. + head_local_replicas: dict[int, list[int]] = {} seen_stage_ids: set[int] = set() for plan in stage_plans: stage_id = plan.configured_stage_id @@ -588,12 +638,27 @@ def _start_omni_master_server(self, stage_plans: Sequence[LogicalStageInitPlan]) seen_stage_ids.add(stage_id) all_stage_ids.append(stage_id) stage_replica_counts[stage_id] = len(plan.replicas) + local_rids = [rep.replica_id for rep in plan.replicas if rep.launch_mode == "local"] + if local_rids: + head_local_replicas[stage_id] = local_rids + + # Start the OmniCoordinator runtime first so its router address is + # available to publish in every registration reply. + from vllm_omni.distributed.omni_coordinator import OmniCoordinatorRuntime + + self._coordinator_runtime = OmniCoordinatorRuntime( + host=self._omni_master_address, + heartbeat_timeout=self._omni_heartbeat_timeout, + ) self._omni_master_server = OmniMasterServer( master_address=self._omni_master_address, master_port=self._omni_master_port, stage_ids=all_stage_ids, stage_replica_counts=stage_replica_counts, + coordinator_router_address=self._coordinator_runtime.router_address, + on_register=self._dispatch_master_register, + head_local_replicas=head_local_replicas, ) self._omni_master_server.start() logger.info( @@ -601,6 +666,147 @@ def _start_omni_master_server(self, stage_plans: Sequence[LogicalStageInitPlan]) all_stage_ids, ) + # ------------------------------------------------------------------ + # Remote replica factory (head-side client construction) + # ------------------------------------------------------------------ + + async def _build_remote_replica(self, stage_id: int, replica_id: int) -> Any: + """Construct a head-side stage client for a newly-registered remote replica. + + Used by :class:`Orchestrator` as its ``remote_replica_factory``. + The orchestrator awaits this from its own asyncio loop, so the + client is created in the same loop that owns ZMQ sockets — no + cross-thread setup is required. + + Raises if the stage is not known or if its construction context + was not captured (e.g. the stage was empty at bring-up time). + """ + ctx = self._stage_remote_factory_contexts.get(stage_id) + if ctx is None: + raise RuntimeError( + f"no factory context captured for stage {stage_id}; " + f"known stages: {sorted(self._stage_remote_factory_contexts.keys())}" + ) + if self._omni_master_server is None: + raise RuntimeError("OmniMasterServer is not running; cannot build remote replica") + + alloc = self._omni_master_server.get_allocation(stage_id, replica_id=replica_id) + + # Build a per-replica copy of the base metadata so ``replica_id`` + # is correct (StageMetadata is a plain dataclass-like). + metadata = copy.copy(ctx.base_metadata) + try: + metadata.replica_id = replica_id + except Exception: + # Best-effort: if the metadata object is frozen / unusual, the + # downstream client will fall back to ``replica_id = 0``. + pass + + if ctx.stage_type == "diffusion": + client = StageDiffusionClient.from_addresses( + metadata, + request_address=alloc.input_bind_address, + response_address=alloc.output_bind_address, + batch_size=ctx.diffusion_batch_size, + ) + logger.info( + "[AsyncOmniEngine] Built remote diffusion client for stage=%d replica=%d (req=%s resp=%s)", + stage_id, + replica_id, + alloc.input_bind_address, + alloc.output_bind_address, + ) + return client + + # LLM path + if ctx.vllm_config is None or ctx.executor_class is None: + raise RuntimeError(f"stage {stage_id} factory context is missing vllm_config / executor_class") + + # The headless's StageEngineCoreProc subprocess calls + # vllm.v1.engine.core.startup_handshake at boot and blocks until the + # head's handshake ROUTER answers — without this step it hits the + # built-in 5-minute timeout and exits. The bootstrap (pre-allocated) + # path runs `connect_remote_engine_cores` to perform that + # handshake; dynamic attach must do the same, otherwise every + # replica that comes in via `on_register` (auto-assigned or + # explicit-id-beyond-pre-alloc) deadlocks. Run the blocking + # handshake in a thread so the orchestrator loop stays responsive, + # then build the async client on this loop where `make_async_mp_client` + # expects to be invoked. + master_server = self._omni_master_server + ctx_vllm_config = ctx.vllm_config + ctx_executor_class = ctx.executor_class + + def _run_handshake() -> Any: + with connect_remote_engine_cores( + vllm_config=ctx_vllm_config, + omni_master_server=master_server, + stage_id=stage_id, + replica_id=replica_id, + ) as remote_resources: + _engine_manager, _coordinator, addresses, _ = remote_resources + return _engine_manager, _coordinator, addresses + + engine_manager, coordinator, addresses = await asyncio.to_thread(_run_handshake) + + client_addresses: dict[str, str] = { + "input_address": addresses.inputs[0], + "output_address": addresses.outputs[0], + } + if addresses.frontend_stats_publish_address is not None: + client_addresses["stats_update_address"] = addresses.frontend_stats_publish_address + + client = StageEngineCoreClientBase.make_async_mp_client( + vllm_config=ctx_vllm_config, + executor_class=ctx_executor_class, + metadata=metadata, + client_addresses=client_addresses, + proc=None, + engine_manager=engine_manager, + coordinator=coordinator, + ) + logger.info( + "[AsyncOmniEngine] Built remote LLM client for stage=%d replica=%d (input=%s)", + stage_id, + replica_id, + client_addresses["input_address"], + ) + return client + + # ------------------------------------------------------------------ + # OmniCoordinator on_register proxy + # ------------------------------------------------------------------ + + def _dispatch_master_register(self, stage_id: int, replica_id: int, alloc: Any) -> None: + """Forward a master-server registration to the orchestrator queue. + + Called on :class:`OmniMasterServer`'s ROUTER thread (not the + orchestrator loop). Must return promptly. ``self.request_queue`` + is initialized in :meth:`_initialize_janus_queues` (running inside + the orchestrator thread) before the master server starts, so it + is non-None here; ``janus.Queue.sync_q`` is thread-safe by + construction. + """ + if self.request_queue is None: + logger.warning( + "[AsyncOmniEngine] request_queue not initialized; dropping register_remote_replica stage=%d replica=%d", + stage_id, + replica_id, + ) + return + msg = RegisterRemoteReplicaMessage( + stage_id=int(stage_id), + replica_id=int(replica_id), + ) + try: + self.request_queue.sync_q.put_nowait(msg) + except Exception: + logger.exception( + "[AsyncOmniEngine] Failed to enqueue register_remote_replica for stage=%d replica=%d", + stage_id, + replica_id, + ) + def _initialize_llm_replica( self, plan: ReplicaInitPlan, @@ -686,6 +892,11 @@ def _initialize_llm_replica( stage_init_timeout, ) if self.single_stage_mode and self._omni_master_server is not None: + coord_router_addr: str | None = ( + self._coordinator_runtime.router_address + if self._coordinator_runtime is not None + else None + ) engine_manager, coordinator, addresses = launch_stack.enter_context( launch_omni_core_engines( vllm_config=vllm_config, @@ -695,6 +906,7 @@ def _initialize_llm_replica( stage_id=plan.metadata.stage_id, stage_config=stage_cfg, replica_id=plan.replica_id, + omni_coordinator_address=coord_router_addr, ) ) else: @@ -830,12 +1042,20 @@ def _initialize_diffusion_replica( "[AsyncOmniEngine] Stage %s diffusion registration completed", plan.metadata.stage_id, ) + coord_router_addr: str | None = ( + self._coordinator_runtime.router_address + if self._coordinator_runtime is not None + else None + ) proc, _, _, _ = spawn_diffusion_proc( self.model, od_config, handshake_address=handshake_address, request_address=request_address, response_address=response_address, + omni_coordinator_address=coord_router_addr, + omni_stage_id=plan.metadata.stage_id, + omni_replica_id=plan.replica_id, ) complete_diffusion_handshake(proc, handshake_address, stage_init_timeout) logger.info( @@ -1049,7 +1269,13 @@ def _initialize_stages(self, stage_init_timeout: int) -> None: replicas_per_stage, replica_devices_map, ) + # Capture per-stage context now (before _start_omni_master_server) + # so the on_register proxy can build head-side clients for + # registrations that arrive immediately after the server starts. if self.single_stage_mode: + self._stage_remote_factory_contexts = capture_stage_factory_contexts( + stage_plans, diffusion_batch_size=self.diffusion_batch_size + ) self._start_omni_master_server(stage_plans) stage_pools: list[StagePool] = [] @@ -1085,6 +1311,14 @@ def _initialize_stages(self, stage_init_timeout: int) -> None: self._omni_master_server.stop() except Exception: logger.exception("[AsyncOmniEngine] Failed to stop OmniMasterServer during stage-init cleanup") + if self._coordinator_runtime is not None: + try: + self._coordinator_runtime.close() + except Exception: + logger.exception( + "[AsyncOmniEngine] Failed to close OmniCoordinatorRuntime during stage-init cleanup" + ) + self._coordinator_runtime = None raise self.stage_pools = stage_pools @@ -1104,13 +1338,6 @@ def _initialize_stages(self, stage_init_timeout: int) -> None: supported_tasks.add("speech") self.supported_tasks = tuple(supported_tasks) if supported_tasks else ("generate",) - def _initialize_janus_queues(self) -> None: - """Initialize janus queues inside orchestrator thread loop context.""" - self.request_queue = janus.Queue() - self.output_queue = janus.Queue() - self.rpc_output_queue = janus.Queue() - logger.debug("[AsyncOmniEngine] janus queues initialized in orchestrator thread loop") - def _bootstrap_orchestrator( self, stage_init_timeout: int, @@ -1122,10 +1349,15 @@ def _bootstrap_orchestrator( asyncio.set_event_loop(loop) async def _run_orchestrator() -> None: - self._initialize_janus_queues() - self._initialize_stages(stage_init_timeout) pd_config = self._detect_pd_config() + coordinator_pub_address: str | None = None + load_balancer_factory: Callable[[], LoadBalancer] | None = None + remote_replica_factory: Callable[[int, int], Awaitable[Any]] | None = None + if self._coordinator_runtime is not None: + coordinator_pub_address = self._coordinator_runtime.pub_address + load_balancer_factory = build_load_balancer_factory(self._omni_lb_policy) + remote_replica_factory = self._build_remote_replica orchestrator = Orchestrator( request_async_queue=self.request_queue.async_q, output_async_queue=self.output_queue.async_q, @@ -1133,6 +1365,9 @@ async def _run_orchestrator() -> None: stage_pools=self.stage_pools, async_chunk=self.async_chunk, pd_config=pd_config, + coordinator_pub_address=coordinator_pub_address, + load_balancer_factory=load_balancer_factory, + remote_replica_factory=remote_replica_factory, ) if not startup_future.done(): startup_future.set_result(asyncio.get_running_loop()) @@ -1948,8 +2183,6 @@ def add_request( reasoning_ended=reasoning_ended, resumable=resumable, ) - if self.request_queue is None: - raise RuntimeError("request_queue is not initialized") self.request_queue.sync_q.put_nowait(msg) # CFG companion expansion: create and enqueue companion requests @@ -2017,8 +2250,6 @@ def add_streaming_update( resumable=resumable, message_type="streaming_update", ) - if self.request_queue is None: - raise RuntimeError("request_queue is not initialized") self.request_queue.sync_q.put_nowait(msg) async def add_streaming_update_async( @@ -2045,8 +2276,6 @@ async def add_streaming_update_async( def try_get_output(self, timeout: float = 0.001) -> EngineQueueMessage | None: """Read one output message from the Orchestrator output queue.""" - if self.output_queue is None: - return None try: return self.output_queue.sync_q.get(timeout=timeout) except queue.Empty: @@ -2056,8 +2285,6 @@ def try_get_output(self, timeout: float = 0.001) -> EngineQueueMessage | None: async def try_get_output_async(self) -> EngineQueueMessage | None: """Async read from the Orchestrator output queue.""" - if self.output_queue is None: - return None try: return self.output_queue.sync_q.get_nowait() except queue.Empty: @@ -2092,11 +2319,6 @@ def collective_rpc( This uses a dedicated RPC output queue so control-plane messages do not race with the normal request output polling loop. """ - if self.request_queue is None: - raise RuntimeError("request_queue is not initialized") - if self.rpc_output_queue is None: - raise RuntimeError("rpc_output_queue is not initialized") - rpc_id = uuid.uuid4().hex msg = CollectiveRPCRequestMessage( rpc_id=rpc_id, @@ -2184,8 +2406,6 @@ def shutdown(self) -> None: logger.warning("[AsyncOmniEngine] Orchestrator thread did not exit in time") for q in (self.request_queue, self.output_queue, self.rpc_output_queue): - if q is None: - continue try: q.close() except Exception: @@ -2198,6 +2418,13 @@ def shutdown(self) -> None: logger.exception("[AsyncOmniEngine] Failed to stop OmniMasterServer during shutdown") self._omni_master_server = None + if self._coordinator_runtime is not None: + try: + self._coordinator_runtime.close() + except Exception: + logger.exception("[AsyncOmniEngine] Failed to close OmniCoordinatorRuntime during shutdown") + self._coordinator_runtime = None + def _try_shutdown(self, *args, **kwargs) -> None: try: self.shutdown() diff --git a/vllm_omni/engine/messages.py b/vllm_omni/engine/messages.py index 28a5c721cb1..0e55105f13f 100644 --- a/vllm_omni/engine/messages.py +++ b/vllm_omni/engine/messages.py @@ -56,6 +56,18 @@ class ShutdownRequestMessage(EngineQueueMessage, kw_only=True): type: Literal["shutdown"] = "shutdown" +class RegisterRemoteReplicaMessage(EngineQueueMessage, kw_only=True): + type: Literal["register_remote_replica"] = "register_remote_replica" + stage_id: int + replica_id: int + + +class UnregisterRemoteReplicaMessage(EngineQueueMessage, kw_only=True): + type: Literal["unregister_remote_replica"] = "unregister_remote_replica" + stage_id: int + input_addr: str + + class ErrorMessage(EngineQueueMessage, kw_only=True): type: Literal["error"] = "error" error: str diff --git a/vllm_omni/engine/omni_core_engine_proc_manager.py b/vllm_omni/engine/omni_core_engine_proc_manager.py new file mode 100644 index 00000000000..8df5fd05326 --- /dev/null +++ b/vllm_omni/engine/omni_core_engine_proc_manager.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Process manager for omni stage engine subprocesses. + +This is a drop-in replacement for vLLM's :class:`CoreEngineProcManager` that +spawns :meth:`StageEngineCoreProc.run_stage_core` instead of the upstream +``EngineCoreProc.run_engine_core``, and forwards omni-specific kwargs +(coordinator address, stage id, per-rank replica id). + +Each spawned subprocess corresponds to exactly one omni *replica*: it has its +own ZMQ allocation from :class:`OmniMasterServer` and (when an +``omni_coordinator_address`` is provided) its own +:class:`OmniCoordClientForStage` reporting heartbeat / status. + +Liveness monitoring and shutdown are inherited from +:class:`CoreEngineProcManager` unchanged. +""" + +from __future__ import annotations + +import contextlib +import threading +import weakref +from multiprocessing.process import BaseProcess +from multiprocessing.queues import Queue + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import numa_utils +from vllm.utils.system_utils import get_mp_context +from vllm.v1.engine.utils import CoreEngineProcManager +from vllm.v1.executor import Executor +from vllm.v1.utils import shutdown + +from vllm_omni.engine.stage_engine_core_proc import StageEngineCoreProc + +logger = init_logger(__name__) + +try: + # ``set_device_control_env_var`` lives next to CoreEngineProcManager and + # is only required for non-CUDA DP, so we tolerate its absence on + # older / future vLLM revisions. + from vllm.v1.engine.utils import set_device_control_env_var # type: ignore +except ImportError: # pragma: no cover - depends on vLLM build + set_device_control_env_var = None # type: ignore[assignment] + + +class OmniCoreEngineProcManager(CoreEngineProcManager): + """Spawn :class:`StageEngineCoreProc` subprocesses with omni kwargs. + + The body mirrors :class:`CoreEngineProcManager.__init__` because the + upstream class hardcodes ``target=EngineCoreProc.run_engine_core`` and + does not expose an extensibility hook. The differences from upstream are: + + * ``target`` is :meth:`StageEngineCoreProc.run_stage_core`. + * Per-rank ``omni_replica_id`` is computed as + ``base_replica_id + rank_idx`` and added to each subprocess's kwargs. + * ``omni_coordinator_address`` (if provided) and ``omni_stage_id`` are + added to every subprocess's kwargs. + """ + + def __init__( + self, + local_engine_count: int, + start_index: int, + local_start_index: int, + vllm_config: VllmConfig, + local_client: bool, + handshake_address: str, + executor_class: type[Executor], + log_stats: bool, + *, + omni_stage_id: int, + omni_coordinator_address: str | None = None, + omni_replica_base_id: int = 0, + client_handshake_address: str | None = None, + tensor_queue: Queue | None = None, + ) -> None: + # NOTE: we intentionally do not call ``super().__init__`` — the + # parent's body hardcodes the wrong target. We re-implement it here + # while reusing the parent's instance methods (shutdown, monitor). + if local_engine_count <= 0: + raise ValueError(f"local_engine_count must be > 0, got {local_engine_count}") + + context = get_mp_context() + common_kwargs: dict[str, object] = { + "vllm_config": vllm_config, + "local_client": local_client, + "handshake_address": handshake_address, + "executor_class": executor_class, + "log_stats": log_stats, + "tensor_queue": tensor_queue, + "omni_stage_id": int(omni_stage_id), + "omni_coordinator_address": omni_coordinator_address, + } + + if client_handshake_address: + common_kwargs["client_handshake_address"] = client_handshake_address + + # Intra-replica vLLM DP mesh (i.e. ``data_parallel_size`` ranks sharing + # one engine, one DPCoordinator, one set of weights). Distinct from + # the omni-level notion of multiple independent replicas of a stage — + # those each spawn their own OmniCoreEngineProcManager and never join + # a vLLM DP group across replicas. + has_intra_replica_dp = vllm_config.parallel_config.data_parallel_size > 1 + + self.processes: list[BaseProcess] = [] + local_dp_ranks: list[int] = [] + for index in range(local_engine_count): + local_index = local_start_index + index + global_index = start_index + index + # Each spawned subprocess is one omni replica. The replica id + # is contiguous within this manager; the master server may have + # pre-allocated a contiguous block starting at ``omni_replica_base_id``. + omni_replica_id = omni_replica_base_id + index + + local_dp_ranks.append(local_index) + self.processes.append( + context.Process( + target=StageEngineCoreProc.run_stage_core, + name=( + f"StageEngineCoreProc_stage{omni_stage_id}" + f"_replica{omni_replica_id}" + (f"_DP{global_index}" if has_intra_replica_dp else "") + ), + kwargs=common_kwargs + | { + "dp_rank": global_index, + "local_dp_rank": local_index, + "omni_replica_id": omni_replica_id, + }, + ) + ) + + self._finalizer = weakref.finalize(self, shutdown, self.processes) + self.manager_stopped = threading.Event() + self.failed_proc_name: str | None = None + + try: + for proc, local_dp_rank in zip(self.processes, local_dp_ranks): + device_control_context: contextlib.AbstractContextManager[None] = contextlib.nullcontext() + if ( + has_intra_replica_dp + and set_device_control_env_var is not None + and (not current_platform.is_cuda_alike() or vllm_config.parallel_config.use_ray) + ): + device_control_context = set_device_control_env_var(vllm_config, local_dp_rank) + + with ( + device_control_context, + numa_utils.configure_subprocess( + vllm_config, + local_rank=0, + dp_local_rank=local_dp_rank, + process_kind="EngineCore", + ), + ): + proc.start() + finally: + if self.finished_procs(): + self.shutdown() diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index ea615d5790b..27bb44ad7e5 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -4,12 +4,20 @@ Runs inside a background thread with its own asyncio event loop. Owns logical request progression across stage pools and handles stage-to-stage transfer logic. + +In distributed mode (``coordinator_pub_address`` provided), it also +owns the single :class:`OmniCoordClientForHub`, runs a +:meth:`_watch_replica_list` task that converts replica disappearances +into ``unregister_remote_replica`` control messages, and handles the +``register_remote_replica`` / ``unregister_remote_replica`` flow that +attaches / detaches head-side stage clients for headless replicas. """ from __future__ import annotations import asyncio import time as _time +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from typing import Any @@ -22,6 +30,12 @@ from vllm.v1.engine import EngineCoreOutputs from vllm.v1.engine.exceptions import EngineDeadError +from vllm_omni.distributed.omni_coordinator import ( + LoadBalancer, + OmniCoordClientForHub, + RandomBalancer, + ReplicaStatus, +) from vllm_omni.engine import OmniEngineCoreRequest from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.engine.messages import ( @@ -32,13 +46,26 @@ EngineQueueMessage, ErrorMessage, OutputMessage, + RegisterRemoteReplicaMessage, + ShutdownRequestMessage, StageMetricsMessage, StageSubmissionMessage, + UnregisterRemoteReplicaMessage, ) from vllm_omni.engine.serialization import serialize_additional_information from vllm_omni.engine.stage_pool import StagePool from vllm_omni.outputs import OmniRequestOutput +# Factory signature for building a head-side stage client for a +# *dynamically attached* (auto-assigned) remote replica. +# +# Receives ``(stage_id, replica_id)`` and returns an awaitable yielding the +# constructed client (any type — it must satisfy the shape expected by the +# matching :class:`StagePool`, i.e. expose ``client_addresses["input_address"]`` +# or ``request_address``, plus the usual ``add_request_async`` / +# ``get_output_async`` / ``shutdown`` surface). +RemoteReplicaFactory = Callable[[int, int], Awaitable[Any]] + logger = init_logger(__name__) @@ -124,6 +151,10 @@ class StreamingInputState: class Orchestrator: """Runs inside a background thread's asyncio event loop.""" + # Cadence at which the replica-list watcher polls for disappearances. + _WATCH_REPLICA_INTERVAL_S: float = 0.5 + _WATCH_REPLICA_IDLE_INTERVAL_S: float = 1.0 + def __init__( self, request_async_queue: janus.AsyncQueue[EngineQueueMessage], @@ -133,6 +164,9 @@ def __init__( *, async_chunk: bool = False, pd_config: dict[str, Any] | None = None, + coordinator_pub_address: str | None = None, + load_balancer_factory: Callable[[], LoadBalancer] | None = None, + remote_replica_factory: RemoteReplicaFactory | None = None, ) -> None: self.request_async_queue = request_async_queue self.output_async_queue = output_async_queue @@ -159,6 +193,28 @@ def __init__( self._fatal_error: str | None = None self._fatal_error_stage_id: int | None = None + # Background tasks for fire-and-forget message handlers (currently + # only ``register_remote_replica`` and ``unregister_remote_replica``). + # Held as a set so each task's reference survives the loop and the + # task can self-deregister on completion. + self._membership_tasks: set[asyncio.Task[None]] = set() + + # Distributed-mode wiring. The hub is constructed on the + # orchestrator's asyncio loop because it spawns a SUB background + # thread; building it from another thread would race the + # ``_init_done`` event. + self._hub: OmniCoordClientForHub | None = ( + OmniCoordClientForHub(coordinator_pub_address) if coordinator_pub_address is not None else None + ) + self._remote_replica_factory = remote_replica_factory + # Inject hub + per-pool LB into each StagePool so they can run + # distributed dispatch via ``StagePool.pick``. + if self._hub is not None: + factory = load_balancer_factory or RandomBalancer + for pool in self.stage_pools: + pool.attach_hub(self._hub) + pool.attach_load_balancer(factory()) + async def run(self) -> None: """Main entry point for the Orchestrator event loop.""" logger.info("[Orchestrator] Starting event loop") @@ -168,9 +224,16 @@ async def run(self) -> None: self._orchestration_output_handler(), name="orchestrator-stage-output-handler", ) + # The replica watcher only runs in distributed mode. It's still + # created in both cases so ``run()`` has a uniform task graph; + # ``_watch_replica_list`` is a no-op poll when ``self._hub`` is None. + watch_task = asyncio.create_task( + self._watch_replica_list(), + name="orchestrator-replica-watcher", + ) try: - await asyncio.gather(request_task, output_task) + await asyncio.gather(request_task, output_task, watch_task) except asyncio.CancelledError: raise except EngineDeadError as e: @@ -185,11 +248,11 @@ async def run(self) -> None: raise finally: self._shutdown_event.set() - for task in (request_task, output_task): + for task in (request_task, output_task, watch_task): if not task.done(): task.cancel() try: - await asyncio.gather(request_task, output_task, return_exceptions=True) + await asyncio.gather(request_task, output_task, watch_task, return_exceptions=True) except Exception: pass @@ -199,8 +262,32 @@ async def run(self) -> None: if self._fatal_error is not None: await self._drain_pending_requests_on_fatal() + # Wait briefly for any in-flight membership handlers (register / + # unregister remote replica) to finish so they don't leave the + # head-side pool in a half-attached state. Cancel anything that + # hasn't completed in time; the generic pending-task sweep below + # will collect the cancellations. + if self._membership_tasks: + try: + await asyncio.wait_for( + asyncio.gather(*self._membership_tasks, return_exceptions=True), + timeout=10.0, + ) + except (asyncio.TimeoutError, Exception): + for t in self._membership_tasks: + if not t.done(): + t.cancel() + self._shutdown_stages() + # Close the hub last so any in-flight dispatch still has access. + if self._hub is not None: + try: + self._hub.close() + except RuntimeError: + pass + self._hub = None + loop = asyncio.get_running_loop() pending = [t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task() and not t.done()] for task in pending: @@ -208,6 +295,27 @@ async def run(self) -> None: if pending: await asyncio.gather(*pending, return_exceptions=True) + # ---- Background task helpers ---- + + def _spawn_membership_task(self, coro: Awaitable[None], *, label: str) -> None: + """Run a fire-and-forget membership-change coroutine. + + Holds a strong reference until completion (asyncio would otherwise + garbage-collect a bare task), and logs any uncaught exception. + """ + task = asyncio.create_task(coro, name=f"orchestrator-{label}") + self._membership_tasks.add(task) + + def _on_done(t: asyncio.Task[None]) -> None: + self._membership_tasks.discard(t) + if t.cancelled(): + return + exc = t.exception() + if exc is not None: + logger.error("[Orchestrator] %s task crashed", label, exc_info=exc) + + task.add_done_callback(_on_done) + # ---- Request handling ---- async def _request_handler(self) -> None: @@ -226,7 +334,20 @@ async def _request_handler(self) -> None: await self._handle_abort(msg) elif msg_type == "collective_rpc": await self._handle_collective_rpc(msg) - elif msg_type == "shutdown": + elif isinstance(msg, RegisterRemoteReplicaMessage): + # Dynamic-attach involves a ~5s blocking handshake (run in a + # thread by ``_build_remote_replica``); ``await`` here would + # block the queue and stall the next ``add_request`` until + # the attach finishes. Dispatch as a background task so the + # main message loop keeps draining. + self._spawn_membership_task(self._handle_register_remote_replica(msg), label="register_remote_replica") + elif isinstance(msg, UnregisterRemoteReplicaMessage): + # Symmetric with register: keep the main queue flowing. + self._spawn_membership_task( + self._handle_unregister_remote_replica(msg), + label="unregister_remote_replica", + ) + elif isinstance(msg, ShutdownRequestMessage): logger.info("[Orchestrator] Received shutdown signal") self._shutdown_event.set() # Pre-mark stage clients as shutting down to prevent @@ -415,7 +536,7 @@ async def _handle_collective_rpc(self, msg: CollectiveRPCRequestMessage) -> None results: list[Any] = [] stage_ids: list[int] = [] for pool in target_pools: - for replica_id in range(pool.num_replicas): + for replica_id in pool.live_replica_ids(): stage_result = await pool.collective_rpc( replica_id=replica_id, method=method, @@ -451,7 +572,7 @@ async def _orchestration_loop(self) -> None: idle = True for stage_id in range(self.num_stages): pool = self.stage_pools[stage_id] - for replica_id in range(pool.num_replicas): + for replica_id in pool.live_replica_ids(): if self._shutdown_event.is_set(): return @@ -1105,14 +1226,137 @@ async def _drain_pending_requests_on_fatal(self) -> None: ) self.request_states.pop(req_id, None) + # ---- Distributed-mode replica attach / detach ---- + + async def _watch_replica_list(self) -> None: + """Convert hub replica disappearances into unregister control messages.""" + last_up: set[tuple[int, str]] = set() + while not self._shutdown_event.is_set(): + if self._hub is None: + # No coordinator wired up; sleep coarsely and re-check shutdown. + try: + await asyncio.sleep(self._WATCH_REPLICA_IDLE_INTERVAL_S) + except asyncio.CancelledError: + raise + continue + + try: + snap = self._hub.get_replica_list() + current = {(rep.stage_id, rep.input_addr) for rep in snap.replicas if rep.status == ReplicaStatus.UP} + for stage_id, addr in last_up - current: + await self.request_async_queue.put( + UnregisterRemoteReplicaMessage( + stage_id=stage_id, + input_addr=addr, + ) + ) + last_up = current + except asyncio.CancelledError: + raise + except Exception: + logger.exception("[Orchestrator] _watch_replica_list iteration failed") + + try: + await asyncio.sleep(self._WATCH_REPLICA_INTERVAL_S) + except asyncio.CancelledError: + raise + + async def _handle_register_remote_replica(self, msg: RegisterRemoteReplicaMessage) -> None: + """Bind a head-side client for a newly registered remote replica.""" + stage_id = int(msg.stage_id) + replica_id = int(msg.replica_id) + if not (0 <= stage_id < self.num_stages): + logger.warning( + "[Orchestrator] register_remote_replica: stage_id %d out of range (num_stages=%d)", + stage_id, + self.num_stages, + ) + return + if self._remote_replica_factory is None: + logger.warning( + "[Orchestrator] register_remote_replica received for stage=%d replica=%d but no factory installed", + stage_id, + replica_id, + ) + return + + try: + await self._attach_remote_replica(stage_id, replica_id) + except Exception: + logger.exception( + "[Orchestrator] failed to attach remote replica stage=%d replica=%d", + stage_id, + replica_id, + ) + + async def _handle_unregister_remote_replica(self, msg: UnregisterRemoteReplicaMessage) -> None: + """Tear down the head-side client for a vanished remote replica.""" + stage_id = int(msg.stage_id) + input_addr = str(msg.input_addr) + if not (0 <= stage_id < self.num_stages): + return + pool = self.stage_pools[stage_id] + affected = pool.invalidate_addr(input_addr) + self._detach_remote_replica(stage_id, input_addr) + if affected: + await self._cleanup_request_ids(affected, abort=True) + for req_id in affected: + await self.output_async_queue.put( + ErrorMessage( + error="stage replica disappeared", + request_id=req_id, + stage_id=stage_id, + ) + ) + + async def _attach_remote_replica(self, stage_id: int, replica_id: int) -> None: + """Build a head-side stage client via the injected factory and register it.""" + factory = self._remote_replica_factory + if factory is None: + return + pool = self.stage_pools[stage_id] + client = await factory(stage_id, replica_id) + input_addr = StagePool._client_input_addr(client) + if input_addr is None: + raise RuntimeError( + f"remote replica factory for stage {stage_id} produced a client without a discoverable input address" + ) + pool.add_client(input_addr, client) + logger.info( + "[Orchestrator] attached remote replica stage=%d replica=%d addr=%s", + stage_id, + replica_id, + input_addr, + ) + + def _detach_remote_replica(self, stage_id: int, input_addr: str) -> None: + """Shut down + remove the head-side client at ``input_addr``.""" + pool = self.stage_pools[stage_id] + client = pool.remove_client(input_addr) + if client is None: + return + try: + client.shutdown() + except Exception: + logger.exception( + "[Orchestrator] failed to shutdown client for stage=%d addr=%s", + stage_id, + input_addr, + ) + logger.info( + "[Orchestrator] detached remote replica stage=%d addr=%s", + stage_id, + input_addr, + ) + def _shutdown_stages(self) -> None: """Shutdown all stage pools.""" if self._stages_shutdown: return self._stages_shutdown = True - total = sum(pool.num_replicas for pool in self.stage_pools) + total = sum(pool.live_num_replicas for pool in self.stage_pools) logger.info("[Orchestrator] Shutting down all %d client(s)", total) for pool in self.stage_pools: - for replica_id in range(pool.num_replicas): + for replica_id in pool.live_replica_ids(): pool.shutdown_replica(replica_id) diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py index 25e4294e51b..e9e2944a748 100644 --- a/vllm_omni/engine/stage_engine_core_client.py +++ b/vllm_omni/engine/stage_engine_core_client.py @@ -23,6 +23,7 @@ from vllm_omni.distributed.omni_connectors.utils.initialization import ( KV_TRANSFER_PORT_OFFSET, ) +from vllm_omni.distributed.omni_connectors.utils.kv_utils import kv_zmq_port from vllm_omni.engine.stage_client import StageClientBase from vllm_omni.engine.stage_init_utils import StageMetadata @@ -344,7 +345,12 @@ def _initialize_kv_sender_endpoint(self) -> None: try: # Orchestrator always reports rank-0's port; receiver # workers add their own local_rank * KV_RANK_PORT_STRIDE. - sender_port = int(base_port) + KV_TRANSFER_PORT_OFFSET + int(from_stage) + sender_port = kv_zmq_port( + int(base_port), + int(from_stage), + local_rank=0, + replica_id=self.replica_id, + ) except (TypeError, ValueError): logger.warning( "[StageEngineCoreClient] stage-%s [rep-%s] could not resolve sender_zmq_port " @@ -386,7 +392,12 @@ def get_kv_sender_info( # rank-0 base port; receiver workers adjust per KV_RANK_PORT_STRIDE. return { "host": self._kv_sender_host, - "zmq_port": base_port + kv_transfer_port_offset + int(self.stage_id), + "zmq_port": kv_zmq_port( + base_port - KV_TRANSFER_PORT_OFFSET + kv_transfer_port_offset, + int(self.stage_id), + local_rank=0, + replica_id=self.replica_id, + ), } def set_engine_outputs(self, engine_outputs: EngineCoreOutput) -> None: diff --git a/vllm_omni/engine/stage_engine_core_proc.py b/vllm_omni/engine/stage_engine_core_proc.py index 2ab8b37dd5f..0b3dadcbd1c 100644 --- a/vllm_omni/engine/stage_engine_core_proc.py +++ b/vllm_omni/engine/stage_engine_core_proc.py @@ -7,6 +7,8 @@ from __future__ import annotations +import contextlib +import os import signal from multiprocessing.process import BaseProcess from typing import TYPE_CHECKING, Any @@ -33,6 +35,8 @@ ) from vllm.v1.utils import shutdown +from vllm_omni.distributed.omni_coordinator import OmniCoordClientForStage + if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.v1.executor import Executor @@ -53,28 +57,41 @@ def run_stage_core( *args: Any, dp_rank: int = 0, local_dp_rank: int = 0, + omni_coordinator_address: str | None = None, + omni_stage_id: int | None = None, + omni_replica_id: int = 0, **kwargs: Any, ) -> None: - """Launch StageEngineCoreProc busy loop in background process.""" + """Launch StageEngineCoreProc busy loop in background process. + + Omni-specific kwargs: + - ``omni_coordinator_address``: ROUTER address of the head-side + :class:`OmniCoordinator`. When provided, this subprocess + instantiates an :class:`OmniCoordClientForStage` after the + HELLO/INIT/READY handshake completes and reports its status + + queue length via heartbeats. The hook is wired so each + heartbeat refreshes ``queue_length`` from the live scheduler. + - ``omni_stage_id``: logical stage id this replica belongs to. + Required when ``omni_coordinator_address`` is provided. + - ``omni_replica_id``: cluster-unique replica id within the + stage (assigned by :class:`OmniMasterServer`). Used for + logging / metrics only. + """ signal_callback: SignalCallback | None = None maybe_register_config_serialize_by_value() engine_core: StageEngineCoreProc | None = None + coord_client: OmniCoordClientForStage | None = None try: - vllm_config: VllmConfig = kwargs["vllm_config"] - parallel_config = vllm_config.parallel_config + # NOTE: previous revisions hardcoded data_parallel_size=1 here + # (TODO referencing issue #984). The hardcoding has been removed + # so the DP fields propagate through from the caller exactly + # like upstream vLLM. - set_process_title(f"StageEngineCoreProc_DP{dp_rank}") + stage_label = f"stage{omni_stage_id}" if omni_stage_id is not None else "noid" + set_process_title(f"StageEngineCoreProc_{stage_label}_replica{omni_replica_id}_DP{dp_rank}") decorate_logs() - - # the current vllm-omni does not support data parallelism, - # so we set the data parallel size to 1. - # [TODO] support data parallelism in the future. - # https://github.com/vllm-project/vllm-omni/issues/984 - parallel_config.data_parallel_size = 1 - parallel_config.data_parallel_size_local = 1 - parallel_config.data_parallel_rank = 0 - parallel_config.data_parallel_index = dp_rank + os.environ["VLLM_OMNI_REPLICA_ID"] = str(max(int(omni_replica_id), 0)) engine_core = StageEngineCoreProc( *args, @@ -82,6 +99,41 @@ def run_stage_core( **kwargs, ) + # Each subprocess corresponds to exactly one omni replica with + # its own OmniMasterServer allocation, so the heartbeat client + # runs unconditionally — there is no dp_rank-based gating. + if omni_coordinator_address is not None: + if omni_stage_id is None: + raise ValueError("omni_stage_id must be provided when omni_coordinator_address is set") + addresses: EngineZmqAddresses = engine_core.addresses + if not addresses.inputs or not addresses.outputs: + raise RuntimeError( + "EngineCore handshake did not populate input/output addresses; " + "cannot start OmniCoordClientForStage" + ) + coord_client = OmniCoordClientForStage( + coord_zmq_addr=omni_coordinator_address, + input_addr=addresses.inputs[0], + output_addr=addresses.outputs[0], + stage_id=int(omni_stage_id), + ) + + def _refresh_queue_length() -> None: + """Pre-heartbeat hook: refresh queue_length from scheduler.""" + scheduler = getattr(engine_core, "scheduler", None) + if scheduler is None: + return + try: + coord_client._queue_length = int( # type: ignore[union-attr] + scheduler.get_num_unfinished_requests() + ) + except Exception: + # Live scheduler stats are best-effort — heartbeats + # must not fail because of a stats lookup error. + pass + + coord_client._on_heartbeat = _refresh_queue_length + def wakeup_engine() -> None: engine_core.input_queue.put_nowait((EngineCoreRequestType.WAKEUP, None)) @@ -111,6 +163,9 @@ def signal_handler(signum: int, frame: Any) -> None: signal.signal(signal.SIGINT, signal.SIG_DFL) if signal_callback is not None: signal_callback.stop() + if coord_client is not None: + with contextlib.suppress(RuntimeError): + coord_client.close() if engine_core is not None: engine_core.shutdown() diff --git a/vllm_omni/engine/stage_engine_startup.py b/vllm_omni/engine/stage_engine_startup.py index 05bcdf7d138..a2744eea47b 100644 --- a/vllm_omni/engine/stage_engine_startup.py +++ b/vllm_omni/engine/stage_engine_startup.py @@ -4,8 +4,9 @@ import contextlib import dataclasses +import socket import threading -from collections.abc import Iterator +from collections.abc import Callable, Iterator from dataclasses import dataclass from typing import Any @@ -31,6 +32,15 @@ StageRoute = tuple[int, int] +# Sentinel that signals "auto-assign me a replica_id" on the wire. Negative +# values are not valid replica ids, so any sub-zero value works equivalently. +AUTO_ASSIGN_REPLICA_ID = -1 + +# Callback signature for OmniMasterServer.on_register. Fires only for +# auto-assigned replicas (new, headless-launched). The arguments are +# (stage_id, replica_id, allocation). +OnRegisterCallback = Callable[[int, int, "StageAllocation"], None] + # Poll period (ms) used by the registration/handshake loop. _POLL_PERIOD_MS = 5_000 # Default timeout (s) for a stage to send READY. @@ -110,6 +120,10 @@ def __init__( master_port: int, stage_ids: list[int], stage_replica_counts: dict[int, int] | None = None, + *, + coordinator_router_address: str | None = None, + on_register: OnRegisterCallback | None = None, + head_local_replicas: dict[int, list[int]] | None = None, ) -> None: self._address = master_address self._port = master_port @@ -117,25 +131,41 @@ def __init__( self._stage_configs: dict[StageRoute, Any] = {} self._stage_coordinator_addresses: dict[StageRoute, StageCoordinatorAddresses] = {} self._stage_config_events: dict[StageRoute, threading.Event] = {} + # Coordinator ROUTER address echoed back in every registration reply + # so OmniCoordClientForStage knows where to connect from inside the + # engine subprocess. + self._coordinator_router_address = coordinator_router_address + # Fires only for *newly assigned* (auto-assigned) replicas, not for + # head-side pre-allocated slots that already have head-side clients. + self._on_register = on_register + # Per-stage allocation lock + auto-assign cursor, so concurrent + # registrations from multiple headless processes for the same stage + # don't race on the routing table. + self._alloc_lock = threading.Lock() + self._stage_ids_known: set[int] = set(int(sid) for sid in stage_ids) self._thread: threading.Thread | None = None self._stop_event = threading.Event() stage_replica_counts = dict(stage_replica_counts or {}) + # Slots the *head* itself will fill via ``launch_omni_core_engines`` + # / its own ``register_stage_with_omni_master`` call. Auto-assigning + # headless registrations must skip these even when they appear + # ``_stage_configs``-unfilled — otherwise a fast headless on the same + # host can race the head's own registration and steal slot 0. + self._head_local_slots: set[StageRoute] = set() + for sid, rids in (head_local_replicas or {}).items(): + for rid in rids: + self._head_local_slots.add((int(sid), int(rid))) + for sid in stage_ids: - replica_count = max(1, int(stage_replica_counts.get(sid, 1))) + replica_count = int(stage_replica_counts.get(sid, 1)) + # Allow 0 explicitly so non-self stages (head distributed mode) + # can declare "no local replicas; remote ones will register + # dynamically". + if replica_count < 0: + raise ValueError(f"stage_replica_counts[{sid}] must be >= 0, got {replica_count}") for replica_id in range(replica_count): - route = (sid, replica_id) - self._stage_config_events[route] = threading.Event() - self._stage_coordinator_addresses[route] = StageCoordinatorAddresses() - hs_port, inp_port, out_port = get_open_ports_list(count=3) - self._stage_routes[route] = StageAllocation( - handshake_bind_address=f"tcp://{master_address}:{hs_port}", - handshake_connect_address=f"tcp://{master_address}:{hs_port}", - input_bind_address=f"tcp://{master_address}:{inp_port}", - input_connect_address=f"tcp://{master_address}:{inp_port}", - output_bind_address=f"tcp://{master_address}:{out_port}", - output_connect_address=f"tcp://{master_address}:{out_port}", - ) + self._allocate_route_locked(sid, replica_id) logger.info( "[OmniMasterServer] Pre-allocated addresses for stages %s (master=%s:%d)", @@ -157,10 +187,84 @@ def port(self) -> int: """Return the registration port exposed to stage launchers.""" return self._port + @property + def coordinator_router_address(self) -> str | None: + """Return the OmniCoordinator ROUTER address echoed to replicas.""" + return self._coordinator_router_address + def get_allocation(self, stage_id: int, replica_id: int = 0) -> StageAllocation: """Return the full address allocation for *stage_id*.""" return self._stage_routes[(stage_id, replica_id)] + # ------------------------------------------------------------------ + # Allocation + # ------------------------------------------------------------------ + + def _allocate_route_locked(self, stage_id: int, replica_id: int) -> StageAllocation: + """Allocate handshake/input/output ports for ``(stage_id, replica_id)``. + + Idempotent: if the route already exists, returns the existing + allocation unchanged. Caller is responsible for holding + ``self._alloc_lock`` when needed. + """ + route = (stage_id, replica_id) + existing = self._stage_routes.get(route) + if existing is not None: + return existing + + self._stage_config_events[route] = threading.Event() + self._stage_coordinator_addresses[route] = StageCoordinatorAddresses() + hs_port, inp_port, out_port = get_open_ports_list(count=3) + alloc = StageAllocation( + handshake_bind_address=f"tcp://{self._address}:{hs_port}", + handshake_connect_address=f"tcp://{self._address}:{hs_port}", + input_bind_address=f"tcp://{self._address}:{inp_port}", + input_connect_address=f"tcp://{self._address}:{inp_port}", + output_bind_address=f"tcp://{self._address}:{out_port}", + output_connect_address=f"tcp://{self._address}:{out_port}", + ) + self._stage_routes[route] = alloc + return alloc + + def _next_free_replica_id(self, stage_id: int) -> int: + """Return the next replica id to assign for an auto-assign registration. + + Strategy: prefer filling a pre-allocated-but-unfilled slot (one that + ``__init__`` reserved in ``_stage_routes`` but no registration has + completed yet) so the head's bootstrap path — which waits on + ``_stage_config_events[(stage_id, replica_id)]`` for specific + pre-allocated ids — unblocks. Only when every pre-allocated slot for + this stage has been filled do we allocate a fresh id. + + Slots in ``_head_local_slots`` are reserved for the head's own + ``launch_omni_core_engines`` registration. Auto-assign must skip + them even when ``_stage_configs`` shows them unfilled — otherwise a + same-host headless that registers before the head's own + ``register_stage_with_omni_master`` call would steal slot 0. + + Without this, a headless contributor using ``--omni-dp-size-local > 1`` + (auto-assign mode) would skip past pre-allocated slot 0 and pick ids + beyond ``num_replicas``, deadlocking the head's + ``connect_remote_engine_cores`` wait. + """ + # Pre-allocated slots that haven't received a registration yet are + # tracked by absence from ``_stage_configs``. Head-owned slots are + # not auto-assignable. + for sid, rid in sorted(self._stage_routes): + if sid != stage_id: + continue + if (sid, rid) in self._head_local_slots: + continue + if (sid, rid) not in self._stage_configs: + return rid + # Every pre-allocated slot is filled (or head-owned); allocate a + # fresh id past the existing routes. + used = {rid for (sid, rid) in self._stage_routes if sid == stage_id} + rid = 0 + while rid in used: + rid += 1 + return rid + def register_stage_config( self, stage_id: int, @@ -280,12 +384,21 @@ def _serve(self, ctx: zmq.Context) -> None: # type: ignore[type-arg] poller = zmq.Poller() poller.register(reg_socket, zmq.POLLIN) + # The server runs until ``stop()`` is called so that headless replicas + # spawned after the head finished its initial bring-up can still + # register dynamically. ``pending`` is kept around purely for + # debug-level logging of which pre-allocated slots have not yet + # registered; once empty it does not terminate the loop. pending: set[StageRoute] = set(self._stage_routes.keys()) - while pending and not self._stop_event.is_set(): + while not self._stop_event.is_set(): events: list[tuple[zmq.Socket, int]] = poller.poll(_POLL_PERIOD_MS) # type: ignore[assignment] if not events: - logger.debug("[OmniMasterServer] Still waiting for registration from stages: %s", pending) + if pending: + logger.debug( + "[OmniMasterServer] Still waiting for registration from pre-allocated slots: %s", + pending, + ) continue for sock, _ in events: @@ -296,12 +409,12 @@ def _serve(self, ctx: zmq.Context) -> None: # type: ignore[type-arg] # Cleanup reg_socket.close(linger=0) - logger.info("[OmniMasterServer] All stages registered; server thread exiting.") + logger.info("[OmniMasterServer] Server thread exiting.") def _handle_registration(self, reg_socket: zmq.Socket) -> StageRoute | None: # type: ignore[type-arg] """Receive a stage registration and reply with the handshake address. - Returns the registered stage_id on success, or None on failure. + Returns ``(stage_id, replica_id)`` on success or ``None`` on failure. """ frames = reg_socket.recv_multipart() if len(frames) < 2: @@ -318,45 +431,186 @@ def _handle_registration(self, reg_socket: zmq.Socket) -> StageRoute | None: # logger.warning("[OmniMasterServer] Failed to decode registration message: %s", exc) return None - stage_id: int | None = msg.get("stage_id") - replica_id = int(msg.get("replica_id", 0) or 0) - key = (stage_id, replica_id) - if key not in self._stage_routes: + stage_id_raw = msg.get("stage_id") + if not isinstance(stage_id_raw, int) or stage_id_raw < 0: logger.warning( - "[OmniMasterServer] Received registration for unknown stage_id=%s replica_id=%s", - stage_id, - replica_id, + "[OmniMasterServer] Registration missing or invalid stage_id: %r", + stage_id_raw, ) return None + stage_id: int = stage_id_raw + + incoming_replica_id = int(msg.get("replica_id", 0) or 0) + was_auto_assigned = incoming_replica_id < 0 + + # Distinguish two registration shapes: + # - Pre-allocated slots (concrete replica_id >= 0): the head built + # this slot during _initialize_stages. Just confirm it; do NOT + # fire ``on_register`` (the head already has a head-side client). + # - Auto-assigned slots (replica_id == AUTO_ASSIGN_REPLICA_ID): + # a *new* replica from a headless launcher. Allocate, then + # fire ``on_register`` so the orchestrator attaches. + with self._alloc_lock: + if was_auto_assigned: + replica_id = self._next_free_replica_id(stage_id) + # When auto-assign picks a slot the head pre-allocated (and + # is therefore waiting on in ``connect_remote_engine_cores``), + # the head's bootstrap path builds the head-side client. We + # must NOT also fire ``on_register`` for it; otherwise the + # orchestrator would build a duplicate client and overwrite + # the bootstrap-built one in the pool, leaking it. + preexisting_slot = (stage_id, replica_id) in self._stage_routes + alloc = self._allocate_route_locked(stage_id, replica_id) + if preexisting_slot: + was_auto_assigned = False + else: + replica_id = incoming_replica_id + if (stage_id, replica_id) not in self._stage_routes: + # Tolerate explicit replica_ids that haven't been + # pre-allocated (e.g. headless that wants a specific id). + alloc = self._allocate_route_locked(stage_id, replica_id) + was_auto_assigned = True + else: + alloc = self._stage_routes[(stage_id, replica_id)] + + # Cross-host override: when the registering replica advertised + # its own bind address + ports, rewrite the StageAllocation so + # each socket is rooted on the host that actually binds it + # (the master's pre-allocated ports are unreachable from a + # remote replica's host). + # + # Diffusion and LLM stages have different binder ownership: + # + # Diffusion remote replica (StageDiffusionProc): + # - handshake: replica binds -> rewrite to replica IP + # - input : replica binds -> rewrite to replica IP + # - output : replica binds -> rewrite to replica IP + # + # LLM remote replica (CoreClient on head): + # - handshake: head binds (``connect_remote_engine_cores``) + # -> keep on master IP, worker TCP-connects + # - input : head binds (``CoreClient`` ROUTER) + # -> keep on master IP, worker TCP-connects + # - output : head binds (``CoreClient`` PULL — default + # bind=True for PULL in ``make_zmq_socket``) + # -> keep on master IP, worker TCP-connects + # + # The registrant indicates which case via the boolean + # ``replica_binds_sockets`` payload flag. It defaults to + # True (the diffusion / single-host case) so older callers + # still get the previous full-rewrite semantics. For LLM + # remote replicas, the master keeps every address on its + # own host and the remote worker establishes 3 outbound + # TCP connections to the master. + new_bind_address = msg.get("replica_bind_address") + if new_bind_address: + replica_binds_sockets = bool(msg.get("replica_binds_sockets", True)) + if replica_binds_sockets: + hs_port = int(msg["replica_handshake_port"]) + inp_port = int(msg["replica_input_port"]) + out_port = int(msg["replica_output_port"]) + hs_bind_addr = f"tcp://{new_bind_address}:{hs_port}" + inp_bind_addr = f"tcp://{new_bind_address}:{inp_port}" + out_bind_addr = f"tcp://{new_bind_address}:{out_port}" + alloc = StageAllocation( + handshake_bind_address=hs_bind_addr, + handshake_connect_address=hs_bind_addr, + input_bind_address=inp_bind_addr, + input_connect_address=inp_bind_addr, + output_bind_address=out_bind_addr, + output_connect_address=out_bind_addr, + ) + self._stage_routes[(stage_id, replica_id)] = alloc + logger.info( + "[OmniMasterServer] Stage %d replica %d cross-host bind (sockets bound on %s; replica_ip=%s)", + stage_id, + replica_id, + "replica" if replica_binds_sockets else "master", + new_bind_address, + ) - self.register_stage_config( - stage_id, - msg.get("stage_config"), - coordinator_addresses=StageCoordinatorAddresses( - coordinator_input=msg.get("coordinator_input"), - coordinator_output=msg.get("coordinator_output"), - frontend_stats_publish_address=msg.get("frontend_stats_publish_address"), - ), - replica_id=replica_id, - ) + # Mark the slot as filled *inside* the lock. Without this, + # concurrent auto-assign registrations from a second headless + # could call ``_next_free_replica_id`` between the lock + # releasing above and the ``register_stage_config`` call + # below, observe the slot as unfilled, and hand the same + # pre-allocated handshake/input/output addresses to two + # different replicas — which then collide on + # ``zmq_socket_ctx(handshake_address, ROUTER, bind=True)``. + self.register_stage_config( + stage_id, + msg.get("stage_config"), + coordinator_addresses=StageCoordinatorAddresses( + coordinator_input=msg.get("coordinator_input"), + coordinator_output=msg.get("coordinator_output"), + frontend_stats_publish_address=msg.get("frontend_stats_publish_address"), + ), + replica_id=replica_id, + ) + + # Fire on_register only for genuinely new (auto-assigned or newly + # allocated) replicas, on the ROUTER thread. Callback is expected to + # be cheap and non-blocking (e.g. enqueue onto an asyncio queue). + if was_auto_assigned and self._on_register is not None: + try: + self._on_register(stage_id, replica_id, alloc) + except Exception: + logger.exception( + "[OmniMasterServer] on_register callback failed for stage=%d replica=%d", + stage_id, + replica_id, + ) - alloc = self._stage_routes[key] response = msgspec.msgpack.encode( { "handshake_address": alloc.handshake_connect_address, "input_address": alloc.input_bind_address, "output_address": alloc.output_bind_address, + "replica_id": replica_id, + "coordinator_router_address": self._coordinator_router_address, } ) # ROUTER-DEALER: reply is [identity, payload] (no empty delimiter). reg_socket.send_multipart([identity, response]) logger.info( - "[OmniMasterServer] Stage %d replica %d registered; assigned handshake=%s", + "[OmniMasterServer] Stage %d replica %d registered (auto=%s); handshake=%s", stage_id, replica_id, + was_auto_assigned, alloc.handshake_connect_address, ) - return key + return (stage_id, replica_id) + + +@dataclass(frozen=True) +class StageRegistrationResponse: + """Reply payload returned by :class:`OmniMasterServer` after a successful registration.""" + + handshake_address: str + input_address: str + output_address: str + replica_id: int + coordinator_router_address: str | None + + +def _detect_local_bind_address(master_address: str, master_port: int) -> str: + """Return the local IP the kernel would use to reach the master. + + Uses a connected UDP socket as a routing-table probe: ``connect()`` on + SOCK_DGRAM sends no packets but forces a route lookup, after which + ``getsockname()[0]`` exposes the source IP that an outbound packet to + ``(master_address, master_port)`` would carry. For a co-located master + this returns the loopback or eth0 IP (same effect as the legacy + ``self._address`` behaviour); for a remote master it returns the + NIC IP that's actually reachable from the master — which is exactly + the address the headless's per-stage ZMQ sockets must bind on. + """ + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect((master_address, master_port)) + return s.getsockname()[0] + finally: + s.close() def register_stage_with_omni_master( @@ -367,23 +621,37 @@ def register_stage_with_omni_master( omni_stage_config: Any = None, coordinator: DPCoordinator | None = None, return_addresses: bool = False, - replica_id: int = 0, -) -> str | tuple[str, str, str]: + replica_id: int | None = 0, + return_full_response: bool = False, + replica_bind_address: str | None = None, + replica_binds_sockets: bool = True, +) -> str | tuple[str, str, str] | StageRegistrationResponse: """Register a stage with the omni master server. Returns the per-stage handshake address by default. When ``return_addresses`` is true, also returns the stage input/output - addresses allocated by the master. + addresses allocated by the master. When ``return_full_response`` is + true, returns the full :class:`StageRegistrationResponse` including the + assigned ``replica_id`` and the OmniCoordinator ROUTER address (if + published by the master). + + Pass ``replica_id=None`` to request auto-assignment of a free replica + id by the master (used by headless launchers). """ + if replica_id is None: + wire_replica_id = AUTO_ASSIGN_REPLICA_ID + else: + wire_replica_id = int(replica_id) + reg_ctx = zmq.Context() try: reg_sock: zmq.Socket = reg_ctx.socket(zmq.DEALER) # type: ignore[attr-defined] try: reg_sock.connect(f"tcp://{omni_master_address}:{omni_master_port}") - payload = { + payload: dict[str, Any] = { "stage_id": omni_stage_id, - "replica_id": replica_id, + "replica_id": wire_replica_id, "stage_config": _serialize_stage_config(omni_stage_config), } if coordinator is not None: @@ -392,6 +660,31 @@ def register_stage_with_omni_master( payload["coordinator_output"] = coordinator_output payload["frontend_stats_publish_address"] = coordinator.get_stats_publish_address() + # Always advertise THIS host's local bind address + 3 locally + # free ports so the master can root the per-stage socket + # allocation on the replica's own interface. For a co-located + # replica the detected IP matches the master's address and + # the override is a no-op semantically; for a cross-host + # replica it's what makes the headless's ROUTER bind succeed + # (otherwise the master would hand back ``tcp://:port`` + # and ``zmq.bind`` would EADDRNOTAVAIL on the remote host). + if replica_bind_address is None: + replica_bind_address = _detect_local_bind_address(omni_master_address, omni_master_port) + hs_port, inp_port, out_port = get_open_ports_list(count=3) + payload["replica_bind_address"] = replica_bind_address + payload["replica_handshake_port"] = hs_port + payload["replica_input_port"] = inp_port + payload["replica_output_port"] = out_port + # ``False`` only for LLM headless replicas: the head's + # ``connect_remote_engine_cores`` is the binder for the + # handshake ROUTER, and ``CoreClient.__init__`` binds the + # input ROUTER and the output PULL (``make_zmq_socket`` + # defaults bind=True for PULL). The master must keep all + # three addresses on the master's host so the head can + # ``bind`` them; the remote LLM worker TCP-connects across + # hosts on all three. + payload["replica_binds_sockets"] = bool(replica_binds_sockets) + reg_sock.send(msgspec.msgpack.encode(payload)) timeout_ms = _DEFAULT_STARTUP_TIMEOUT_S * 1_000 if not reg_sock.poll(timeout=timeout_ms): @@ -402,13 +695,16 @@ def register_stage_with_omni_master( f"for stage {omni_stage_id}." ) response_bytes = reg_sock.recv() - response = msgspec.msgpack.decode(response_bytes) - handshake_address: str = response["handshake_address"] - input_address: str = response["input_address"] - output_address: str = response["output_address"] + response_msg = msgspec.msgpack.decode(response_bytes) + handshake_address: str = response_msg["handshake_address"] + input_address: str = response_msg["input_address"] + output_address: str = response_msg["output_address"] + assigned_replica_id: int = int(response_msg.get("replica_id", wire_replica_id)) + coord_router_addr: str | None = response_msg.get("coordinator_router_address") logger.info( - "Stage %d registered; handshake_address=%s", + "Stage %d replica %d registered; handshake_address=%s", omni_stage_id, + assigned_replica_id, handshake_address, ) finally: @@ -416,6 +712,14 @@ def register_stage_with_omni_master( finally: reg_ctx.term() + if return_full_response: + return StageRegistrationResponse( + handshake_address=handshake_address, + input_address=input_address, + output_address=output_address, + replica_id=assigned_replica_id, + coordinator_router_address=coord_router_addr, + ) if return_addresses: return handshake_address, input_address, output_address return handshake_address @@ -543,8 +847,16 @@ def launch_omni_core_engines( stage_id: int, stage_config: Any = None, replica_id: int = 0, + *, + omni_coordinator_address: str | None = None, ) -> Iterator[tuple[CoreEngineProcManager, DPCoordinator | None, EngineZmqAddresses]]: - """Launch local engine cores using the omni registration flow.""" + """Launch local engine cores using the omni registration flow. + + When ``omni_coordinator_address`` is provided, the spawned engine + subprocesses use :class:`OmniCoreEngineProcManager` and each + instantiates an :class:`OmniCoordClientForStage` after the handshake + completes so the head's :class:`OmniCoordinator` knows about them. + """ addresses = omni_master_server.get_zmq_addresses(stage_id, replica_id=replica_id) parallel_config = vllm_config.parallel_config # Determine the number of local engines and their ranks. @@ -608,16 +920,35 @@ def launch_omni_core_engines( handshake_bind_address = omni_master_server.get_allocation(stage_id, replica_id=replica_id).handshake_bind_address with zmq_socket_ctx(handshake_bind_address, zmq.ROUTER, bind=True) as handshake_socket: - local_engine_manager = CoreEngineProcManager( - local_engine_count=local_engine_count, - start_index=start_index, - local_start_index=local_start_index, - vllm_config=vllm_config, - local_client=True, - handshake_address=handshake_address, - executor_class=executor_class, - log_stats=log_stats, - ) + if omni_coordinator_address is not None: + # Use the omni subclass so each spawned subprocess instantiates + # an OmniCoordClientForStage and heartbeats to the coordinator. + from vllm_omni.engine.omni_core_engine_proc_manager import OmniCoreEngineProcManager + + local_engine_manager: CoreEngineProcManager = OmniCoreEngineProcManager( + local_engine_count=local_engine_count, + start_index=start_index, + local_start_index=local_start_index, + vllm_config=vllm_config, + local_client=True, + handshake_address=handshake_address, + executor_class=executor_class, + log_stats=log_stats, + omni_stage_id=stage_id, + omni_coordinator_address=omni_coordinator_address, + omni_replica_base_id=replica_id, + ) + else: + local_engine_manager = CoreEngineProcManager( + local_engine_count=local_engine_count, + start_index=start_index, + local_start_index=local_start_index, + vllm_config=vllm_config, + local_client=True, + handshake_address=handshake_address, + executor_class=executor_class, + log_stats=log_stats, + ) yield local_engine_manager, coordinator, addresses diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py index 7ab89400b6f..1a9f98fe5f3 100644 --- a/vllm_omni/engine/stage_init_utils.py +++ b/vllm_omni/engine/stage_init_utils.py @@ -61,6 +61,63 @@ class LogicalStageInitPlan: replicas: list[ReplicaInitPlan] +@dataclass +class StageRemoteFactoryContext: + """Per-stage context cached by AsyncOmniEngine for dynamic replica attach. + + Populated once during ``_bootstrap_orchestrator`` from the per-stage + init plans. ``_build_remote_replica`` consumes it to construct the + right head-side stage client when a headless replica registers. + """ + + stage_id: int + stage_type: str + stage_cfg: Any + base_metadata: Any + # LLM-only fields: + vllm_config: Any | None = None + executor_class: type | None = None + # Diffusion-only fields: + diffusion_batch_size: int = 1 + + +def capture_stage_factory_contexts( + stage_plans: Sequence[LogicalStageInitPlan], + diffusion_batch_size: int, +) -> dict[int, StageRemoteFactoryContext]: + """Snapshot per-stage construction context for dynamic replica attach. + + Called once after ``_initialize_stages`` finishes. The captured + context holds everything ``_build_remote_replica`` needs to build a + fresh head-side client when a new headless replica registers + (vllm_config / executor_class for LLM, batch_size for diffusion, + plus the base stage metadata). + + Per-replica fields like ``replica_id`` are filled in at build time, + not at capture time. + """ + contexts: dict[int, StageRemoteFactoryContext] = {} + for plan in stage_plans: + if not plan.replicas: + # Stage was declared but has zero replicas locally; we still + # want to be able to attach incoming headless ones, so use + # the stage_cfg-derived context if any replica plan exists. + continue + template = plan.replicas[0] + stage_id = int(plan.configured_stage_id) + stage_type = template.metadata.stage_type or "llm" + contexts[stage_id] = StageRemoteFactoryContext( + stage_id=stage_id, + stage_type=stage_type, + stage_cfg=template.stage_cfg, + base_metadata=template.metadata, + vllm_config=template.stage_vllm_config, + executor_class=template.executor_class, + diffusion_batch_size=diffusion_batch_size, + ) + return contexts + + def _resolve_model_to_local_path(model: str) -> str: """Resolve an HF Hub model ID to a local cache path.""" if os.path.isdir(model): @@ -420,32 +477,59 @@ def split_devices_for_replicas( """Split a devices string into per-replica subsets. When ``num_replicas`` is 1, returns ``[devices_str]`` unchanged. - Otherwise, the total number of device IDs must equal - ``num_replicas * tp_size``; each replica gets ``tp_size`` consecutive - device IDs. + Otherwise, two YAML shapes are accepted: + + 1. **Legacy / pool mode** — ``len(devices) == num_replicas * tp_size``: + the string enumerates the full per-stage pool. Each replica gets + ``tp_size`` consecutive entries. The values are logical indices + into the launcher's ``CUDA_VISIBLE_DEVICES``. + + ``split_devices_for_replicas("1,2,3,4", 2, 2, 1) → ["1,2", "3,4"]`` + + 2. **Template mode** — ``len(devices) == tp_size``: the YAML declares + a single per-replica template (the same shape one replica would + use), and is **dp-independent**. Each replica r gets the offsets + ``[r*tp_size + a for a in template]`` of the launcher's + ``CUDA_VISIBLE_DEVICES``. The template's entries must lie in + ``[0, tp_size)``. - Example:: + ``split_devices_for_replicas("0,1", 2, 2, 1) → ["0,1", "2,3"]`` + ``split_devices_for_replicas("0,1", 4, 2, 1) → ["0,1", "2,3", "4,5", "6,7"]`` - split_devices_for_replicas("1,2,3,4", num_replicas=2, tp_size=2, stage_id=1) - # → ["1,2", "3,4"] + This lets the same ``devices: "0,1"`` YAML work for any + ``--omni-dp-size-local``: the launcher's CVD scales, the YAML + does not. + + Any other length raises ``ValueError`` (the two modes are + length-disjoint for ``num_replicas > 1``). """ if num_replicas <= 1 or devices_str is None: return [devices_str] if devices_str is not None else [devices_str] device_list = [d.strip() for d in devices_str.split(",") if d.strip()] - required = num_replicas * tp_size - if len(device_list) != required: - raise ValueError( - f"Stage {stage_id}: num_replicas={num_replicas}, " - f"tensor_parallel_size={tp_size} requires " - f"{required} devices, got {len(device_list)}: {devices_str}" - ) - result: list[str] = [] - for r in range(num_replicas): - chunk = device_list[r * tp_size : (r + 1) * tp_size] - result.append(",".join(chunk)) - return result + if len(device_list) == num_replicas * tp_size: + return [",".join(device_list[r * tp_size : (r + 1) * tp_size]) for r in range(num_replicas)] + + if len(device_list) == tp_size: + try: + offsets = [int(a) for a in device_list] + except ValueError as e: + raise ValueError(f"Stage {stage_id}: template-mode devices must be ints, got {devices_str!r}") from e + bad = [a for a in offsets if not (0 <= a < tp_size)] + if bad: + raise ValueError( + f"Stage {stage_id}: template-mode device offset(s) {bad} " + f"out of range [0, {tp_size}); devices={devices_str!r}" + ) + return [",".join(str(r * tp_size + a) for a in offsets) for r in range(num_replicas)] + + raise ValueError( + f"Stage {stage_id}: devices={devices_str!r} has {len(device_list)} id(s); " + f"need either {tp_size} (template, dp-independent) or " + f"{num_replicas * tp_size} (pool / legacy). " + f"num_replicas={num_replicas}, tensor_parallel_size={tp_size}." + ) def get_stage_tp_size(stage_cfg: Any) -> int: @@ -479,9 +563,19 @@ def get_stage_devices_per_replica(stage_cfg: Any) -> int: def compute_replica_layout( stage_configs: Sequence[Any], + *, + allow_zero: bool = False, ) -> tuple[list[int], dict[int, list[str]]]: """Compute per-stage replica counts and device assignments. + Args: + stage_configs: per-stage config objects with a ``runtime`` sub-config + exposing ``num_replicas`` and ``devices``. + allow_zero: when True, ``num_replicas == 0`` is honored (used by + single-stage / head-distributed mode for non-self stages that + will be filled dynamically by remote registrations); when False + (default), the count is clamped to at least 1. + Returns: replicas_per_stage: num_replicas per logical stage. replica_devices_map: stage_idx -> per-replica device strings @@ -495,7 +589,9 @@ def compute_replica_layout( if hasattr(runtime_cfg, "get") else getattr(runtime_cfg, "num_replicas", 1) ) - replicas_per_stage.append(max(1, num_replicas)) + if num_replicas < 0: + raise ValueError(f"num_replicas must be >= 0, got {num_replicas}") + replicas_per_stage.append(num_replicas if allow_zero else max(1, num_replicas)) replica_devices_map: dict[int, list[str]] = {} for stage_id, stage_cfg in enumerate(stage_configs): diff --git a/vllm_omni/engine/stage_pool.py b/vllm_omni/engine/stage_pool.py index 427368a02a3..68de787881f 100644 --- a/vllm_omni/engine/stage_pool.py +++ b/vllm_omni/engine/stage_pool.py @@ -10,6 +10,13 @@ from vllm.logger import init_logger from vllm.v1.engine import EngineCoreOutputs +from vllm_omni.distributed.omni_coordinator import ( + LoadBalancer, + OmniCoordClientForHub, + ReplicaInfo, + ReplicaStatus, +) +from vllm_omni.distributed.omni_coordinator.load_balancer import Task from vllm_omni.engine.stage_client import ( StagePoolClient, StagePoolDiffusionClient, @@ -35,7 +42,33 @@ class _ReplicaMetrics: class StagePool: - """Replicas of one logical stage with RR + affinity selection.""" + """Replicas of one logical stage + per-stage routing (LB + affinity). + + The pool owns the head-side stage clients for one logical stage. It also + absorbs the per-stage dispatch responsibility (load balancing, affinity + tracking, bounded-wait pick) that used to live in a separate + ``StageDispatcher`` class — see the design doc for the rationale. + + In distributed mode (when an :class:`OmniCoordClientForHub` and a + :class:`LoadBalancer` are injected via :meth:`attach_hub` / + :meth:`attach_load_balancer`), :meth:`pick` consults the hub's cached + replica list and routes via the load balancer, sticking subsequent calls + for the same ``request_id`` to the same replica. + + In non-distributed mode (no hub attached), :meth:`pick` falls back to the + legacy ``select_replica_id`` round-robin path so the multi-stage + in-process invocation is unchanged. + + Dynamic replica membership: when a remote replica is added or removed + (driven by :class:`Orchestrator` via :meth:`add_client` / + :meth:`remove_client`), the pool keeps stable integer ``replica_id``s by + storing clients in a list whose entries can be ``None`` after a removal. + Iteration callers should use :meth:`live_replica_ids` rather than + ``range(pool.num_replicas)`` to skip the gaps. + """ + + DISPATCH_WAIT_TIMEOUT_S: float = 10.0 + DISPATCH_RETRY_INTERVAL_S: float = 0.1 def __init__( self, @@ -50,33 +83,69 @@ def __init__( else: normalized_clients = [clients] - if not normalized_clients: - raise ValueError(f"StagePool for stage {stage_id} has no replicas") + # Allow empty pools when running in distributed head mode for a + # non-self stage; clients will arrive via add_client(...). self.stage_id = stage_id - self.clients: list[StagePoolClient] = normalized_clients + # Slots can become None after a dynamic remove_client (distributed mode); + # iterate via live_replica_ids() to skip holes. + self.clients: list[StagePoolClient | None] = list(normalized_clients) self._output_processor = output_processor self._stage_vllm_config = stage_vllm_config self._next_replica_id = 0 self._request_bindings: dict[str, int] = {} self._replica_metrics: list[_ReplicaMetrics] = [_ReplicaMetrics() for _ in self.clients] + # Distributed-mode state. Populated by add_client / remove_client. + self._addr_to_replica_id: dict[str, int] = {} + for replica_id, client in enumerate(self.clients): + if client is not None: + addr = self._client_input_addr(client) + if addr is not None: + self._addr_to_replica_id[addr] = replica_id + + # Distributed-mode dispatch hooks (injected by Orchestrator on bring-up). + self._hub: OmniCoordClientForHub | None = None + self._lb: LoadBalancer | None = None + # ``request_id`` → ``input_addr`` affinity (distributed mode only). + # Kept separate from the legacy ``_request_bindings`` so the two + # binding shapes do not collide. + self._affinity: dict[str, str] = {} + # ---- Stage-level properties ---- @property def num_replicas(self) -> int: + """Total slot count, including ``None`` holes from removed replicas. + + Use :meth:`live_replica_ids` to iterate only live entries. + """ return len(self.clients) + @property + def live_num_replicas(self) -> int: + """Number of currently live (non-None) replicas in this pool.""" + return sum(1 for c in self.clients if c is not None) + + def live_replica_ids(self) -> list[int]: + """Return the indices of currently live replicas in this pool.""" + return [i for i, c in enumerate(self.clients) if c is not None] + @property def stage_type(self) -> str | None: - return self.stage_client.stage_type + client = self.stage_client + return None if client is None else client.stage_type @property def final_output(self) -> bool: - return self.clients[0].final_output + client = self.stage_client + return False if client is None else bool(client.final_output) @property - def stage_client(self) -> StagePoolClient: - return self.clients[0] + def stage_client(self) -> StagePoolClient | None: + for client in self.clients: + if client is not None: + return client + return None @property def llm_stage_client(self) -> StagePoolLLMClient: @@ -90,11 +159,209 @@ def stage_vllm_config(self) -> Any: def output_processor(self) -> Any: return self._output_processor - # ---- Route binding lifecycle ---- + @property + def is_distributed(self) -> bool: + """True iff a hub has been attached (i.e. running in head-distributed mode).""" + return self._hub is not None + + # ---- Distributed-mode dispatch hooks ---- + + def attach_hub(self, hub: OmniCoordClientForHub | None) -> None: + """Inject the shared :class:`OmniCoordClientForHub`. + + Called once by :class:`Orchestrator` after the hub is constructed. + ``hub=None`` keeps the pool in legacy mode (no behavior change). + """ + self._hub = hub + + def attach_load_balancer(self, lb: LoadBalancer | None) -> None: + """Inject the per-pool :class:`LoadBalancer` for distributed-mode pick.""" + self._lb = lb + + # ---- Dynamic membership (distributed mode) ---- + + @staticmethod + def _client_input_addr(client: Any) -> str | None: + """Return the input ZMQ address advertised by ``client`` if any. + + LLM clients expose ``client_addresses["input_address"]``; diffusion + clients expose ``request_address``. Both are stable strings used by + :class:`OmniCoordinator` to key replicas. + """ + request_address = getattr(client, "request_address", None) + if isinstance(request_address, str) and request_address: + return request_address + addrs = getattr(client, "client_addresses", None) + if isinstance(addrs, dict): + addr = addrs.get("input_address") + if isinstance(addr, str) and addr: + return addr + return None + + def add_client(self, input_addr: str, client: Any) -> int: + """Register a head-side client for ``input_addr``. + + Returns the assigned ``replica_id`` (index into :attr:`clients`). + If the address is already known, replaces the existing client and + returns its existing id (this should not happen in practice — the + master server assigns unique slots — but the contract is idempotent + to keep the dispatch layer robust). + """ + if not input_addr: + raise ValueError("input_addr must be a non-empty string") + + existing = self._addr_to_replica_id.get(input_addr) + if existing is not None: + self.clients[existing] = client + return existing + + replica_id = len(self.clients) + self.clients.append(client) + self._addr_to_replica_id[input_addr] = replica_id + self._replica_metrics.append(_ReplicaMetrics()) + return replica_id + + def remove_client(self, input_addr: str) -> Any | None: + """Remove the client at ``input_addr``. Returns the removed client or ``None``. + + Slot is marked ``None`` to preserve indices for outstanding bindings. + """ + replica_id = self._addr_to_replica_id.pop(input_addr, None) + if replica_id is None: + return None + client = self.clients[replica_id] + self.clients[replica_id] = None + return client + + def get_client_by_addr(self, input_addr: str) -> Any | None: + """Return the live client for ``input_addr`` if present.""" + replica_id = self._addr_to_replica_id.get(input_addr) + if replica_id is None: + return None + return self.clients[replica_id] + + def get_replica_id_by_addr(self, input_addr: str) -> int | None: + """Return the stable replica_id for ``input_addr`` if registered.""" + return self._addr_to_replica_id.get(input_addr) + + # ---- Per-request distributed dispatch ---- + + async def pick( + self, + request_id: str, + task: Task | None = None, + *, + affinity_request_id: str | None = None, + ) -> int: + """Return a replica id for ``request_id``. + + In distributed mode: consults the hub for UP replicas, runs the load + balancer, and records affinity so future picks for the same + ``request_id`` return the same replica. Bounded wait up to + ``DISPATCH_WAIT_TIMEOUT_S`` when no UP replica is currently usable. + + In non-distributed (legacy) mode: delegates to + :meth:`select_replica_id`. + """ + if self._hub is None or self._lb is None: + return self.select_replica_id(request_id, affinity_request_id=affinity_request_id) + + # 1. Sticky: previously bound and still serviceable? + bound_addr = self._affinity.get(request_id) + if bound_addr is not None: + replica_id = self._serviceable_replica_id_for_addr(bound_addr) + if replica_id is not None: + return replica_id + # Bound replica is gone or DOWN — fall through to re-select. + self._affinity.pop(request_id, None) + + # 2. Inherited affinity (CFG companion sharing a parent request_id). + if affinity_request_id is not None: + parent_addr = self._affinity.get(affinity_request_id) + if parent_addr is not None: + replica_id = self._serviceable_replica_id_for_addr(parent_addr) + if replica_id is not None: + self._affinity[request_id] = parent_addr + return replica_id + + # 3. Fresh pick: poll hub + LB with bounded wait. + task = task or Task(request_id=request_id) + deadline = _time.monotonic() + self.DISPATCH_WAIT_TIMEOUT_S + while True: + candidates = self._collect_serviceable_replicas() + if candidates: + # LB chose an index *into our candidates list*. + lb_idx = self._lb.select(task, [rep for rep, _ in candidates]) + replica_info, replica_id = candidates[lb_idx] + self._affinity[request_id] = replica_info.input_addr + return replica_id + + now = _time.monotonic() + if now >= deadline: + raise RuntimeError(f"no UP replica for stage {self.stage_id} after {self.DISPATCH_WAIT_TIMEOUT_S:.1f}s") + await asyncio.sleep(min(self.DISPATCH_RETRY_INTERVAL_S, deadline - now)) + + def _collect_serviceable_replicas(self) -> list[tuple[ReplicaInfo, int]]: + """Return list of ``(ReplicaInfo, replica_id)`` for UP, attached replicas.""" + if self._hub is None: + return [] + snap = self._hub.get_replicas_for_stage(self.stage_id) + out: list[tuple[ReplicaInfo, int]] = [] + for rep in snap.replicas: + if rep.status != ReplicaStatus.UP: + continue + replica_id = self._addr_to_replica_id.get(rep.input_addr) + if replica_id is None: + continue # Hub knows about it but head-side client not attached yet. + if self.clients[replica_id] is None: + continue + out.append((rep, replica_id)) + return out + + def _serviceable_replica_id_for_addr(self, input_addr: str) -> int | None: + """Return ``replica_id`` for ``input_addr`` iff currently UP + attached.""" + if self._hub is None: + return None + replica_id = self._addr_to_replica_id.get(input_addr) + if replica_id is None or self.clients[replica_id] is None: + return None + snap = self._hub.get_replicas_for_stage(self.stage_id) + for rep in snap.replicas: + if rep.input_addr == input_addr and rep.status == ReplicaStatus.UP: + return replica_id + return None + + def bind(self, request_id: str, input_addr: str) -> None: + """Explicitly record affinity (distributed mode).""" + self._affinity[request_id] = input_addr + + def release(self, request_id: str) -> None: + """Drop affinity (distributed mode) and legacy binding for ``request_id``.""" + self._affinity.pop(request_id, None) + self.release_binding(request_id) + + def invalidate_addr(self, input_addr: str) -> list[str]: + """Drop affinity rows pointing at ``input_addr``; return affected request ids.""" + affected: list[str] = [rid for rid, addr in self._affinity.items() if addr == input_addr] + for rid in affected: + self._affinity.pop(rid, None) + return affected + + # ---- Legacy (non-distributed) route binding ---- def get_bound_replica_id(self, request_id: str) -> int | None: - """Return the currently bound replica id for *request_id* if present.""" - return self._request_bindings.get(request_id) + """Return the currently bound replica id for *request_id* if present. + + In distributed mode the binding may have been recorded via + :meth:`pick`; we honor it transparently here. + """ + legacy = self._request_bindings.get(request_id) + if legacy is not None: + return legacy + addr = self._affinity.get(request_id) + if addr is None: + return None + return self._addr_to_replica_id.get(addr) def get_bound_client(self, request_id: str) -> StagePoolClient | None: """Return the currently bound client for *request_id* if present.""" @@ -113,6 +380,7 @@ def get_bound_llm_client(self, request_id: str) -> StagePoolLLMClient | None: def release_binding(self, request_id: str) -> None: """Drop the route binding for *request_id* in this stage.""" self._request_bindings.pop(request_id, None) + self._affinity.pop(request_id, None) def release_bindings(self, request_ids: list[str]) -> None: """Drop route bindings for the given request ids in this stage.""" @@ -125,27 +393,43 @@ def select_replica_id( *, affinity_request_id: str | None = None, ) -> int: - """Pick a replica id for *request_id* and cache the choice.""" + """Pick a replica id for *request_id* and cache the choice (legacy path).""" cached = self.get_bound_replica_id(request_id) - if cached is not None: + if cached is not None and self.clients[cached] is not None: return cached - chosen = self.get_bound_replica_id(affinity_request_id) if affinity_request_id is not None else None + chosen: int | None = None + if affinity_request_id is not None: + parent = self.get_bound_replica_id(affinity_request_id) + if parent is not None and self.clients[parent] is not None: + chosen = parent + if chosen is None: - if self.num_replicas == 1: - chosen = 0 + live = self.live_replica_ids() + if not live: + raise RuntimeError(f"stage {self.stage_id} has no live replicas") + if len(live) == 1: + chosen = live[0] else: - chosen = self._next_replica_id - self._next_replica_id = (self._next_replica_id + 1) % self.num_replicas + # Round-robin over live replicas only. + start = self._next_replica_id % len(live) + chosen = live[start] + self._next_replica_id = (self._next_replica_id + 1) % len(live) self._request_bindings[request_id] = chosen return chosen def _llm_client(self, replica_id: int) -> StagePoolLLMClient: - return cast(StagePoolLLMClient, self.clients[replica_id]) + client = self.clients[replica_id] + if client is None: + raise RuntimeError(f"stage {self.stage_id} replica {replica_id} is not attached") + return cast(StagePoolLLMClient, client) def _diffusion_client(self, replica_id: int) -> StagePoolDiffusionClient: - return cast(StagePoolDiffusionClient, self.clients[replica_id]) + client = self.clients[replica_id] + if client is None: + raise RuntimeError(f"stage {self.stage_id} replica {replica_id} is not attached") + return cast(StagePoolDiffusionClient, client) # ---- Metrics ---- @@ -206,7 +490,7 @@ async def submit_initial( params = params_override if params_override is not None else req_state.sampling_params_list[self.stage_id] submit_kwargs = dict(submit_kwargs or {}) if self.stage_type == "diffusion": - replica_id = self.select_replica_id( + replica_id = await self._pick_or_select( request_id, affinity_request_id=affinity_request_id, ) @@ -217,10 +501,13 @@ async def submit_initial( await client.add_request_async(request_id, request, params, **submit_kwargs) return replica_id - replica_id = self.select_replica_id( + replica_id = await self._pick_or_select( request_id, affinity_request_id=affinity_request_id, ) + client = self.clients[replica_id] + if client is None: + raise RuntimeError(f"stage {self.stage_id} replica {replica_id} is not attached") try: self.output_processor.add_request( request=request, @@ -262,8 +549,12 @@ async def submit_update( """Submit a streaming update to an already admitted request.""" params = req_state.sampling_params_list[self.stage_id] replica_id = self.get_bound_replica_id(request_id) - if replica_id is None: - replica_id = self.select_replica_id(request_id) + if replica_id is None or self.clients[replica_id] is None: + replica_id = await self._pick_or_select(request_id) + + client = self.clients[replica_id] + if client is None: + raise RuntimeError(f"stage {self.stage_id} replica {replica_id} is not attached") if self.stage_type == "diffusion": await self._diffusion_client(replica_id).add_request_async(request_id, request, params) @@ -281,6 +572,17 @@ async def submit_update( await self._llm_client(replica_id).add_request_async(request) return replica_id + async def _pick_or_select( + self, + request_id: str, + *, + affinity_request_id: str | None = None, + ) -> int: + """Bridge to ``pick`` in distributed mode or ``select_replica_id`` legacy.""" + if self.is_distributed: + return await self.pick(request_id, affinity_request_id=affinity_request_id) + return self.select_replica_id(request_id, affinity_request_id=affinity_request_id) + # ---- Stage-local polling ---- async def _poll_stage_raw(self, client: StagePoolLLMClient) -> EngineCoreOutputs | None: @@ -296,7 +598,10 @@ async def process_llm_raw_outputs( raw_outputs: EngineCoreOutputs, ) -> list[Any]: """Run the shared LLM output processor on one raw poll result.""" - client = self._llm_client(replica_id) + raw_client = self.clients[replica_id] + if raw_client is None: + return [] + client = cast(StagePoolLLMClient, raw_client) processor = self.output_processor processed = processor.process_outputs( raw_outputs.outputs, @@ -319,7 +624,10 @@ async def poll_llm_raw_output( timeout_s: float = 0.001, ) -> EngineCoreOutputs | None: """Poll raw EngineCore outputs from one LLM replica once.""" - client = self._llm_client(replica_id) + raw_client = self.clients[replica_id] + if raw_client is None: + return None + client = cast(StagePoolLLMClient, raw_client) try: return await asyncio.wait_for( self._poll_stage_raw(client), @@ -339,7 +647,10 @@ async def poll_llm_raw_output( def poll_diffusion_output(self, replica_id: int) -> Any | None: """Drain one ready diffusion output from the given replica if present.""" - return self._diffusion_client(replica_id).get_diffusion_output_nowait() + raw_client = self.clients[replica_id] + if raw_client is None: + return None + return cast(StagePoolDiffusionClient, raw_client).get_diffusion_output_nowait() # ---- Stage-local control plane ---- @@ -347,7 +658,7 @@ async def abort_requests(self, request_ids: list[str]) -> None: """Abort the given requests in this stage pool. Request-bound abort routing stays inside the pool because route affinity - (`request_id -> replica_id`) is pool-owned. + (``request_id -> replica_id``) is pool-owned. """ if not request_ids: return @@ -355,13 +666,16 @@ async def abort_requests(self, request_ids: list[str]) -> None: request_ids_by_replica: dict[int, list[str]] = {} for request_id in request_ids: replica_id = self.get_bound_replica_id(request_id) - if replica_id is None: - logger.debug("[StagePool] abort: no binding for req=%s in stage-%s", request_id, self.stage_id) + if replica_id is None or self.clients[replica_id] is None: + logger.debug("[StagePool] abort: no live binding for req=%s in stage-%s", request_id, self.stage_id) continue request_ids_by_replica.setdefault(replica_id, []).append(request_id) for replica_id, replica_request_ids in request_ids_by_replica.items(): - await self.clients[replica_id].abort_requests_async(replica_request_ids) + client = self.clients[replica_id] + if client is None: + continue + await client.abort_requests_async(replica_request_ids) # Clean up OutputProcessor state (e.g. mm_accumulated tensors) that # would otherwise leak — aborted requests never produce a final @@ -377,10 +691,15 @@ async def collective_rpc( timeout: float | None = None, args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None, - ) -> Any: + ) -> dict[str, Any] | Any: """Dispatch a stage-scoped control-plane RPC to one physical route.""" kwargs = dict(kwargs or {}) client = self.clients[replica_id] + if client is None: + return { + "supported": False, + "error": f"stage {self.stage_id} replica {replica_id} is not attached", + } try: return await client.collective_rpc_async( method=method, @@ -402,7 +721,11 @@ async def collective_rpc( def shutdown_replica(self, replica_id: int) -> None: """Shutdown one backend handle in this stage pool.""" + if replica_id >= len(self.clients): + return client = self.clients[replica_id] + if client is None: + return try: client.shutdown() logger.info( diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index f2689d4b50f..50b825d7c4a 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -9,6 +9,8 @@ import json import os import signal +import threading +from multiprocessing import connection from types import FrameType from typing import Any @@ -103,6 +105,62 @@ def validate(self, args: argparse.Namespace) -> None: if args.stage_id is not None and (args.omni_master_address is None or args.omni_master_port is None): raise ValueError("--stage-id requires both --omni-master-address and --omni-master-port to be set") + # --omni-replica-address is only consulted in run_headless(); reject it + # on the head so a misconfigured launch fails loudly instead of being + # silently ignored. + if getattr(args, "omni_replica_address", None) is not None and not args.headless: + raise ValueError("--omni-replica-address requires --headless to be set") + + # --omni-dp-size-local is process-local. A value other than 1 only + # makes sense when this process owns a stage (head or headless). + omni_dp_size_local = getattr(args, "omni_dp_size_local", None) + if omni_dp_size_local is not None: + if omni_dp_size_local < 1: + raise ValueError(f"--omni-dp-size-local must be >= 1, got {omni_dp_size_local}") + if omni_dp_size_local != 1 and args.stage_id is None: + raise ValueError("--omni-dp-size-local != 1 requires --stage-id to be set") + + # vLLM CLI args that omni does not honor: parallelism comes from the + # per-stage YAML (parallel_config:, enable_expert_parallel:) and the + # process-local replica count from --omni-dp-size-local. Passing the + # vLLM equivalents on the command line would silently disagree with + # those sources of truth, so reject them at parse time. + if getattr(args, "omni", False): + explicit_cli_keys: set[str] = getattr(args, "_cli_explicit_keys", set()) or set() + prohibited_with_omni: dict[str, str] = { + "data_parallel_size": "--data-parallel-size", + "data_parallel_size_local": "--data-parallel-size-local", + "data_parallel_address": "--data-parallel-address", + "data_parallel_rpc_port": "--data-parallel-rpc-port", + "data_parallel_start_rank": "--data-parallel-start-rank", + "data_parallel_backend": "--data-parallel-backend", + "api_server_count": "--api-server-count", + "enable_expert_parallel": "--enable-expert-parallel", + } + offenders = sorted(flag for dest, flag in prohibited_with_omni.items() if dest in explicit_cli_keys) + if offenders: + raise ValueError( + "The following CLI args are not supported under --omni: " + f"{', '.join(offenders)}. Configure parallelism through the " + "per-stage YAML (`--deploy-config` / `--stage-configs-path`) " + "and replica count via `--omni-dp-size-local`." + ) + + # --omni-lb-policy is validated against the LoadBalancingPolicy enum. + omni_lb_policy = getattr(args, "omni_lb_policy", None) + if omni_lb_policy is not None: + from vllm_omni.distributed.omni_coordinator import LoadBalancingPolicy + + try: + LoadBalancingPolicy(omni_lb_policy) + except ValueError as exc: + valid = ", ".join(p.value for p in LoadBalancingPolicy) + raise ValueError(f"--omni-lb-policy={omni_lb_policy!r} is not one of: {valid}") from exc + + omni_heartbeat_timeout = getattr(args, "omni_heartbeat_timeout", None) + if omni_heartbeat_timeout is not None and omni_heartbeat_timeout <= 0: + raise ValueError(f"--omni-heartbeat-timeout must be > 0, got {omni_heartbeat_timeout}") + # Skip validation for diffusion models as they have different requirements from vllm_omni.diffusion.utils.hf_utils import is_diffusion_model @@ -190,8 +248,12 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu omni_config_group.add_argument( "--replica-id", type=int, - default=0, - help="Replica id to register when launching a single headless stage.", + default=None, + help=( + "Deprecated and ignored — replica ids are auto-assigned by the " + "master server. Specifying this flag prints a warning and has " + "no effect." + ), ) omni_config_group.add_argument( "--stage-init-timeout", @@ -253,6 +315,50 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu type=int, help="Port of the Omni orchestrator (master).", ) + omni_config_group.add_argument( + "--omni-replica-address", + "-ora", + type=str, + default=None, + help=( + "Local bind address (this host's IP) that the headless stage " + "advertises to the Omni master for its handshake/input/output " + "ZMQ sockets. If unset, auto-detected via a UDP-connect " + "routing probe against --omni-master-address. Override only " + "when the auto-detected IP is wrong (e.g. multi-NIC host " + "where the master is reachable on the wrong interface)." + ), + ) + omni_config_group.add_argument( + "--omni-dp-size-local", + type=int, + default=1, + help=( + "Number of stage replicas this runtime launches locally for its " + "own --stage-id. Process-local: head and every headless invocation " + "read their own copy; values may differ across invocations. " + "Requires --stage-id to be set when not equal to 1." + ), + ) + omni_config_group.add_argument( + "--omni-lb-policy", + type=str, + default="random", + choices=["random", "round-robin", "least-queue-length"], + help=( + "Per-stage load-balancing policy used by the head's StagePool to " + "route requests across UP replicas. Only consulted on the head runtime." + ), + ) + omni_config_group.add_argument( + "--omni-heartbeat-timeout", + type=float, + default=30.0, + help=( + "Seconds before an unreporting replica is marked ERROR in the " + "OmniCoordinator. Only consulted on the head runtime." + ), + ) # Diffusion model specific arguments omni_config_group.add_argument( @@ -564,9 +670,15 @@ def _create_default_diffusion_stage_cfg(args: argparse.Namespace) -> list[dict[s def run_headless(args: argparse.Namespace) -> None: - """Run a single stage in headless mode.""" + """Run a single stage in headless mode. + + Honors ``--omni-dp-size-local``: launches that many replicas locally for + ``--stage-id``. Each replica registers with the head's OmniMasterServer + (auto-assigned replica id when ``--omni-dp-size-local > 1`` so multiple + headless invocations can coexist) and reports heartbeats to the head's + OmniCoordinator. + """ from vllm.v1.engine.coordinator import DPCoordinator - from vllm.v1.engine.utils import CoreEngineProcManager from vllm.v1.executor.multiproc_executor import MultiprocExecutor from vllm.version import __version__ as VLLM_VERSION @@ -575,6 +687,7 @@ def run_headless(args: argparse.Namespace) -> None: spawn_diffusion_proc, ) from vllm_omni.distributed.omni_connectors.utils.initialization import resolve_omni_kv_config_for_stage + from vllm_omni.engine.omni_core_engine_proc_manager import OmniCoreEngineProcManager from vllm_omni.engine.stage_engine_startup import register_stage_with_omni_master from vllm_omni.engine.stage_init_utils import ( build_diffusion_config, @@ -582,23 +695,40 @@ def run_headless(args: argparse.Namespace) -> None: build_vllm_config, extract_stage_metadata, get_stage_connector_spec, + get_stage_devices_per_replica, inject_kv_stage_info, load_omni_transfer_config_for_model, prepare_engine_environment, + setup_stage_devices, + split_devices_for_replicas, terminate_alive_proc, ) from vllm_omni.entrypoints.utils import inject_omni_kv_config, load_and_resolve_stage_configs + from vllm_omni.platforms import current_omni_platform model = args.model stage_id: int | None = args.stage_id - replica_id: int = args.replica_id omni_master_address: str | None = args.omni_master_address omni_master_port: int | None = args.omni_master_port + omni_replica_address: str | None = getattr(args, "omni_replica_address", None) + omni_dp_size_local: int = max(1, int(getattr(args, "omni_dp_size_local", 1) or 1)) if stage_id is None: raise ValueError("--stage-id is required in headless mode") - if replica_id < 0: - raise ValueError("--replica-id must be >= 0 in headless mode") + + # ``--replica-id`` is deprecated and ignored — replica ids are + # auto-assigned by ``OmniMasterServer`` so headless processes carry + # no knowledge of their per-replica id at launch time. Warn (don't + # error) when the operator still supplies it so existing launchers + # keep working with a single log line. + explicit_cli_keys: set[str] = getattr(args, "_cli_explicit_keys", set()) or set() + if "replica_id" in explicit_cli_keys: + logger.warning( + "[Headless] --replica-id is deprecated and ignored " + "(supplied value: %s). Replica ids are auto-assigned by the " + "master server.", + args.replica_id, + ) if omni_master_address is None or omni_master_port is None: raise ValueError("--omni-master-address and --omni-master-port are required in headless mode") api_server_count = args.api_server_count or 0 @@ -609,6 +739,11 @@ def run_headless(args: argparse.Namespace) -> None: args_dict = vars(args).copy() args_dict.pop("_cli_explicit_keys", None) + # Forward ``--deploy-config`` so the headless reads the same YAML the + # head was launched with — otherwise ``load_and_resolve_stage_configs`` + # falls back to ``vllm_omni/deploy/.yaml`` and the headless's + # view of ``stage.runtime.devices`` diverges from the head's, breaking + # the per-replica device split. config_path, stage_configs = load_and_resolve_stage_configs( model, args_dict.get("stage_configs_path"), @@ -631,6 +766,39 @@ def run_headless(args: argparse.Namespace) -> None: omni_transfer_config = load_omni_transfer_config_for_model(model, config_path) omni_conn_cfg, omni_from, omni_to = resolve_omni_kv_config_for_stage(omni_transfer_config, stage_id) + # When ``--omni-dp-size-local > 1``, slice the YAML's ``devices:`` field + # into per-replica subsets so each subprocess we spawn below sees a + # narrowed ``CUDA_VISIBLE_DEVICES`` and doesn't stack on cuda:0. Mirrors + # the head-side per-replica device application at + # ``async_omni_engine.py`` (setup_stage_devices around each launch). + runtime_cfg = getattr(stage_cfg, "runtime", None) + devices_str: str | None = None + if runtime_cfg is not None: + devices_str = ( + runtime_cfg.get("devices") if hasattr(runtime_cfg, "get") else getattr(runtime_cfg, "devices", None) + ) + devices_per_replica = get_stage_devices_per_replica(stage_cfg) + if devices_str: + # Always remap YAML's logical devices through setup_stage_devices, + # even for omni_dp_size_local==1. The launcher's CUDA_VISIBLE_DEVICES + # is dropped from the engine-subprocess env between vllm-serve and + # OmniCoreEngineProcManager.Process, so the worker would otherwise + # default cuda:0 to physical GPU 0 and collide with a co-located + # head on the same host (see hyi3_multi_host_1 reproducer). + per_replica_devices: list[str | None] = split_devices_for_replicas( + devices_str, omni_dp_size_local, devices_per_replica, stage_id + ) + logger.info( + "[Headless] Stage %d: %d local replicas, devices_per_replica=%d, per-replica devices: %s", + stage_id, + omni_dp_size_local, + devices_per_replica, + per_replica_devices, + ) + else: + per_replica_devices = [None] * omni_dp_size_local + device_control_env = current_omni_platform.device_control_env_var + if stage_cfg.stage_type == "diffusion": metadata = extract_stage_metadata(stage_cfg) if omni_conn_cfg: @@ -639,39 +807,107 @@ def run_headless(args: argparse.Namespace) -> None: od_config = build_diffusion_config(model, stage_cfg, metadata) logger.info( - "[Headless] Launching diffusion stage %d replica %d via OmniMasterServer at %s:%d", + "[Headless] Launching %d diffusion replica(s) for stage %d via OmniMasterServer at %s:%d", + omni_dp_size_local, stage_id, - replica_id, omni_master_address, omni_master_port, ) - proc = None + procs: list[Any] = [] try: - handshake_address, request_address, response_address = register_stage_with_omni_master( - omni_master_address=omni_master_address, - omni_master_port=omni_master_port, - omni_stage_id=stage_id, - omni_stage_config=stage_cfg, - return_addresses=True, - replica_id=replica_id, - ) - proc, _, _, _ = spawn_diffusion_proc( - model, - od_config, - handshake_address=handshake_address, - request_address=request_address, - response_address=response_address, + for _rep_idx in range(omni_dp_size_local): + # Always auto-assign: headless processes carry no knowledge + # of their per-replica id and the master server is the sole + # authority on the per-stage id namespace. + response = register_stage_with_omni_master( + omni_master_address=omni_master_address, + omni_master_port=omni_master_port, + omni_stage_id=stage_id, + omni_stage_config=stage_cfg, + replica_id=None, + return_full_response=True, + replica_bind_address=omni_replica_address, + ) + # Apply this replica's CUDA_VISIBLE_DEVICES (only when + # ``--omni-dp-size-local > 1`` and the YAML's stage devices + # field is set). The spawned subprocess inherits the env at + # spawn time; we restore the parent env afterwards so the + # next replica's setup sees the same baseline. + previous_visible_devices = os.environ.get(device_control_env) + try: + if per_replica_devices[_rep_idx] is not None: + setup_stage_devices(stage_id, {"devices": per_replica_devices[_rep_idx]}) + # Each StageDiffusionProc starts its own + # torch.distributed group bound to + # ``od_config.master_port``. Without an explicit + # per-replica override all spawned subprocesses + # share the value ``OmniDiffusionConfig.__post_init__`` + # picked once (and the second binder hits EADDRINUSE + # on ``init_process_group``). We can't use + # kernel-ephemeral allocation either, because the + # master server's pre-allocated ZMQ ports (returned + # by ``register_stage_with_omni_master``) also live + # in the ephemeral range and are not actually bound + # until the headless ``_perform_diffusion_handshake`` + # runs — so picking an ephemeral port here can steal + # a port the master server already promised to a + # sibling headless. Use ``settle_port`` from a base + # above the Linux default ephemeral range + # (32768-60999) so torch.distributed master ports + # never overlap with ZMQ allocations. + if omni_dp_size_local > 1: + od_config.master_port = od_config.settle_port( + 61000 + _rep_idx * 100, + port_inc=37, + ) + proc, _, _, _ = spawn_diffusion_proc( + model, + od_config, + handshake_address=response.handshake_address, + request_address=response.input_address, + response_address=response.output_address, + omni_coordinator_address=response.coordinator_router_address, + omni_stage_id=stage_id, + omni_replica_id=response.replica_id, + ) + finally: + if previous_visible_devices is None: + current_omni_platform.unset_device_control_env_var() + else: + current_omni_platform.set_device_control_env_var(previous_visible_devices) + complete_diffusion_handshake(proc, response.handshake_address, args.stage_init_timeout) + procs.append(proc) + logger.info( + "[Headless] Diffusion replica id=%d for stage %d is up (coord=%s)", + response.replica_id, + stage_id, + response.coordinator_router_address, + ) + + # Block on the sentinel set so any replica crash is detected + # immediately (the previous per-proc join loop only noticed + # crashes in registration order). Any exit triggers fleet + # shutdown via the finally block; non-zero exits propagate. + sentinel_to_proc = {p.sentinel: p for p in procs} + died = connection.wait(list(sentinel_to_proc.keys())) + first = sentinel_to_proc[died[0]] + logger.info( + "[Headless] Diffusion replica %s exited (code=%s); shutting down stage %d.", + first.name, + first.exitcode, + stage_id, ) - complete_diffusion_handshake(proc, handshake_address, args.stage_init_timeout) - proc.join() - if proc.exitcode not in (None, 0): - raise RuntimeError(f"Diffusion stage {stage_id} replica {replica_id} exited with code {proc.exitcode}") + if first.exitcode not in (None, 0): + raise RuntimeError( + f"Diffusion stage {stage_id} replica {first.name!r} exited with code {first.exitcode}" + ) return finally: - logger.info("[Headless] Shutting down stage %d replica %d.", stage_id, replica_id) - if proc is not None and proc.is_alive(): - terminate_alive_proc(proc) + logger.info("[Headless] Shutting down %d diffusion replica(s) for stage %d.", len(procs), stage_id) + for p in procs: + if p.is_alive(): + terminate_alive_proc(p) stage_connector_spec = get_stage_connector_spec( omni_transfer_config=omni_transfer_config, @@ -679,8 +915,10 @@ def run_headless(args: argparse.Namespace) -> None: async_chunk=False, ) - # Device assignment is managed externally (e.g. CUDA_VISIBLE_DEVICES); - # runtime_cfg is intentionally ignored in headless mode. + # ``runtime_cfg`` is mostly inherited from the parent's + # CUDA_VISIBLE_DEVICES; when ``--omni-dp-size-local > 1`` we additionally + # bracket each replica's spawn below with setup_stage_devices so they + # don't all stack on cuda:0 (see ``per_replica_devices`` above). engine_args_dict = build_engine_args_dict( stage_cfg, model, @@ -746,57 +984,116 @@ def signal_handler(signum: int, frame: FrameType | None) -> None: enable_wave_coordination=vllm_config.model_config.is_moe, ) logger.info( - "[Headless] Started DP Coordinator process for stage %d replica %d (PID: %d)", + "[Headless] Started DP Coordinator process for stage %d (PID: %d)", stage_id, - replica_id, coordinator.proc.pid, ) logger.info( - "[Headless] Launching %d engine core(s) for stage %d replica %d via OmniMasterServer at %s:%d", + "[Headless] Launching %d omni replica(s) (vLLM dp_size_local=%d each) for stage %d " + "via OmniMasterServer at %s:%d", + omni_dp_size_local, local_engine_count, stage_id, - replica_id, omni_master_address, omni_master_port, ) - # Headless mode launches all local engine cores for a single stage. - # The OmniMasterServer allocates one handshake endpoint per stage, so we - # register the stage once here and let every local engine core reuse the - # returned handshake address directly. - handshake_address = register_stage_with_omni_master( - omni_master_address=omni_master_address, - omni_master_port=omni_master_port, - omni_stage_id=stage_id, - omni_stage_config=stage_cfg, - coordinator=coordinator, - replica_id=replica_id, - ) - - engine_manager = None + # One OmniMasterServer registration per omni replica; each registration + # yields its own (handshake, input, output) allocation and the head's + # OmniCoordinator ROUTER address. We then spawn one + # OmniCoreEngineProcManager per replica so its subprocess gets the + # right replica id wired into its OmniCoordClientForStage. log_stats = bool(args.log_stats) if args.disable_log_stats: log_stats = False + engine_managers: list[Any] = [] + monitor_threads: list[threading.Thread] = [] + + def _monitor_target(mgr: Any) -> None: + try: + mgr.monitor_engine_liveness() + except Exception: + logger.exception("[Headless] monitor_engine_liveness raised") + try: - engine_manager = CoreEngineProcManager( - local_engine_count=local_engine_count, - start_index=dp_rank, - local_start_index=0, - vllm_config=vllm_config, - local_client=False, - handshake_address=handshake_address, - executor_class=executor_class, - log_stats=log_stats, - ) - # vllm>=0.19 renamed CoreEngineProcManager.join_first() to - # monitor_engine_liveness() (see upstream PR #35862). - engine_manager.monitor_engine_liveness() + for _rep_idx in range(omni_dp_size_local): + # Always auto-assign: see the diffusion branch comment above + # for the rationale (headless owns no replica-id namespace). + response = register_stage_with_omni_master( + omni_master_address=omni_master_address, + omni_master_port=omni_master_port, + omni_stage_id=stage_id, + omni_stage_config=stage_cfg, + coordinator=coordinator, + replica_id=None, + return_full_response=True, + replica_bind_address=omni_replica_address, + # LLM headless: the head binds *all* three sockets — + # handshake ROUTER (``connect_remote_engine_cores``), + # input ROUTER and output PULL (``CoreClient`` — + # ``make_zmq_socket`` defaults bind=True for PULL). + # The remote LLM worker is purely a connector: it + # opens 3 outbound TCP connections to the master's + # host. So the master must keep every address on + # its own host; rewriting any of them to this + # replica's NIC makes the head's ``bind`` go + # EADDRNOTAVAIL on a cross-host launch. + replica_binds_sockets=False, + ) + # Per-replica CUDA_VISIBLE_DEVICES, same pattern as the diffusion + # branch above. OmniCoreEngineProcManager.__init__ spawns its + # subprocesses via context.Process inside the constructor, so we + # must set the env *before* instantiation and restore after. + previous_visible_devices = os.environ.get(device_control_env) + try: + if per_replica_devices[_rep_idx] is not None: + setup_stage_devices(stage_id, {"devices": per_replica_devices[_rep_idx]}) + mgr = OmniCoreEngineProcManager( + local_engine_count=local_engine_count, + start_index=dp_rank, + local_start_index=0, + vllm_config=vllm_config, + local_client=False, + handshake_address=response.handshake_address, + executor_class=executor_class, + log_stats=log_stats, + omni_stage_id=stage_id, + omni_coordinator_address=response.coordinator_router_address, + omni_replica_base_id=response.replica_id, + ) + finally: + if previous_visible_devices is None: + current_omni_platform.unset_device_control_env_var() + else: + current_omni_platform.set_device_control_env_var(previous_visible_devices) + engine_managers.append(mgr) + logger.info( + "[Headless] Stage %d replica id=%d up (coord=%s)", + stage_id, + response.replica_id, + response.coordinator_router_address, + ) + + # Run all managers' liveness monitors in parallel. Each blocks + # until its own subprocesses exit (or fail). + if len(engine_managers) == 1: + engine_managers[0].monitor_engine_liveness() + else: + for mgr in engine_managers: + t = threading.Thread(target=_monitor_target, args=(mgr,), name=f"omni-replica-monitor-{id(mgr):x}") + t.start() + monitor_threads.append(t) + for t in monitor_threads: + t.join() finally: - logger.info("[Headless] Shutting down stage %d.", stage_id) - if engine_manager is not None: - engine_manager.shutdown() + logger.info("[Headless] Shutting down stage %d (%d managers).", stage_id, len(engine_managers)) + for mgr in engine_managers: + try: + mgr.shutdown() + except Exception: + logger.exception("[Headless] engine manager shutdown failed") if coordinator is not None: coordinator.shutdown() diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py index 7b725f469eb..b86c59fcb00 100644 --- a/vllm_omni/entrypoints/stage_utils.py +++ b/vllm_omni/entrypoints/stage_utils.py @@ -95,6 +95,24 @@ def _map_device_list(stage_id: int, device_list: list[str], visible_device_list: if not all(device.isdigit() for device in device_list): raise ValueError("Logical devices must be non-negative integers") + # Idempotency: if every requested id already names an entry in the visible + # pool (by value, not by index), treat the YAML field as physical and skip + # the remap. This avoids a double-mapping crash when a parent harness has + # already narrowed CUDA_VISIBLE_DEVICES to specific physical ids before + # spawning the subprocess (the test runtime's OmniServerStageCli does this + # — it sets each stage subprocess's env via ``_set_stage_device_env``). + # Without this guard, a YAML ``devices: "1"`` against a narrowed visible + # list ``["1"]`` would try ``visible[1]`` (out of range) and raise. + visible_set = set(visible_device_list) + if all(device in visible_set for device in device_list): + logger.info( + "Stage %s requested devices %s are already in the visible set %s; using as-is (no remap).", + stage_id, + device_list, + visible_device_list, + ) + return list(device_list) + logical_ids = [int(device) for device in device_list] mapped_devices = [visible_device_list[idx] for idx in logical_ids if idx < num_visible] mapping_pairs = [