From 5c44ca58d2db3d0494ae323f3c352504b2db9a72 Mon Sep 17 00:00:00 2001 From: chickeyton Date: Mon, 9 Feb 2026 13:40:19 +0800 Subject: [PATCH] add DpCoordinator LoadBalancer --- .../distributed/dp_coordinator/__init__.py | 77 +++ .../dp_coordinator/client_for_hub.py | 353 ++++++++++ .../dp_coordinator/client_for_stage.py | 209 ++++++ .../dp_coordinator/dp_coordinator.py | 610 ++++++++++++++++++ .../distributed/dp_coordinator/messages.py | 70 ++ .../distributed/load_balancer/__init__.py | 22 + .../load_balancer/load_balancer.py | 58 ++ 7 files changed, 1399 insertions(+) create mode 100644 vllm_omni/distributed/dp_coordinator/__init__.py create mode 100644 vllm_omni/distributed/dp_coordinator/client_for_hub.py create mode 100644 vllm_omni/distributed/dp_coordinator/client_for_stage.py create mode 100644 vllm_omni/distributed/dp_coordinator/dp_coordinator.py create mode 100644 vllm_omni/distributed/dp_coordinator/messages.py create mode 100644 vllm_omni/distributed/load_balancer/__init__.py create mode 100644 vllm_omni/distributed/load_balancer/load_balancer.py diff --git a/vllm_omni/distributed/dp_coordinator/__init__.py b/vllm_omni/distributed/dp_coordinator/__init__.py new file mode 100644 index 00000000000..22bfc1908f5 --- /dev/null +++ b/vllm_omni/distributed/dp_coordinator/__init__.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""DPCoordinator - Data Parallel Coordinator for vLLM-Omni. + +This module provides a coordinator for managing data parallel stage instances, +aggregating their status, and publishing updates for load balancing. + +Example usage: + + from vllm_omni.distributed.dp_coordinator import ( + DPCoordinator, DPCoordinatorConfig, + ClientForStage, ClientForStageConfig, + ClientForHub, ClientForHubConfig, + StageStatus, + ) + + # Start coordinator + coord_config = DPCoordinatorConfig( + pub_address="tcp://*:5555", + router_address="tcp://*:5556", + ) + coord = DPCoordinator(coord_config) + coord.start() + + # Connect stage side client (worker side) + stage_config = ClientForStageConfig( + coordinator_addr="tcp://localhost:5556", + zmq_addr="tcp://worker-1:8000", + ) + stage_client = ClientForStage(stage_id=0, config=stage_config) + stage_client.start() + stage_client.set_status(StageStatus.UP) + + # Connect hub side client (API server side) + hub_config = ClientForHubConfig( + coordinator_addr="tcp://localhost:5555", + ) + hub_client = ClientForHub(hub_config) + hub_client.start() + + # Query available instances for load balancing + ready_instances = hub_client.get_ready_instances(stage_id=0) + least_loaded = hub_client.get_least_loaded_instance(stage_id=0) + + # Cleanup + hub_client.stop() + stage_client.stop() + coord.stop() +""" + +from .dp_coordinator import DPCoordinator, DPCoordinatorConfig +from .client_for_hub import ClientForHub, ClientForHubConfig +from .client_for_stage import ClientForStage, ClientForStageConfig +from .messages import ( + EventType, + InstanceInfo, + InstanceListing, + StageStatus, +) + +__all__ = [ + # Coordinator + "DPCoordinator", + "DPCoordinatorConfig", + # Hub Side Client (API server side) + "ClientForHub", + "ClientForHubConfig", + # Stage Side Client (worker side) + "ClientForStage", + "ClientForStageConfig", + # Messages + "EventType", + "InstanceInfo", + "InstanceListing", + "StageStatus", +] diff --git a/vllm_omni/distributed/dp_coordinator/client_for_hub.py b/vllm_omni/distributed/dp_coordinator/client_for_hub.py new file mode 100644 index 00000000000..8e745008af7 --- /dev/null +++ b/vllm_omni/distributed/dp_coordinator/client_for_hub.py @@ -0,0 +1,353 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""ClientForHub - Client for API servers to communicate with DPCoordinator. + +This module implements the client that API servers (hubs) use to receive +instance listings from the coordinator for load balancing purposes. +""" + +import threading +import time +from dataclasses import dataclass +from typing import Any, Callable + +import zmq +from vllm.logger import init_logger + +from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer + +from .messages import InstanceInfo, InstanceListing, StageStatus + +logger = init_logger(__name__) + + +@dataclass +class ClientForHubConfig: + """Configuration for ClientForHub. + + Attributes: + coordinator_addr: Address of the coordinator PUB socket + reconnect_interval_ms: Interval between reconnection attempts + stale_threshold_ms: Time after which cached data is considered stale + """ + + coordinator_addr: str = "tcp://localhost:5555" + reconnect_interval_ms: int = 5000 + stale_threshold_ms: int = 10000 + + +class ClientForHub: + """Client for API servers to receive instance listings from DPCoordinator. + + This client subscribes to instance listing broadcasts from the coordinator + and maintains a local cache of available stage instances for load balancing. + """ + + def __init__(self, config: ClientForHubConfig | None = None): + """Initialize the hub side client. + + Args: + config: Client configuration (optional) + """ + self._config = config or ClientForHubConfig() + + # ZMQ + self._zmq_context: zmq.Context | None = None + self._sub_socket: zmq.Socket | None = None + + # Instance cache + self._instance_cache: dict[str, InstanceInfo] = {} + self._cache_lock = threading.RLock() + self._last_update: float = 0.0 + + # Background receiver thread + self._receiver_thread: threading.Thread | None = None + self._running = False + self._stop_event = threading.Event() + + # Callbacks + self._on_update_callbacks: list[Callable[[InstanceListing], None]] = [] + + @property + def last_update(self) -> float: + """Get the timestamp of the last update.""" + return self._last_update + + @property + def is_stale(self) -> bool: + """Check if the cached data is stale.""" + if self._last_update == 0.0: + return True + age_ms = (time.time() - self._last_update) * 1000 + return age_ms > self._config.stale_threshold_ms + + def start(self, zmq_context: zmq.Context | None = None) -> None: + """Start the client and connect to the coordinator.""" + if self._running: + logger.warning("ClientForHub is already running") + return + + logger.info("Starting ClientForHub...") + + # Initialize ZMQ + if zmq_context is None: + self._zmq_context = zmq.Context() + zmq_context = self._zmq_context + + # SUB socket for receiving instance listings + self._sub_socket = zmq_context.socket(zmq.SUB) + self._sub_socket.connect(self._config.coordinator_addr) + self._sub_socket.setsockopt_string(zmq.SUBSCRIBE, "listing") + self._sub_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1s timeout for polling + logger.info( + f"Subscribed to coordinator at {self._config.coordinator_addr}" + ) + + self._running = True + self._stop_event.clear() + + # Start receiver thread + self._receiver_thread = threading.Thread( + target=self._receiver_loop, + name="ClientForHub-Receiver", + daemon=True, + ) + self._receiver_thread.start() + + logger.info("ClientForHub started") + + def stop(self) -> None: + """Stop the client and disconnect from the coordinator.""" + if not self._running: + return + + logger.info("Stopping ClientForHub...") + + self._running = False + self._stop_event.set() + + # Wait for receiver thread + if self._receiver_thread and self._receiver_thread.is_alive(): + self._receiver_thread.join(timeout=2.0) + + # Close sockets + if self._sub_socket: + self._sub_socket.close() + self._sub_socket = None + if self._zmq_context: + self._zmq_context.term() + self._zmq_context = None + + logger.info("ClientForHub stopped") + + def add_update_callback( + self, callback: Callable[[InstanceListing], None] + ) -> None: + """Add a callback to be called when instance listing is updated. + + Args: + callback: Function to call with the new InstanceListing + """ + self._on_update_callbacks.append(callback) + + def remove_update_callback( + self, callback: Callable[[InstanceListing], None] + ) -> None: + """Remove a previously added callback. + + Args: + callback: The callback function to remove + """ + if callback in self._on_update_callbacks: + self._on_update_callbacks.remove(callback) + + def get_instance_listing(self) -> InstanceListing: + """Get the current instance listing from cache. + + Returns: + InstanceListing containing all known instances + """ + with self._cache_lock: + instances = list(self._instance_cache.values()) + return InstanceListing(instances=instances, timestamp=self._last_update) + + def get_instances_by_stage(self, stage_id: int) -> list[InstanceInfo]: + """Get all instances for a specific stage. + + Args: + stage_id: The stage ID to filter by + + Returns: + List of InstanceInfo for the specified stage + """ + with self._cache_lock: + return [ + inst + for inst in self._instance_cache.values() + if inst.stage_id == stage_id + ] + + def get_ready_instances(self, stage_id: int | None = None) -> list[InstanceInfo]: + """Get all instances that are ready to accept requests. + + Args: + stage_id: Optional stage ID to filter by + + Returns: + List of ready InstanceInfo,optionally filtered by stage + """ + with self._cache_lock: + instances = [ + inst + for inst in self._instance_cache.values() + if inst.status == StageStatus.UP + ] + if stage_id is not None: + instances = [inst for inst in instances if inst.stage_id == stage_id] + return instances + + def get_instance(self, zmq_addr: str) -> InstanceInfo | None: + """Get a specific instance by ZMQ address. + + Args: + zmq_addr: The ZMQ address of the instance to look up + + Returns: + InstanceInfo if found, None otherwise + """ + with self._cache_lock: + return self._instance_cache.get(zmq_addr) + + def get_least_loaded_instance( + self, stage_id: int | None = None + ) -> InstanceInfo | None: + """Get the instance with the lowest load. + + Args: + stage_id: Optional stage ID to filter by + + Returns: + InstanceInfo with lowest load, or None if no ready instances + """ + instances = self.get_ready_instances(stage_id) + if not instances: + return None + + # Sort by queue length (number of unfinished tasks) + return min(instances, key=lambda inst: inst.queue_length) + + def health(self) -> dict[str, Any]: + """Get client health status. + + Returns: + Dictionary with health information + """ + with self._cache_lock: + total_instances = len(self._instance_cache) + ready_instances = sum( + 1 + for inst in self._instance_cache.values() + if inst.status == StageStatus.UP + ) + instances_by_stage: dict[int, int] = {} + for inst in self._instance_cache.values(): + instances_by_stage[inst.stage_id] = ( + instances_by_stage.get(inst.stage_id, 0) + 1 + ) + + return { + "running": self._running, + "connected": self._running and not self.is_stale, + "total_instances": total_instances, + "ready_instances": ready_instances, + "instances_by_stage": instances_by_stage, + "last_update": self._last_update, + "is_stale": self.is_stale, + "coordinator_addr": self._config.coordinator_addr, + } + + def _receiver_loop(self) -> None: + """Background thread that receives instance listings.""" + logger.debug("Receiver thread started") + + while self._running and not self._stop_event.is_set(): + try: + # Receive multipart message: [topic, data] + frames = self._sub_socket.recv_multipart() + if len(frames) >= 2: + topic = frames[0].decode("utf-8") + data = frames[1] + + if topic == "listing": + self._handle_listing(data) + + except zmq.Again: + # Timeout, continue loop + continue + except zmq.ZMQError as e: + if self._running: + logger.error(f"ZMQ error in receiver: {e}") + self._stop_event.wait( + self._config.reconnect_interval_ms / 1000.0 + ) + except Exception as e: + if self._running: + logger.error(f"Error in receiver loop: {e}") + + logger.debug("Receiver thread stopped") + + def _handle_listing(self, data: bytes) -> None: + """Handle an incoming instance listing. + + Args: + data: Serialized listing data + """ + try: + listing_dict = OmniSerializer.deserialize(data) + + if not isinstance(listing_dict, dict): + logger.warning(f"Received non-dict listing: {type(listing_dict)}") + return + + # Parse instances + instances: list[InstanceInfo] = [] + for inst_dict in listing_dict.get("instances", []): + # Parse status + status_str = inst_dict.get("status", "down") + if isinstance(status_str, StageStatus): + status = status_str + else: + try: + status = StageStatus(status_str) + except ValueError: + status = StageStatus.DOWN + + instance = InstanceInfo( + stage_id=inst_dict.get("stage_id", 0), + zmq_addr=inst_dict.get("zmq_addr", ""), + status=status, + queue_length=inst_dict.get("queue_length", 0), + last_heartbeat=inst_dict.get("last_heartbeat", 0.0), + registered_at=inst_dict.get("registered_at", 0.0), + ) + instances.append(instance) + + # Update cache + timestamp = listing_dict.get("timestamp", time.time()) + with self._cache_lock: + self._instance_cache = {inst.zmq_addr: inst for inst in instances} + self._last_update = timestamp + + # Notify callbacks + listing = InstanceListing(instances=instances, timestamp=timestamp) + for callback in self._on_update_callbacks: + try: + callback(listing) + except Exception as e: + logger.error(f"Error in update callback: {e}") + + logger.debug(f"Updated instance cache: {len(instances)} instances") + + except Exception as e: + logger.error(f"Error handling listing: {e}") diff --git a/vllm_omni/distributed/dp_coordinator/client_for_stage.py b/vllm_omni/distributed/dp_coordinator/client_for_stage.py new file mode 100644 index 00000000000..e47a0675ee4 --- /dev/null +++ b/vllm_omni/distributed/dp_coordinator/client_for_stage.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""ClientForStage - Client for workers to communicate with DPCoordinator. + +This module implements the client that stage workers use to register +with the coordinator, send heartbeats, and report status updates. +""" + +import threading +import time +from dataclasses import dataclass + +import zmq +from vllm.logger import init_logger + +from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer + +from .messages import EventType, StageStatus + +logger = init_logger(__name__) + + +@dataclass +class ClientForStageConfig: + """Configuration for ClientForStage. + + Attributes: + coordinator_addr: Address of the coordinator ROUTER socket + zmq_addr: ZMQ address where this stage receives tasks + heartbeat_interval_ms: Interval between heartbeats + """ + + coordinator_addr: str = "tcp://localhost:5556" + zmq_addr: str = "" + heartbeat_interval_ms: int = 1000 + + +class ClientForStage: + """Client for stage workers to communicate with the DPCoordinator. + + This client handles heartbeats and status updates for a stage instance. + """ + + def __init__(self, stage_id: int, config: ClientForStageConfig): + """Initialize the stage client. + + Args: + stage_id: Identifier for the stage (e.g., 0 for LLM, 1 for diffusion) + config: Client configuration + """ + self._config = config + self._stage_id = stage_id + self._status = StageStatus.DOWN + self._queue_length = 0 + self._lock = threading.RLock() + + # ZMQ + self._zmq_context: zmq.Context | None = None + self._dealer_socket: zmq.Socket | None = None + + # Background heartbeat thread + self._heartbeat_thread: threading.Thread | None = None + self._running = False + self._stop_event = threading.Event() + + @property + def zmq_addr(self) -> str: + """Get the ZMQ address.""" + return self._config.zmq_addr + + @property + def stage_id(self) -> int: + """Get the stage ID.""" + return self._stage_id + + @property + def status(self) -> StageStatus: + """Get the current status.""" + return self._status + + def start(self, zmq_context: zmq.Context | None = None) -> None: + """Start the client and connect to the coordinator.""" + if self._running: + logger.warning("ClientForStage is already running") + return + + logger.info(f"Starting ClientForStage {self._config.zmq_addr}...") + + # Initialize ZMQ + if zmq_context is None: + self._zmq_context = zmq.Context() + zmq_context = self._zmq_context + + # DEALER socket for communication with coordinator + self._dealer_socket = zmq_context.socket(zmq.DEALER) + self._dealer_socket.setsockopt_string(zmq.IDENTITY, self._config.zmq_addr) + self._dealer_socket.connect(self._config.coordinator_addr) + logger.info(f"Connected to coordinator at {self._config.coordinator_addr}") + + self._running = True + self._stop_event.clear() + + # Start heartbeat thread + self._heartbeat_thread = threading.Thread( + target=self._heartbeat_loop, + name=f"ClientForStage-Heartbeat-{self._config.zmq_addr}", + daemon=True, + ) + self._heartbeat_thread.start() + + logger.info(f"ClientForStage {self._config.zmq_addr} started") + + def stop(self) -> None: + """Stop the client and disconnect from the coordinator.""" + if not self._running: + return + + logger.info(f"Stopping ClientForStage {self._config.zmq_addr}...") + + # Send DOWN status + try: + self.set_status(StageStatus.DOWN) + except Exception as e: + logger.warning(f"Failed to send status update: {e}") + + self._running = False + self._stop_event.set() + + # Wait for heartbeat thread + if self._heartbeat_thread and self._heartbeat_thread.is_alive(): + self._heartbeat_thread.join(timeout=2.0) + + # Close sockets + if self._dealer_socket: + self._dealer_socket.close() + self._dealer_socket = None + if self._zmq_context: + self._zmq_context.term() + self._zmq_context = None + + logger.info(f"ClientForStage {self._config.zmq_addr} stopped") + + def update_load( + self, + queue_length: int + ) -> None: + """Send a load update to the coordinator. + + Args: + queue_length: Number of unfinished tasks. + """ + # Update local metrics + with self._lock: + if self._queue_length == queue_length: + return # No change in load + self._queue_length = queue_length + self._send_event(EventType.LOAD_UPDATE) + + def set_status(self, status: StageStatus) -> None: + """Set the instance status. + + Args: + status: New status for the instance + """ + with self._lock: + if self._status == status: + return # No change in status + self._status = status + self._send_event(EventType.STATUS_UPDATE) + + def _send_event(self, event_type: EventType) -> None: + """Send an event to the coordinator. + + Args: + event_type: Type of the event + payload: Event-specific payload data + """ + if not self._dealer_socket: + raise RuntimeError("Client not started") + + with self._lock: + message = { + "event_type": event_type.value, + "stage_id": self._stage_id, + "zmq_addr": self.zmq_addr, + "status": self._status.value, + "queue_length":self._queue_length, + "timestamp": time.time(), + } + + data = OmniSerializer.serialize(message) + self._dealer_socket.send(data) + + def _heartbeat_loop(self) -> None: + """Background thread that sends periodic heartbeats.""" + logger.debug(f"Heartbeat thread started for {self.zmq_addr}") + heartbeat_interval = self._config.heartbeat_interval_ms / 1000.0 + + while self._running and not self._stop_event.is_set(): + try: + self._send_event(EventType.HEARTBEAT) + except Exception as e: + logger.error(f"Error sending heartbeat: {e}") + + self._stop_event.wait(heartbeat_interval) + + logger.debug(f"Heartbeat thread stopped for {self.zmq_addr}") + diff --git a/vllm_omni/distributed/dp_coordinator/dp_coordinator.py b/vllm_omni/distributed/dp_coordinator/dp_coordinator.py new file mode 100644 index 00000000000..afcc27ce7eb --- /dev/null +++ b/vllm_omni/distributed/dp_coordinator/dp_coordinator.py @@ -0,0 +1,610 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""DPCoordinator - Data Parallel Coordinator for vLLM-Omni. + +This module implements the coordinator that aggregates stage instance status +and publishes updates via ZMQ PUB/SUB for load balancing. +""" + +import threading +import time +from dataclasses import asdict, dataclass +from typing import Any + +import zmq +from vllm.logger import init_logger + +from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer + +from .messages import ( + EventType, + InstanceInfo, + InstanceListing, + StageStatus, +) + +logger = init_logger(__name__) + + +@dataclass +class DPCoordinatorConfig: + """Configuration for DPCoordinator. + + Attributes: + pub_address: Address for PUB socket (broadcasts instance listings) + router_address: Address for ROUTER socket (receives registrations/heartbeats) + heartbeat_timeout_ms: Time before marking an instance as stale + publish_interval_ms: Interval between instance listing broadcasts + """ + + pub_address: str = "tcp://*:5555" + router_address: str = "tcp://*:5556" + heartbeat_timeout_ms: int = 5000 + publish_interval_ms: int = 500 + + +class DPCoordinator: + """Data Parallel Coordinator for managing stage instances. + + The coordinator maintains a registry of stage instances, receives + heartbeats and status updates, and broadcasts instance listings + for load balancing. + + This class implements the singleton pattern - use get_instance() to + obtain the coordinator instance. + """ + + _instance: "DPCoordinator | None" = None + _instance_lock: threading.Lock = threading.Lock() + + def __init__(self, config: DPCoordinatorConfig | None = None): + """Initialize the coordinator. + + Args: + config: Configuration for the coordinator. If None, uses defaults. + """ + self._config = config or DPCoordinatorConfig() + + # ZMQ context and sockets + self._zmq_context: zmq.Context | None = None + self._pub_socket: zmq.Socket | None = None + self._router_socket: zmq.Socket | None = None + + # Thread-safe registry + self._registry: dict[str, InstanceInfo] = {} + self._registry_lock = threading.RLock() + + # Background threads + self._receiver_thread: threading.Thread | None = None + self._publisher_thread: threading.Thread | None = None + self._health_checker_thread: threading.Thread | None = None + + # Control flags + self._running = False + self._stop_event = threading.Event() + + @classmethod + def get_instance(cls, config: DPCoordinatorConfig | None = None) -> "DPCoordinator": + """Get or create the singleton coordinator instance. + + Args: + config: Configuration for the coordinator (only used on first call) + + Returns: + The singleton DPCoordinator instance + """ + with cls._instance_lock: + if cls._instance is None: + cls._instance = cls(config) + return cls._instance + + @classmethod + def reset_instance(cls) -> None: + """Reset the singleton instance. Primarily for testing.""" + with cls._instance_lock: + if cls._instance is not None: + cls._instance.stop() + cls._instance = None + + @property + def config(self) -> DPCoordinatorConfig: + """Get the coordinator configuration.""" + return self._config + + def start(self, zmq_context: zmq.Context | None = None) -> None: + """Start the coordinator and all background threads.""" + if self._running: + logger.warning("DPCoordinator is already running") + return + + logger.info("Starting DPCoordinator...") + + # Initialize ZMQ + if zmq_context is None: + self._zmq_context = zmq.Context() + zmq_context = self._zmq_context + + # PUB socket for broadcasting instance listings + self._pub_socket = zmq_context.socket(zmq.PUB) + self._pub_socket.bind(self._config.pub_address) + logger.info(f"PUB socket bound to {self._config.pub_address}") + + # ROUTER socket for receiving registrations and heartbeats + self._router_socket = zmq_context.socket(zmq.ROUTER) + self._router_socket.bind(self._config.router_address) + self._router_socket.setsockopt(zmq.RCVTIMEO, 100) # 100ms timeout for polling + logger.info(f"ROUTER socket bound to {self._config.router_address}") + + self._running = True + self._stop_event.clear() + + # Start background threads + self._receiver_thread = threading.Thread( + target=self._receiver_loop, name="DPCoordinator-Receiver", daemon=True + ) + self._publisher_thread = threading.Thread( + target=self._publisher_loop, name="DPCoordinator-Publisher", daemon=True + ) + self._health_checker_thread = threading.Thread( + target=self._health_checker_loop, + name="DPCoordinator-HealthChecker", + daemon=True, + ) + + self._receiver_thread.start() + self._publisher_thread.start() + self._health_checker_thread.start() + + logger.info("DPCoordinator started successfully") + + def stop(self) -> None: + """Stop the coordinator and all background threads.""" + if not self._running: + return + + logger.info("Stopping DPCoordinator...") + self._running = False + self._stop_event.set() + + # Wait for threads to finish + if self._receiver_thread and self._receiver_thread.is_alive(): + self._receiver_thread.join(timeout=2.0) + if self._publisher_thread and self._publisher_thread.is_alive(): + self._publisher_thread.join(timeout=2.0) + if self._health_checker_thread and self._health_checker_thread.is_alive(): + self._health_checker_thread.join(timeout=2.0) + + # Close sockets + if self._pub_socket: + self._pub_socket.close() + self._pub_socket = None + if self._router_socket: + self._router_socket.close() + self._router_socket = None + if self._zmq_context: + self._zmq_context.term() + self._zmq_context = None + + logger.info("DPCoordinator stopped") + + def register_instance( + self, + instance_id: str, + stage_id: int, + status: StageStatus = StageStatus.DOWN, + queue_length: int = 0, + ) -> None: + """Register a new stage instance. + + Args: + instance_id: Unique identifier for this instance (zmq_addr) + stage_id: Identifier for the stage type + status: Initial status + queue_length: Initial queue length + """ + timestamp = time.time() + instance_info = InstanceInfo( + stage_id=stage_id, + instance_id=instance_id, + status=status, + queue_length=queue_length, + last_heartbeat=timestamp, + registered_at=timestamp, + ) + + with self._registry_lock: + if instance_id in self._registry: + logger.warning(f"Instance {instance_id} already registered, updating") + self._registry[instance_id] = instance_info + + logger.info(f"Registered instance {instance_id} (stage={stage_id})") + + def unregister_instance(self, instance_id: str) -> bool: + """Unregister a stage instance. + + Args: + instance_id: The instance ID (zmq_addr) to unregister + + Returns: + True if the instance was found and removed, False otherwise + """ + with self._registry_lock: + if instance_id in self._registry: + del self._registry[instance_id] + logger.info(f"Unregistered instance {instance_id}") + return True + else: + logger.warning(f"Instance {instance_id} not found for unregistration") + return False + + def update_instance( + self, + instance_id: str, + status: StageStatus | None = None, + queue_length: int | None = None, + ) -> bool: + """Update the status and/or queue length of an instance. + + Args: + instance_id: The instance ID (zmq_addr) to update + status: New status (optional) + queue_length: New queue length (optional) + + Returns: + True if the instance was found and updated, False otherwise + """ + timestamp = time.time() + + with self._registry_lock: + if instance_id not in self._registry: + return False + + instance = self._registry[instance_id] + instance.last_heartbeat = timestamp + + if status is not None: + instance.status = status + if queue_length is not None: + instance.queue_length = queue_length + + return True + + def get_instance_listing(self) -> InstanceListing: + """Get the current instance listing. + + Returns: + InstanceListing containing all registered instances + """ + with self._registry_lock: + instances = list(self._registry.values()) + + return InstanceListing(instances=instances, timestamp=time.time()) + + def get_instances_by_stage(self, stage_id: int) -> list[InstanceInfo]: + """Get all instances for a specific stage. + + Args: + stage_id: The stage ID to filter by + + Returns: + List of InstanceInfo for the specified stage + """ + with self._registry_lock: + return [ + inst for inst in self._registry.values() if inst.stage_id == stage_id + ] + + def get_instance(self, instance_id: str) -> InstanceInfo | None: + """Get a specific instance by ID. + + Args: + instance_id: The instance ID (zmq_addr) to look up + Returns: + InstanceInfo if found, None otherwise + """ + with self._registry_lock: + return self._registry.get(instance_id) + + def health(self) -> dict[str, Any]: + """Get coordinator health status. + + Returns: + Dictionary with health information + """ + with self._registry_lock: + total_instances = len(self._registry) + ready_instances = sum( + 1 + for inst in self._registry.values() + if inst.status == StageStatus.UP + ) + instances_by_stage: dict[int, int] = {} + for inst in self._registry.values(): + instances_by_stage[inst.stage_id] = ( + instances_by_stage.get(inst.stage_id, 0) + 1 + ) + + return { + "running": self._running, + "total_instances": total_instances, + "ready_instances": ready_instances, + "instances_by_stage": instances_by_stage, + "pub_address": self._config.pub_address, + "router_address": self._config.router_address, + } + + def _receiver_loop(self) -> None: + """Background thread that receives messages from clients.""" + logger.debug("Receiver thread started") + + while self._running and not self._stop_event.is_set(): + try: + # ROUTER socket receives: [identity, empty, message] + frames = self._router_socket.recv_multipart() + if len(frames) < 2: + logger.warning(f"Received malformed message: {len(frames)} frames") + continue + + identity = frames[0] + message_data = frames[-1] # Last frame is the message + + self._handle_message(identity, message_data) + + except zmq.Again: + # Timeout, continue loop + continue + except zmq.ZMQError as e: + if self._running: + logger.error(f"ZMQ error in receiver: {e}") + except Exception as e: + if self._running: + logger.error(f"Error in receiver loop: {e}") + + logger.debug("Receiver thread stopped") + + def _handle_message(self, identity: bytes, message_data: bytes) -> None: + """Handle an incoming message from a client. + + Args: + identity: ZMQ identity of the sender (zmq_addr) + message_data: Serialized message data + """ + try: + message = OmniSerializer.deserialize(message_data) + + if not isinstance(message, dict): + logger.warning(f"Received non-dict message: {type(message)}") + return + + event_type = message.get("event_type") + if event_type is None: + logger.warning("Received message without event_type") + return + + # Instance ID is the ZMQ identity (zmq_addr) + instance_id = identity.decode("utf-8") + + # Extract common fields from message + stage_id = message.get("stage_id", 0) + status_str = message.get("status") + status = StageStatus(status_str) if status_str else None + queue_length = message.get("queue_length", 0) + + if event_type == EventType.STATUS_UPDATE or event_type == "status_update": + self._handle_status_update(instance_id, stage_id, status, queue_length) + elif event_type == EventType.LOAD_UPDATE or event_type == "load_update": + self._handle_load_update(instance_id, stage_id, status, queue_length) + elif event_type == EventType.HEARTBEAT or event_type == "heartbeat": + self._handle_heartbeat(instance_id, stage_id, status, queue_length) + else: + logger.warning(f"Unknown event type: {event_type}") + + except Exception as e: + logger.error(f"Error handling message: {e}") + + def _handle_status_update( + self, + instance_id: str, + stage_id: int, + status: StageStatus | None, + queue_length: int, + ) -> None: + """Handle a status update message. + + Auto-registers unknown instances and unregisters instances with DOWN status. + """ + if status == StageStatus.DOWN: + self.unregister_instance(instance_id) + return + + # Auto-register if instance is unknown + if self.get_instance(instance_id) is None: + self.register_instance( + instance_id=instance_id, + stage_id=stage_id, + status=status or StageStatus.DOWN, + queue_length=queue_length, + ) + if status == StageStatus.UP: + logger.info(f"Instance {instance_id} is now UP") + elif status: + self.update_instance(instance_id, status=status, queue_length=queue_length) + if status == StageStatus.UP: + logger.info(f"Instance {instance_id} is now UP") + elif status == StageStatus.ERROR: + logger.error(f"Instance {instance_id} reported error") + + def _handle_heartbeat( + self, + instance_id: str, + stage_id: int, + status: StageStatus | None, + queue_length: int, + ) -> None: + """Handle a heartbeat message.""" + # Auto-register if instance is unknown + if self.get_instance(instance_id) is None: + self.register_instance( + instance_id=instance_id, + stage_id=stage_id, + status=status or StageStatus.DOWN, + queue_length=queue_length, + ) + else: + self.update_instance(instance_id, status=status, queue_length=queue_length) + + def _handle_load_update( + self, + instance_id: str, + stage_id: int, + status: StageStatus | None, + queue_length: int, + ) -> None: + """Handle a load update message.""" + # Auto-register if instance is unknown + if self.get_instance(instance_id) is None: + self.register_instance( + instance_id=instance_id, + stage_id=stage_id, + status=status or StageStatus.DOWN, + queue_length=queue_length, + ) + else: + self.update_instance(instance_id, status=status, queue_length=queue_length) + + def _publisher_loop(self) -> None: + """Background thread that publishes instance listings periodically.""" + logger.debug("Publisher thread started") + publish_interval = self._config.publish_interval_ms / 1000.0 + + while self._running and not self._stop_event.is_set(): + try: + listing = self.get_instance_listing() + self._publish_listing(listing) + except Exception as e: + if self._running: + logger.error(f"Error in publisher loop: {e}") + + self._stop_event.wait(publish_interval) + + logger.debug("Publisher thread stopped") + + def _publish_listing(self, listing: InstanceListing) -> None: + """Publish an instance listing to subscribers. + + Args: + listing: The instance listing to publish + """ + if not self._pub_socket: + return + + try: + # Convert to serializable dict + listing_dict = { + "instances": [asdict(inst) for inst in listing.instances], + "timestamp": listing.timestamp, + } + # Convert enum values to strings + for inst in listing_dict["instances"]: + if isinstance(inst.get("status"), StageStatus): + inst["status"] = inst["status"].value + elif inst.get("status"): + inst["status"] = str(inst["status"]) + + data = OmniSerializer.serialize(listing_dict) + self._pub_socket.send_multipart([b"listing", data]) + + except Exception as e: + logger.error(f"Error publishing listing: {e}") + + def _health_checker_loop(self) -> None: + """Background thread that detects stale instances.""" + logger.debug("Health checker thread started") + check_interval = self._config.heartbeat_timeout_ms / 1000.0 / 2.0 + + while self._running and not self._stop_event.is_set(): + try: + self._check_stale_instances() + except Exception as e: + if self._running: + logger.error(f"Error in health checker loop: {e}") + + self._stop_event.wait(check_interval) + + logger.debug("Health checker thread stopped") + + def _check_stale_instances(self) -> None: + """Check for and handle stale instances.""" + current_time = time.time() + timeout_seconds = self._config.heartbeat_timeout_ms / 1000.0 + stale_instances: list[str] = [] + + with self._registry_lock: + for instance_id, instance in self._registry.items(): + if instance.status == StageStatus.DOWN: + continue + if current_time - instance.last_heartbeat > timeout_seconds: + stale_instances.append(instance_id) + for instance_id in stale_instances: + logger.warning(f"Instance {instance_id} is stale (no heartbeat)") + self.update_instance(instance_id, status=StageStatus.ERROR) + + +def main(): + """Entry point for running the coordinator as a standalone process.""" + import argparse + + parser = argparse.ArgumentParser(description="DPCoordinator - Data Parallel Coordinator") + parser.add_argument( + "--pub-address", + type=str, + default="tcp://*:5555", + help="PUB socket bind address (default: tcp://*:5555)", + ) + parser.add_argument( + "--router-address", + type=str, + default="tcp://*:5556", + help="ROUTER socket bind address (default: tcp://*:5556)", + ) + parser.add_argument( + "--heartbeat-timeout-ms", + type=int, + default=5000, + help="Heartbeat timeout in milliseconds (default: 5000)", + ) + parser.add_argument( + "--publish-interval-ms", + type=int, + default=500, + help="Instance listing publish interval in milliseconds (default: 500)", + ) + + args = parser.parse_args() + + config = DPCoordinatorConfig( + pub_address=args.pub_address, + router_address=args.router_address, + heartbeat_timeout_ms=args.heartbeat_timeout_ms, + publish_interval_ms=args.publish_interval_ms, + ) + + coordinator = DPCoordinator(config) + coordinator.start() + + print(f"DPCoordinator started") + print(f" PUB address: {config.pub_address}") + print(f" ROUTER address: {config.router_address}") + print("Press Ctrl+C to stop...") + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("\nShutting down...") + coordinator.stop() + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/distributed/dp_coordinator/messages.py b/vllm_omni/distributed/dp_coordinator/messages.py new file mode 100644 index 00000000000..934faa2e4b6 --- /dev/null +++ b/vllm_omni/distributed/dp_coordinator/messages.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Message protocol definitions for DPCoordinator. + +This module defines the dataclasses and enums used for communication +between the DPCoordinator and StageClients. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class StageStatus(str, Enum): + """Status of a stage instance.""" + + UP = "up" + DOWN = "down" + ERROR = "error" + + +class EventType(str, Enum): + """Types of events sent between coordinator and clients.""" + + STATUS_UPDATE = "status_update" + LOAD_UPDATE = "load_update" + HEARTBEAT = "heartbeat" + + +@dataclass +class InstanceInfo: + """Instance state stored in the coordinator registry. + + Attributes: + stage_id: Identifier for the stage type + zmq_addr: Unique identifier for this instance + status: Current status of the instance + queue_length: Number of unfinished tasks + last_heartbeat: Unix timestamp of last heartbeat + registered_at: Unix timestamp when instance was registered + """ + + stage_id: int + zmq_addr: str + status: StageStatus = StageStatus.DOWN + queue_length: int = 0 + last_heartbeat: float = 0.0 + registered_at: float = 0.0 + + +@dataclass +class InstanceListing: + """List of instances for PUB broadcast. + + Attributes: + instances: List of instance information + timestamp: Unix timestamp when listing was generated + """ + + instances: list[InstanceInfo] = field(default_factory=list) + timestamp: float = 0.0 + + def get_instances_by_stage(self, stage_id: int) -> list[InstanceInfo]: + """Get all instances for a specific stage.""" + return [inst for inst in self.instances if inst.stage_id == stage_id] + + def get_ready_instances(self) -> list[InstanceInfo]: + """Get all instances that are ready to accept requests.""" + return [inst for inst in self.instances if inst.status == StageStatus.UP] diff --git a/vllm_omni/distributed/load_balancer/__init__.py b/vllm_omni/distributed/load_balancer/__init__.py new file mode 100644 index 00000000000..a398a7edf42 --- /dev/null +++ b/vllm_omni/distributed/load_balancer/__init__.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""LoadBalancer - Task routing for vLLM-Omni distributed stages.""" + +from .load_balancer import ( + ConsistentHashBalancer, + LeastLoadedBalancer, + LoadBalancer, + RandomBalancer, + RoundRobinBalancer, + Task, +) + +__all__ = [ + "LoadBalancer", + "RandomBalancer", + "RoundRobinBalancer", + "LeastLoadedBalancer", + "ConsistentHashBalancer", + "Task", +] diff --git a/vllm_omni/distributed/load_balancer/load_balancer.py b/vllm_omni/distributed/load_balancer/load_balancer.py new file mode 100644 index 00000000000..08c4ffe93a2 --- /dev/null +++ b/vllm_omni/distributed/load_balancer/load_balancer.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""LoadBalancer - Task routing for vLLM-Omni distributed stages. + +This module implements load balancers that route tasks to stage instances +based on different strategies. +""" + +import hashlib +import random +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from vllm_omni.distributed.dp_coordinator.messages import InstanceInfo + + +@dataclass +class Task: + """Placeholder for task object.""" + + session_id: str + request_id: str + + +class LoadBalancer(ABC): + """Abstract base class for load balancers.""" + + @abstractmethod + def select( + self, + task: Task, + instances: list[InstanceInfo], + ) -> int: + """Route a task to an instance. + + Args: + task: The task to route + instances: List of available instances + + Returns: + Index of the selected instance + """ + pass + + +class RandomBalancer(LoadBalancer): + """Randomly select an instance from the available pool.""" + + def select( + self, + task: Task, + instances: list[InstanceInfo], + ) -> int: + if not instances: + raise ValueError("No instances available") + return random.randrange(len(instances))