diff --git a/tests/distributed/omni_coordinator/test_load_balancer.py b/tests/distributed/omni_coordinator/test_load_balancer.py new file mode 100644 index 00000000000..c54d2489402 --- /dev/null +++ b/tests/distributed/omni_coordinator/test_load_balancer.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from time import time + +from vllm_omni.distributed.omni_coordinator import ( + InstanceInfo, + RandomBalancer, + StageStatus, +) + + +def test_load_balancer_select_returns_valid_index(): + """Verify RandomBalancer.select() returns a valid index for instances.""" + # Task structure mirrors async_omni; RandomBalancer ignores task contents. + task: dict = { + "request_id": "test", + "engine_inputs": None, + "sampling_params": None, + } + + now = time() + instances = [ + InstanceInfo( + input_addr="tcp://host:10001", + output_addr="tcp://host:10001-out", + stage_id=0, + status=StageStatus.UP, + queue_length=0, + last_heartbeat=now, + registered_at=now, + ), + InstanceInfo( + input_addr="tcp://host:10002", + output_addr="tcp://host:10002-out", + stage_id=0, + status=StageStatus.UP, + queue_length=1, + last_heartbeat=now, + registered_at=now, + ), + InstanceInfo( + input_addr="tcp://host:10003", + output_addr="tcp://host:10003-out", + stage_id=1, + status=StageStatus.UP, + queue_length=2, + last_heartbeat=now, + registered_at=now, + ), + ] + + balancer = RandomBalancer() + + index = balancer.select(task, instances) + + assert isinstance(index, int) + assert 0 <= index < len(instances) diff --git a/tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py b/tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py new file mode 100644 index 00000000000..24b3319232d --- /dev/null +++ b/tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import time + +import pytest +import zmq + +from vllm_omni.distributed.omni_coordinator import ( + InstanceList, + OmniCoordClientForHub, +) + + +def _bind_pub() -> tuple[zmq.Context, zmq.Socket, str]: + ctx = zmq.Context.instance() + pub = ctx.socket(zmq.PUB) + pub.bind("tcp://127.0.0.1:*") + endpoint = pub.getsockopt(zmq.LAST_ENDPOINT).decode("ascii") + return ctx, pub, endpoint + + +def _wait_for_condition(cond, timeout: float = 2.0, interval: float = 0.01) -> bool: + start = time.time() + while time.time() - start < timeout: + if cond(): + return True + time.sleep(interval) + return False + + +def test_hub_client_caches_instance_list_from_pub(): + """Verify OmniCoordClientForHub receives instance list updates from OmniCoordinator and caches for get_instance_list().""" + ctx, pub, endpoint = _bind_pub() + + client = OmniCoordClientForHub(endpoint) + # ZMQ PUB/SUB slow-joiner: allow SUB to finish connecting before first send + time.sleep(0.2) + + now = time.time() + instances_payload = [ + { + "input_addr": "tcp://stage:10001", + "output_addr": "tcp://stage:10001-out", + "stage_id": 0, + "status": "up", + "queue_length": 0, + "last_heartbeat": now, + "registered_at": now, + }, + { + "input_addr": "tcp://stage:10002", + "output_addr": "tcp://stage:10002-out", + "stage_id": 0, + "status": "up", + "queue_length": 1, + "last_heartbeat": now, + "registered_at": now, + }, + { + "input_addr": "tcp://stage:10003", + "output_addr": "tcp://stage:10003-out", + "stage_id": 1, + "status": "error", + "queue_length": 5, + "last_heartbeat": now, + "registered_at": now, + }, + ] + + payload = {"instances": instances_payload, "timestamp": now} + pub.send(json.dumps(payload).encode("utf-8")) + + assert _wait_for_condition(lambda: len(client.get_instance_list().instances) == 3) + + inst_list = client.get_instance_list() + assert isinstance(inst_list, InstanceList) + assert len(inst_list.instances) == 3 + + for src, inst in zip(instances_payload, inst_list.instances, strict=True): + assert inst.input_addr == src["input_addr"] + assert inst.output_addr == src["output_addr"] + assert inst.stage_id == src["stage_id"] + assert inst.status.value == src["status"] + + stage0 = client.get_instances_for_stage(0) + stage1 = client.get_instances_for_stage(1) + + assert all(inst.stage_id == 0 for inst in stage0.instances) + assert all(inst.stage_id == 1 for inst in stage1.instances) + + # Send an updated list with fewer instances and verify cache refresh. + updated_payload = { + "instances": instances_payload[:2], + "timestamp": now + 1.0, + } + pub.send(json.dumps(updated_payload).encode("utf-8")) + + assert _wait_for_condition(lambda: len(client.get_instance_list().instances) == 2) + updated_list = client.get_instance_list() + assert len(updated_list.instances) == 2 + + client.close() + pub.close(0) + ctx.term() + + +def test_hub_client_close_closes_sub_socket(): + """Verify OmniCoordClientForHub.close() marks client as closed; second close raises.""" + ctx, pub, endpoint = _bind_pub() + client = OmniCoordClientForHub(endpoint) + client.close() + + with pytest.raises(RuntimeError, match="already closed"): + client.close() + + pub.close(0) + ctx.term() diff --git a/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py b/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py new file mode 100644 index 00000000000..b74a48f49cd --- /dev/null +++ b/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import zmq + +from vllm_omni.distributed.omni_coordinator import ( + OmniCoordClientForStage, + StageStatus, +) + + +def _bind_router() -> tuple[zmq.Context, zmq.Socket, str]: + ctx = zmq.Context.instance() + router = ctx.socket(zmq.ROUTER) + router.bind("tcp://127.0.0.1:*") + endpoint = router.getsockopt(zmq.LAST_ENDPOINT).decode("ascii") + return ctx, router, endpoint + + +def _recv_event(router: zmq.Socket) -> dict: + frames = router.recv_multipart() + # ROUTER adds identity frame; the last frame is the payload. + payload = frames[-1] + return json.loads(payload.decode("utf-8")) + + +def test_stage_client_auto_register_on_init(): + """Verify OmniCoordClientForStage automatically sends initial registration/status-up event when created.""" + ctx, router, endpoint = _bind_router() + + input_addr = "tcp://stage:10001" + output_addr = "tcp://stage:10001-out" + stage_id = 0 + + client = OmniCoordClientForStage(endpoint, input_addr, output_addr, stage_id) + + event = _recv_event(router) + + assert event["event_type"] == "update" + assert event["status"] == StageStatus.UP.value + assert event["stage_id"] == stage_id + assert event["input_addr"] == input_addr + assert event["output_addr"] == output_addr + + client.close() + router.close(0) + ctx.term() + + +def test_stage_client_update_info_sends_correct_event(): + """Verify OmniCoordClientForStage.update_info() sends status/load update events with expected fields.""" + ctx, router, endpoint = _bind_router() + + input_addr = "tcp://stage:10002" + output_addr = "tcp://stage:10002-out" + stage_id = 1 + + client = OmniCoordClientForStage(endpoint, input_addr, output_addr, stage_id) + + # Discard initial registration event. + _recv_event(router) + + client.update_info(status=StageStatus.ERROR) + client.update_info(queue_length=10) + + first = _recv_event(router) + second = _recv_event(router) + + assert first["status"] == StageStatus.ERROR.value + assert first["stage_id"] == stage_id + assert first["input_addr"] == input_addr + assert first["output_addr"] == output_addr + + assert second["queue_length"] == 10 + assert second["stage_id"] == stage_id + assert second["input_addr"] == input_addr + assert second["output_addr"] == output_addr + + client.close() + router.close(0) + ctx.term() + + +def test_stage_client_close_sends_down_status(): + """Verify close() sends final status-down event before closing underlying socket.""" + ctx, router, endpoint = _bind_router() + + input_addr = "tcp://stage:10003" + output_addr = "tcp://stage:10003-out" + stage_id = 2 + + client = OmniCoordClientForStage(endpoint, input_addr, output_addr, stage_id) + + # Discard initial registration event. + _recv_event(router) + + client.close() + + event = _recv_event(router) + assert event["status"] == StageStatus.DOWN.value + assert event["stage_id"] == stage_id + assert event["input_addr"] == input_addr + assert event["output_addr"] == output_addr + + assert client._socket.closed # DEALER socket no longer usable after close + + router.close(0) + ctx.term() diff --git a/tests/distributed/omni_coordinator/test_omni_coordinator.py b/tests/distributed/omni_coordinator/test_omni_coordinator.py new file mode 100644 index 00000000000..0c68e61bb11 --- /dev/null +++ b/tests/distributed/omni_coordinator/test_omni_coordinator.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import time + +import zmq +from vllm.v1.utils import get_engine_client_zmq_addr + +from vllm_omni.distributed.omni_coordinator import ( + OmniCoordClientForStage, + OmniCoordinator, + StageStatus, +) + + +def _recv_instance_list(sub: zmq.Socket, timeout_ms: int = 2000) -> dict | None: + """Receive InstanceList JSON from SUB socket. Returns None on timeout.""" + sub.setsockopt(zmq.RCVTIMEO, timeout_ms) + try: + data = sub.recv() + return json.loads(data.decode("utf-8")) + except zmq.Again: + return None + + +def _wait_for_instance_list( + sub: zmq.Socket, + expected_count: int, + timeout: float = 3.0, +) -> dict | None: + """Wait until received InstanceList with expected_count active instances.""" + start = time.time() + while time.time() - start < timeout: + msg = _recv_instance_list(sub, timeout_ms=500) + if msg is not None and len(msg.get("instances", [])) == expected_count: + return msg + return None + + +def test_omni_coordinator_registration_broadcast(): + """Verify that after multiple OmniCoordClientForStage instances register, + OmniCoordinator publishes an InstanceList containing all registered instances. + """ + router_addr = get_engine_client_zmq_addr( + local_only=False, + host="127.0.0.1", + port=0, + ) + pub_addr = get_engine_client_zmq_addr( + local_only=False, + host="127.0.0.1", + port=0, + ) + coordinator = OmniCoordinator( + router_zmq_addr=router_addr, + pub_zmq_addr=pub_addr, + heartbeat_timeout=1000.0, + ) + + sub_ctx = zmq.Context.instance() + sub = sub_ctx.socket(zmq.SUB) + sub.connect(pub_addr) + sub.setsockopt(zmq.SUBSCRIBE, b"") + + # ZMQ PUB/SUB slow-joiner: allow SUB to connect before clients register. + time.sleep(0.3) + + # Create 3 stage clients; each auto-registers on init. + clients = [ + OmniCoordClientForStage(router_addr, "tcp://stage:10001", "tcp://stage:10001-out", 0), + OmniCoordClientForStage(router_addr, "tcp://stage:10002", "tcp://stage:10002-out", 0), + OmniCoordClientForStage(router_addr, "tcp://stage:10003", "tcp://stage:10003-out", 1), + ] + + msg = _wait_for_instance_list(sub, expected_count=3) + assert msg is not None, "Expected InstanceList with 3 instances" + assert len(msg["instances"]) == 3 + assert isinstance(msg["timestamp"], (int, float)) + + input_addrs = {inst["input_addr"] for inst in msg["instances"]} + assert "tcp://stage:10001" in input_addrs + assert "tcp://stage:10002" in input_addrs + assert "tcp://stage:10003" in input_addrs + + for c in clients: + c.close() + coordinator.close() + sub.close(0) + sub_ctx.term() + + +def test_omni_coordinator_heartbeat_timeout_handling(): + """Verify that when a stage instance stops sending heartbeats, + OmniCoordinator marks it as unhealthy and excludes it from the active list. + """ + router_addr = get_engine_client_zmq_addr( + local_only=False, + host="127.0.0.1", + port=0, + ) + pub_addr = get_engine_client_zmq_addr( + local_only=False, + host="127.0.0.1", + port=0, + ) + coordinator = OmniCoordinator( + router_zmq_addr=router_addr, + pub_zmq_addr=pub_addr, + heartbeat_timeout=5.0, + ) + + sub_ctx = zmq.Context.instance() + sub = sub_ctx.socket(zmq.SUB) + sub.connect(pub_addr) + sub.setsockopt(zmq.SUBSCRIBE, b"") + + time.sleep(0.3) + + # A and B: real clients that send heartbeats every 5s. + client_a = OmniCoordClientForStage(router_addr, "tcp://stage:a", "tcp://stage:a-out", 0) + client_b = OmniCoordClientForStage(router_addr, "tcp://stage:b", "tcp://stage:b-out", 0) + + # C: raw DEALER that sends only registration, no heartbeat. + dealer_ctx = zmq.Context.instance() + dealer_c = dealer_ctx.socket(zmq.DEALER) + dealer_c.connect(router_addr) + reg_event = { + "input_addr": "tcp://stage:c", + "output_addr": "tcp://stage:c-out", + "stage_id": 0, + "event_type": "update", + "status": StageStatus.UP.value, + "queue_length": 0, + } + dealer_c.send(json.dumps(reg_event).encode("utf-8")) + + msg = _wait_for_instance_list(sub, expected_count=3) + assert msg is not None, "Expected initial 3 instances" + assert len(msg["instances"]) == 3 + + # Wait for heartbeat timeout (timeout=5s, check interval ~2.5s). + time.sleep(8.0) + + # Receive the update (C should be ERROR and excluded from active list). + msg_after_timeout = _wait_for_instance_list(sub, expected_count=2, timeout=5.0) + assert msg_after_timeout is not None, "Expected InstanceList with 2 instances after timeout" + instances = msg_after_timeout.get("instances", []) + input_addrs = {inst["input_addr"] for inst in instances} + + assert "tcp://stage:a" in input_addrs + assert "tcp://stage:b" in input_addrs + assert "tcp://stage:c" not in input_addrs + + client_a.close() + client_b.close() + dealer_c.close(0) + coordinator.close() + sub.close(0) + dealer_ctx.term() + sub_ctx.term() + + +def test_omni_coordinator_instance_shutdown_handling(): + """Verify that when a stage instance sends status='down', + OmniCoordinator removes it from the active list and broadcasts an updated list. + """ + router_addr = get_engine_client_zmq_addr( + local_only=False, + host="127.0.0.1", + port=0, + ) + pub_addr = get_engine_client_zmq_addr( + local_only=False, + host="127.0.0.1", + port=0, + ) + coordinator = OmniCoordinator( + router_zmq_addr=router_addr, + pub_zmq_addr=pub_addr, + heartbeat_timeout=1000.0, + ) + + sub_ctx = zmq.Context.instance() + sub = sub_ctx.socket(zmq.SUB) + sub.connect(pub_addr) + sub.setsockopt(zmq.SUBSCRIBE, b"") + + time.sleep(0.3) # PUB/SUB slow-joiner + + client = OmniCoordClientForStage(router_addr, "tcp://stage:shutdown", "tcp://stage:shutdown-out", 0) + + msg = _wait_for_instance_list(sub, expected_count=1) + assert msg is not None + assert len(msg["instances"]) == 1 + assert msg["instances"][0]["input_addr"] == "tcp://stage:shutdown" + + # Send down status (simulating graceful shutdown). + client.update_info(status=StageStatus.DOWN) + + # Receive updated list (should have 0 active instances). + msg = _wait_for_instance_list(sub, expected_count=0) + assert msg is not None + assert len(msg["instances"]) == 0 + + client.close() + coordinator.close() + sub.close(0) + sub_ctx.term() diff --git a/vllm_omni/distributed/omni_coordinator/__init__.py b/vllm_omni/distributed/omni_coordinator/__init__.py new file mode 100644 index 00000000000..cbef920d4be --- /dev/null +++ b/vllm_omni/distributed/omni_coordinator/__init__.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .load_balancer import ( + LoadBalancer, + LoadBalancingPolicy, + RandomBalancer, + Task, +) +from .messages import InstanceEvent, InstanceInfo, InstanceList, StageStatus +from .omni_coord_client_for_hub import OmniCoordClientForHub +from .omni_coord_client_for_stage import OmniCoordClientForStage +from .omni_coordinator import OmniCoordinator + +__all__ = [ + "OmniCoordinator", + "StageStatus", + "InstanceEvent", + "InstanceInfo", + "InstanceList", + "OmniCoordClientForStage", + "OmniCoordClientForHub", + "Task", + "LoadBalancer", + "LoadBalancingPolicy", + "RandomBalancer", +] diff --git a/vllm_omni/distributed/omni_coordinator/load_balancer.py b/vllm_omni/distributed/omni_coordinator/load_balancer.py new file mode 100644 index 00000000000..15a079b0a87 --- /dev/null +++ b/vllm_omni/distributed/omni_coordinator/load_balancer.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, TypedDict + +from .messages import InstanceInfo + + +class Task(TypedDict, total=False): + """Task structure passed from async_omni (stage.submit(task)). + + Mirrors the dict built in AsyncOmni with request_id, engine_inputs, + sampling_params. Future load-balancing policies may use these fields. + """ + + request_id: str + engine_inputs: Any + sampling_params: Any + + +class LoadBalancingPolicy(str, Enum): + """Enumeration for load balancing policies. + + Only ``RANDOM`` is implemented. Additional policies (e.g. round-robin, + least-connections) can be added in the future. + """ + + RANDOM = "random" + + +class LoadBalancer(ABC): + """Abstract base class for load balancers. + + Subclasses implement :meth:`select` to choose an instance for a given task. + """ + + @abstractmethod + def select(self, task: Task, instances: list[InstanceInfo]) -> int: + """Route a task to one of the available instances. + + Args: + task: The task to route. Not used by the random policy but reserved + for future strategies that may inspect task metadata. + instances: List of available instances to choose from. + + Returns: + Index of the selected instance in ``instances``. + + Raises: + ValueError: If ``instances`` is empty. + """ + + raise NotImplementedError + + +class RandomBalancer(LoadBalancer): + """Load balancer that selects an instance uniformly at random. + + This is the initial and only policy supported. It intentionally ignores + the task payload and chooses a random index from the provided instance + list. More sophisticated policies (e.g. round-robin, least-connections) + can be implemented as additional subclasses of :class:`LoadBalancer`. + """ + + def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002 + if not instances: + raise ValueError("instances must not be empty") + + return random.randrange(len(instances)) + + +__all__ = [ + "Task", + "LoadBalancingPolicy", + "LoadBalancer", + "RandomBalancer", +] diff --git a/vllm_omni/distributed/omni_coordinator/messages.py b/vllm_omni/distributed/omni_coordinator/messages.py new file mode 100644 index 00000000000..2bb590139e2 --- /dev/null +++ b/vllm_omni/distributed/omni_coordinator/messages.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + + +class StageStatus(str, Enum): + """Enumeration for stage instance status.""" + + UP = "up" # Instance is ready and available + DOWN = "down" # Instance is shutdown gracefully + ERROR = "error" # Instance encountered an error or timeout + + +@dataclass +class InstanceEvent: + """Wire payload from OmniCoordClientForStage to OmniCoordinator. + + Schema for Stage → Coordinator events over ZMQ: + input_addr, output_addr, stage_id, status, queue_length, event_type. + """ + + input_addr: str # Stage instance input ZMQ address (e.g., "tcp://host:port") + output_addr: str # Stage instance output ZMQ address (e.g., "tcp://host:port") + stage_id: int # Stage ID + event_type: str # "update" | "heartbeat" + status: StageStatus # Current status + queue_length: int # Current queue length + + +@dataclass +class InstanceInfo: + """Metadata for a single stage instance. + + This type is stored in OmniCoordinator's internal registry and is also + published to hubs via :class:`InstanceList`. + """ + + input_addr: str # Stage instance input ZMQ address (e.g., "tcp://host:port") + output_addr: str # Stage instance output ZMQ address (e.g., "tcp://host:port") + stage_id: int # Stage ID of this instance + status: StageStatus # Current status of the instance + queue_length: int # Current queue length of this instance + last_heartbeat: float # Timestamp of the last heartbeat received (seconds) + registered_at: float # Timestamp when the instance was registered (seconds) + + +@dataclass +class InstanceList: + """Container for instance list updates. + + OmniCoordinator publishes an :class:`InstanceList` whenever its view of + active instances changes. OmniCoordClientForHub caches the latest value + and exposes it to AsyncOmni and the load balancer. + """ + + instances: list[InstanceInfo] + timestamp: float # Time when the list was last updated (seconds) diff --git a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_hub.py b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_hub.py new file mode 100644 index 00000000000..9081e45917c --- /dev/null +++ b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_hub.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import logging +import threading +from time import sleep, time +from typing import Any + +import zmq + +from .messages import InstanceInfo, InstanceList, StageStatus + +logger = logging.getLogger(__name__) + + +class OmniCoordClientForHub: + """Client for AsyncOmni side to receive instance list updates. + + This client maintains a SUB socket connected to OmniCoordinator's PUB + endpoint and caches the latest :class:`InstanceList` in memory for use by + the load balancer and routing logic. + """ + + def __init__(self, coord_zmq_addr: str) -> None: + """Initialize client and start receive thread (socket created in thread).""" + self._coord_zmq_addr = coord_zmq_addr + + self._ctx = zmq.Context() + self._lock = threading.Lock() + self._instance_list: InstanceList | None = None + self._closed = False + self._stop_event = threading.Event() + self._init_done = threading.Event() + self._init_error: list[BaseException] = [] + + self._thread = threading.Thread(target=self._recv_loop, daemon=True) + self._thread.start() + + self._init_done.wait(timeout=5.0) + if self._init_error: + raise RuntimeError(f"Failed to connect to coordinator at {self._coord_zmq_addr}") from self._init_error[0] + + def _decode_instance_list(self, payload: dict[str, Any]) -> InstanceList: + """Convert a JSON-decoded dict into an :class:`InstanceList`.""" + instances_payload = payload.get("instances", []) + instances: list[InstanceInfo] = [] + for inst in instances_payload: + instances.append( + InstanceInfo( + input_addr=inst["input_addr"], + output_addr=inst["output_addr"], + stage_id=int(inst["stage_id"]), + status=StageStatus(inst["status"]), + queue_length=int(inst["queue_length"]), + last_heartbeat=float(inst["last_heartbeat"]), + registered_at=float(inst["registered_at"]), + ) + ) + + timestamp = float(payload.get("timestamp", time())) + return InstanceList(instances=instances, timestamp=timestamp) + + def _recv_loop(self) -> None: + """Background loop that receives and caches instance lists.""" + sub = None + try: + sub = self._ctx.socket(zmq.SUB) + sub.setsockopt(zmq.SUBSCRIBE, b"") + sub.setsockopt(zmq.RCVTIMEO, 100) # 100ms timeout, avoids busy-wait + sub.connect(self._coord_zmq_addr) + except (zmq.ZMQError, OSError) as e: + self._init_error.append(e) + sub = None + finally: + self._init_done.set() + + try: + while not self._stop_event.is_set(): + # (Re)create and connect SUB socket if needed. + if sub is None: + try: + sub = self._ctx.socket(zmq.SUB) + sub.setsockopt(zmq.SUBSCRIBE, b"") + sub.setsockopt(zmq.RCVTIMEO, 100) # 100ms timeout, avoids busy-wait + sub.connect(self._coord_zmq_addr) + except (zmq.ZMQError, OSError) as e: + logger.error( + "Hub client failed to connect to coordinator at %s, will retry", + self._coord_zmq_addr, + exc_info=e, + ) + if sub is not None: + try: + sub.close() + except zmq.ZMQError: + pass + sub = None + sleep(1.0) + continue + + try: + data = sub.recv() + except zmq.Again: + continue + except zmq.ZMQError as e: + logger.error("Hub client recv failed, will reconnect", exc_info=e) + try: + sub.close() + except zmq.ZMQError: + pass + sub = None + sleep(1.0) + continue + + try: + payload = json.loads(data.decode("utf-8")) + inst_list = self._decode_instance_list(payload) + with self._lock: + self._instance_list = inst_list + except ( + json.JSONDecodeError, + KeyError, + ValueError, + TypeError, + AttributeError, + ) as e: + logger.warning("Invalid instance list message, skipping: %s", e) + finally: + try: + if sub is not None: + sub.close() + except zmq.ZMQError: + pass + try: + self._ctx.term() + except zmq.ZMQError: + pass + + def get_instance_list(self) -> InstanceList: + """Return the latest cached :class:`InstanceList`. + + If no update has been received yet, returns an empty list with + ``timestamp=0.0``. + """ + with self._lock: + if self._instance_list is None: + return InstanceList(instances=[], timestamp=0.0) + return self._instance_list + + def get_instances_for_stage(self, stage_id: int) -> InstanceList: + """Return instances filtered by ``stage_id``.""" + base = self.get_instance_list() + filtered = [inst for inst in base.instances if inst.stage_id == stage_id] + return InstanceList(instances=filtered, timestamp=base.timestamp) + + def close(self) -> None: + """Close the SUB socket and stop the background thread.""" + if self._closed: + raise RuntimeError("Client already closed") + + self._closed = True + self._stop_event.set() + self._thread.join(timeout=1.0) diff --git a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py new file mode 100644 index 00000000000..cd5c357bb4e --- /dev/null +++ b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import logging +import threading +import time +from dataclasses import asdict + +import zmq + +from .messages import InstanceEvent, StageStatus + +logger = logging.getLogger(__name__) + + +class OmniCoordClientForStage: + """Client used by stage instances to send events to OmniCoordinator. + + This client maintains a DEALER socket connected to OmniCoordinator's + ROUTER endpoint and sends JSON-encoded events describing instance status. + """ + + def __init__( + self, + coord_zmq_addr: str, + input_addr: str, + output_addr: str, + stage_id: int, + ) -> None: + """Initialize client and send initial registration / status-up event.""" + self._coord_zmq_addr = coord_zmq_addr + self._input_addr = input_addr + self._output_addr = output_addr + self._stage_id = stage_id + + self._ctx = zmq.Context() + self._socket = self._ctx.socket(zmq.DEALER) + try: + self._socket.connect(self._coord_zmq_addr) + except zmq.ZMQError as e: + self._socket.close() + raise RuntimeError(f"Failed to connect to coordinator at {self._coord_zmq_addr}: {e}") from e + + self._status = StageStatus.UP + self._queue_length = 0 + self._closed = False + self._heartbeat_interval = 5.0 + self._stop_event = threading.Event() + self._send_lock = threading.Lock() + + self._send_event("update") + + self._heartbeat_thread = threading.Thread( + target=self._heartbeat_loop, + daemon=True, + ) + self._heartbeat_thread.start() + + def _reconnect(self) -> bool: + """Best-effort reconnect with up to ``max_retries`` attempts. + + Each attempt closes the current socket/context, sleeps 5 seconds, + then creates a new DEALER socket and reconnects to the coordinator. + Caller must hold ``_send_lock``. + Returns True on success, False if all attempts fail. + """ + while not self._stop_event.is_set() and not self._closed: + try: + self._socket.close(0) + except zmq.ZMQError: + pass + try: + self._ctx.term() + except zmq.ZMQError: + pass + + time.sleep(5.0) + + try: + self._ctx = zmq.Context() + self._socket = self._ctx.socket(zmq.DEALER) + self._socket.connect(self._coord_zmq_addr) + return True + except zmq.ZMQError as e: + logger.error( + "Stage client reconnect failed, will retry in 5s (coord=%s)", + self._coord_zmq_addr, + exc_info=e, + ) + continue + return False + + def _send_event(self, event_type: str) -> None: + """Send an InstanceEvent to OmniCoordinator. + + Wire format: input_addr, output_addr, stage_id, status, queue_length, event_type. + For "update": includes status and queue_length from instance state. + For "heartbeat": status and queue_length are null. + + On send failure (ZMQError / RuntimeError), attempts to reconnect up + to 3 times (5s sleep each) and retries the send once after a + successful reconnect. Raises if reconnect or the retry send fails. + """ + if self._closed: + raise RuntimeError("Client already closed") + + event = InstanceEvent( + input_addr=self._input_addr, + output_addr=self._output_addr, + stage_id=self._stage_id, + event_type=event_type, + status=self._status, + queue_length=self._queue_length, + ) + data = json.dumps(asdict(event)).encode("utf-8") + + with self._send_lock: + try: + self._socket.send(data, flags=zmq.NOBLOCK) + return + except zmq.Again: + logger.debug("Send buffer full, dropping event") + return + except (RuntimeError, zmq.ZMQError) as e: + # First send failed; try reconnecting a few times. + if not self._reconnect: + logger.error("Failed to send event and reconnect to coordinator", exc_info=e) + raise + + # Reconnected successfully; try sending once more. + try: + self._socket.send(data, flags=zmq.NOBLOCK) + except zmq.Again: + logger.debug("Send buffer full after reconnect, dropping event") + except (RuntimeError, zmq.ZMQError) as e2: + logger.error("Failed to send event after successful reconnect", exc_info=e2) + raise + + def update_info( + self, + status: StageStatus | None = None, + queue_length: int | None = None, + ) -> None: + """Update instance information and notify OmniCoordinator. + + At least one of ``status`` or ``queue_length`` must be provided. + """ + if status is None and queue_length is None: + raise ValueError("At least one of status or queue_length must be provided") + + if status is not None: + self._status = status + if queue_length is not None: + self._queue_length = queue_length + + self._send_event("update") + + def _heartbeat_loop(self) -> None: + """Periodically send heartbeat events while the client is alive.""" + while not self._stop_event.wait(timeout=self._heartbeat_interval): + if self._closed: + break + + try: + self._send_event("heartbeat") + except (RuntimeError, zmq.ZMQError): + break + + def close(self) -> None: + """Send a final down event and close the underlying socket.""" + if self._closed: + raise RuntimeError("Client already closed") + + # Stop heartbeat thread first to avoid concurrent sends during shutdown. + self._stop_event.set() + if hasattr(self, "_heartbeat_thread"): + self._heartbeat_thread.join(timeout=1.0) + + # Mark status as DOWN and send one last update. + self._status = StageStatus.DOWN + try: + self._send_event("update") + except zmq.ZMQError: + pass # Socket may already be broken, proceed with close + + # Close DEALER socket and terminate this client's context. + self._socket.close(0) + try: + self._ctx.term() + except zmq.ZMQError: + pass + self._closed = True diff --git a/vllm_omni/distributed/omni_coordinator/omni_coordinator.py b/vllm_omni/distributed/omni_coordinator/omni_coordinator.py new file mode 100644 index 00000000000..7ff608f2fa3 --- /dev/null +++ b/vllm_omni/distributed/omni_coordinator/omni_coordinator.py @@ -0,0 +1,330 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import json +import logging +import threading +from dataclasses import asdict +from time import time +from typing import Any + +import zmq + +from .messages import InstanceEvent, InstanceInfo, InstanceList, StageStatus + +logger = logging.getLogger(__name__) + + +class OmniCoordinator: + """Coordinator for stage instances and hub clients. + + This service receives instance events from :class:`OmniCoordClientForStage` + via a ZMQ ROUTER socket and publishes active instance lists to + :class:`OmniCoordClientForHub` via a PUB socket. + + The coordinator maintains an in-memory registry of all known instances, + including their status, queue length, and heartbeat timestamps. A + background thread periodically checks for heartbeat timeouts and marks + unhealthy instances as ``StageStatus.ERROR``. + """ + + def __init__( + self, + router_zmq_addr: str, + pub_zmq_addr: str, + heartbeat_timeout: float = 30.0, + ) -> None: + """Initialize coordinator and start background service loops. + + Args: + router_zmq_addr: ZMQ address to bind the ROUTER socket. + pub_zmq_addr: ZMQ address to bind the PUB socket. + heartbeat_timeout: Seconds before an instance is considered + unhealthy if no heartbeat / update is received. + """ + self._router_zmq_addr = router_zmq_addr + self._pub_zmq_addr = pub_zmq_addr + self._heartbeat_timeout = heartbeat_timeout + + # Dedicated ZMQ context for this coordinator instance. + self._ctx = zmq.Context() + self._router = self._ctx.socket(zmq.ROUTER) + self._router.bind(self._router_zmq_addr) + + self._pub = self._ctx.socket(zmq.PUB) + self._pub.bind(self._pub_zmq_addr) + + self._instances: dict[str, InstanceInfo] = {} + self._lock = threading.Lock() + self._pub_lock = threading.Lock() + + self._publish_min_interval: float = 0.1 # seconds + self._pending_broadcast: bool = False + + self._running = True + self._closed = False + self._stop_event = threading.Event() + + self._router.setsockopt(zmq.RCVTIMEO, 100) + + self._recv_thread = threading.Thread(target=self._recv_loop, daemon=True) + self._recv_thread.start() + + self._periodic_thread = threading.Thread(target=self._periodic_loop, daemon=True) + self._periodic_thread.start() + + def get_active_instances(self) -> InstanceList: + """Return an :class:`InstanceList` of active (UP) instances only.""" + with self._lock: + active = [inst for inst in self._instances.values() if inst.status == StageStatus.UP] + return InstanceList(instances=active, timestamp=time()) + + def add_new_instance(self, event: InstanceEvent) -> None: + """Add a new instance based on an incoming event.""" + with self._lock: + self._add_new_instance_locked(event) + self.publish_instance_list_update() + + def update_instance_info(self, event: InstanceEvent) -> None: + """Update an existing instance based on an incoming event.""" + with self._lock: + self._update_instance_info_locked(event) + self.publish_instance_list_update() + + def remove_instance(self, event: InstanceEvent) -> None: + """Mark an instance as removed / down based on an incoming event. + + This marks the instance's status as DOWN or ERROR (depending on the + event) but keeps it in the internal registry. It is removed from the + *active* instance list published to hubs. + """ + with self._lock: + self._remove_instance_locked(event) + self.publish_instance_list_update() + + def publish_instance_list_update(self) -> None: + """Publish the current active instance list to all subscribers.""" + active_list = self.get_active_instances() + payload = asdict(active_list) + data = json.dumps(payload).encode("utf-8") + + with self._pub_lock: + try: + # PUB socket is best-effort; drop update if not ready. + self._pub.send(data, flags=zmq.NOBLOCK) + except (zmq.Again, zmq.ZMQError): + # Silently ignore send failures; next update will catch up. + return + + def _schedule_broadcast(self, force: bool) -> None: + """Schedule a broadcast, optionally bypassing throttling. + + When ``force`` is True, publish immediately. Otherwise, mark a pending + broadcast that will be flushed by the periodic loop at most once per + ``_publish_min_interval``. + """ + if force: + self.publish_instance_list_update() + else: + self._pending_broadcast = True + + def _mark_instance_error_locked(self, info: InstanceInfo) -> None: + """Mark instance as ERROR (e.g. after heartbeat timeout).""" + info.status = StageStatus.ERROR + + def _check_heartbeat_timeouts(self) -> None: + """Mark instances as ERROR if their heartbeat has timed out.""" + now = time() + timed_out = False + gc_ttl = 600.0 # 10 minutes + + with self._lock: + to_delete: list[str] = [] + + for input_addr, info in self._instances.items(): + if info.status == StageStatus.UP and now - info.last_heartbeat > self._heartbeat_timeout: + self._mark_instance_error_locked(info) + timed_out = True + elif info.status in (StageStatus.DOWN, StageStatus.ERROR) and now - info.last_heartbeat > gc_ttl: + to_delete.append(input_addr) + + for input_addr in to_delete: + del self._instances[input_addr] + if timed_out: + # Instance liveness changed; force immediate broadcast. + self._schedule_broadcast(force=True) + + def close(self) -> None: + """Shut down background threads and close all ZMQ sockets.""" + if self._closed: + raise RuntimeError("Coordinator already closed") + + self._closed = True + self._running = False + self._stop_event.set() + + # Wait for threads to exit before closing sockets. + for thread in (self._recv_thread, self._periodic_thread): + thread.join(timeout=1.0) + + try: + self._router.close(0) + except zmq.ZMQError: + pass + + try: + self._pub.close(0) + except zmq.ZMQError: + pass + + try: + self._ctx.term() + except zmq.ZMQError: + pass + + def _parse_instance_event(self, data: dict[str, Any]) -> InstanceEvent | None: + """Parse wire payload dict into InstanceEvent. Returns None if invalid.""" + try: + return InstanceEvent( + input_addr=str(data["input_addr"]), + output_addr=str(data["output_addr"]), + stage_id=int(data["stage_id"]), + event_type=str(data["event_type"]), + status=StageStatus(data.get("status")), + queue_length=data.get("queue_length"), + ) + except (KeyError, ValueError, TypeError): + return None + + def _recv_loop(self) -> None: + """Background loop that receives and processes instance events.""" + while self._running: + try: + frames = self._router.recv_multipart() + except zmq.Again: + # RCVTIMEO expired, loop to recheck _running. + continue + except zmq.ZMQError: + # Socket likely closed or context terminated. + break + + if not frames: + continue + + payload = frames[-1] + try: + data = json.loads(payload.decode("utf-8")) + event = self._parse_instance_event(data) + except json.JSONDecodeError as e: + logger.warning("Invalid JSON in instance event, dropping: %s", e) + continue + if event is None: + logger.warning("Malformed instance event, dropping") + continue + + self._handle_event(event) + + def _periodic_loop(self) -> None: + """Periodic loop to check heartbeat timeouts and flush broadcasts. + + Heartbeat timeouts are checked on their original cadence, while + queue_length / non-liveness updates are coalesced and flushed at + most once per ``_publish_min_interval``. + """ + heartbeat_interval = max(1.0, min(self._heartbeat_timeout / 2.0, 5.0)) + loop_interval = self._publish_min_interval + + last_heartbeat_check = 0.0 + while self._running: + now = time() + + if now - last_heartbeat_check >= heartbeat_interval: + self._check_heartbeat_timeouts() + last_heartbeat_check = now + + if self._pending_broadcast: + self.publish_instance_list_update() + self._pending_broadcast = False + + if self._stop_event.wait(timeout=loop_interval): + break + + def _handle_event(self, event: InstanceEvent) -> None: + """Dispatch an incoming event to the appropriate handler.""" + try: + input_addr = event.input_addr + + # Heartbeat: only update last_heartbeat; if previously ERROR, + # promote back to UP and broadcast once. + if event.event_type == "heartbeat": + promote = False + with self._lock: + info = self._instances.get(input_addr) + if info is not None: + info.last_heartbeat = time() + if info.status == StageStatus.ERROR: + info.status = StageStatus.UP + promote = True + if promote: + self._schedule_broadcast(force=True) + return + + # Check-and-act under single lock to avoid TOCTOU race (duplicate + # registration when concurrent events arrive for the same instance). + with self._lock: + force_broadcast = False + if input_addr not in self._instances: + self._add_new_instance_locked(event) + force_broadcast = True + else: + if event.status == StageStatus.DOWN: + self._remove_instance_locked(event) + force_broadcast = True + else: + self._update_instance_info_locked(event) + + # New instances / DOWN events are broadcast immediately; other + # updates (e.g. queue_length changes) are throttled via the + # periodic loop. + self._schedule_broadcast(force=force_broadcast) + except (KeyError, ValueError, TypeError) as e: + logger.warning("Dropping malformed event: %s", e) + + def _add_new_instance_locked(self, event: InstanceEvent) -> None: + input_addr = event.input_addr + if not input_addr: + raise KeyError("input_addr required") + stage_id = event.stage_id + if stage_id < 0: + raise KeyError("stage_id required and must be non-negative") + + now = time() + info = InstanceInfo( + input_addr=input_addr, + output_addr=event.output_addr, + stage_id=stage_id, + status=event.status, + queue_length=event.queue_length, + last_heartbeat=now, + registered_at=now, + ) + self._instances[input_addr] = info + + def _update_instance_info_locked(self, event: InstanceEvent) -> None: + input_addr = event.input_addr + info = self._instances[input_addr] + + if event.status is not None: + info.status = event.status + + if event.queue_length is not None: + info.queue_length = event.queue_length + + def _remove_instance_locked(self, event: InstanceEvent) -> None: + input_addr = event.input_addr + info = self._instances.get(input_addr) + if info is None: + return + + info.status = StageStatus.DOWN