diff --git a/vllm_omni/diffusion/stage_diffusion_client.py b/vllm_omni/diffusion/stage_diffusion_client.py index e19db453644..f7d8977f936 100644 --- a/vllm_omni/diffusion/stage_diffusion_client.py +++ b/vllm_omni/diffusion/stage_diffusion_client.py @@ -122,6 +122,11 @@ def _initialize_client( self.engine_input_source = getattr(metadata, "engine_input_source", []) self._proc = proc self._owns_process = proc is not None + # Expose the ZMQ addresses on the instance so callers (e.g. + # ``StagePool._client_input_addr``) can identify the diffusion + # replica by its bound address. + self.request_address = request_address + self.response_address = response_address self._zmq_ctx = zmq.Context() self._request_socket = self._zmq_ctx.socket(zmq.PUSH) diff --git a/vllm_omni/diffusion/stage_diffusion_proc.py b/vllm_omni/diffusion/stage_diffusion_proc.py index 2f3ad55943a..9d34df895b0 100644 --- a/vllm_omni/diffusion/stage_diffusion_proc.py +++ b/vllm_omni/diffusion/stage_diffusion_proc.py @@ -7,6 +7,7 @@ from __future__ import annotations import asyncio +import contextlib import signal import time from concurrent.futures import ThreadPoolExecutor @@ -30,6 +31,7 @@ OmniMsgpackDecoder, OmniMsgpackEncoder, ) +from vllm_omni.distributed.omni_coordinator import OmniCoordClientForStage from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -54,6 +56,19 @@ def __init__(self, model: str, od_config: OmniDiffusionConfig) -> None: self._engine: DiffusionEngine | None = None self._executor: ThreadPoolExecutor | None = None self._closed = False + # Set by ``run_loop`` to the live dispatch task dict so + # :attr:`queue_length` can report in-flight requests for the + # OmniCoordinator heartbeat hook. + self._active_tasks: dict[str, asyncio.Task] | None = None + + @property + def queue_length(self) -> int: + """Number of in-flight diffusion requests. + + Returns 0 before :meth:`run_loop` starts and after it exits. + """ + tasks = self._active_tasks + return 0 if tasks is None else len(tasks) # ------------------------------------------------------------------ # Initialization @@ -309,6 +324,9 @@ async def run_loop( decoder = OmniMsgpackDecoder() tasks: dict[str, asyncio.Task] = {} + # Expose the live task dict so :attr:`queue_length` (used by the + # OmniCoordinator heartbeat hook) can read the in-flight count. + self._active_tasks = tasks async def _dispatch_request( request_id: str, @@ -466,6 +484,7 @@ async def _dispatch_batch( if tasks: await asyncio.gather(*tasks.values(), return_exceptions=True) + self._active_tasks = None request_socket.close() response_socket.close() ctx.term() @@ -504,8 +523,22 @@ def run_diffusion_proc( handshake_address: str, request_address: str, response_address: str, + *, + omni_coordinator_address: str | None = None, + omni_stage_id: int | None = None, + omni_replica_id: int = 0, ) -> None: - """Entry point for the diffusion subprocess.""" + """Entry point for the diffusion subprocess. + + Omni-specific kwargs (mirroring :meth:`StageEngineCoreProc.run_stage_core`): + - ``omni_coordinator_address``: ROUTER address of the head-side + OmniCoordinator. When set, a :class:`OmniCoordClientForStage` + reports the diffusion replica's status + queue length. + - ``omni_stage_id``: logical stage id; required when + ``omni_coordinator_address`` is set. + - ``omni_replica_id``: cluster-unique replica id within the + stage (logging / metrics only). + """ shutdown_requested = False def signal_handler(signum: int, frame: Any) -> None: @@ -518,6 +551,7 @@ def signal_handler(signum: int, frame: Any) -> None: signal.signal(signal.SIGINT, signal_handler) proc = cls(model, od_config) + coord_client: OmniCoordClientForStage | None = None try: proc.initialize() @@ -529,6 +563,32 @@ def signal_handler(signum: int, frame: Any) -> None: handshake_socket.close() handshake_ctx.term() + # Wire OmniCoordClientForStage *after* READY so that the head + # has bound its head-side request/response sockets — the + # address pair we report is the same pair this proc binds to + # (request/response addresses passed in). + if omni_coordinator_address is not None: + if omni_stage_id is None: + raise ValueError("omni_stage_id must be provided when omni_coordinator_address is set") + coord_client = OmniCoordClientForStage( + coord_zmq_addr=omni_coordinator_address, + input_addr=request_address, + output_addr=response_address, + stage_id=int(omni_stage_id), + ) + + def _refresh_queue_length() -> None: + coord_client._queue_length = proc.queue_length # type: ignore[union-attr] + + coord_client._on_heartbeat = _refresh_queue_length + + logger.info( + "StageDiffusionProc registered with OmniCoordinator (stage_id=%d replica_id=%d coord=%s)", + omni_stage_id, + omni_replica_id, + omni_coordinator_address, + ) + # Run async event loop asyncio.run(proc.run_loop(request_address, response_address)) @@ -539,6 +599,9 @@ def signal_handler(signum: int, frame: Any) -> None: logger.exception("StageDiffusionProc encountered a fatal error.") raise finally: + if coord_client is not None: + with contextlib.suppress(RuntimeError): + coord_client.close() proc.close() @@ -551,10 +614,17 @@ def spawn_diffusion_proc( handshake_address: str | None = None, request_address: str | None = None, response_address: str | None = None, + *, + omni_coordinator_address: str | None = None, + omni_stage_id: int | None = None, + omni_replica_id: int = 0, ) -> tuple[BaseProcess, str, str, str]: """Spawn a StageDiffusionProc subprocess. Returns ``(proc, handshake_address, request_address, response_address)``. + + Pass ``omni_coordinator_address`` / ``omni_stage_id`` / ``omni_replica_id`` + to have the subprocess publish heartbeats to an OmniCoordinator. """ handshake_address = handshake_address or get_open_zmq_ipc_path() request_address = request_address or get_open_zmq_ipc_path() @@ -570,6 +640,9 @@ def spawn_diffusion_proc( "handshake_address": handshake_address, "request_address": request_address, "response_address": response_address, + "omni_coordinator_address": omni_coordinator_address, + "omni_stage_id": omni_stage_id, + "omni_replica_id": omni_replica_id, }, ) proc.start() diff --git a/vllm_omni/distributed/omni_coordinator/__init__.py b/vllm_omni/distributed/omni_coordinator/__init__.py index 6894e311378..f5cf8fceb43 100644 --- a/vllm_omni/distributed/omni_coordinator/__init__.py +++ b/vllm_omni/distributed/omni_coordinator/__init__.py @@ -9,17 +9,19 @@ RoundRobinBalancer, Task, ) -from .messages import InstanceEvent, InstanceInfo, InstanceList, StageStatus +from .messages import ReplicaEvent, ReplicaInfo, ReplicaList, StageStatus from .omni_coord_client_for_hub import OmniCoordClientForHub from .omni_coord_client_for_stage import OmniCoordClientForStage from .omni_coordinator import OmniCoordinator +from .runtime import OmniCoordinatorRuntime __all__ = [ "OmniCoordinator", + "OmniCoordinatorRuntime", "StageStatus", - "InstanceEvent", - "InstanceInfo", - "InstanceList", + "ReplicaEvent", + "ReplicaInfo", + "ReplicaList", "OmniCoordClientForStage", "OmniCoordClientForHub", "Task", diff --git a/vllm_omni/distributed/omni_coordinator/load_balancer.py b/vllm_omni/distributed/omni_coordinator/load_balancer.py index 41b03be1630..9f21d0b9773 100644 --- a/vllm_omni/distributed/omni_coordinator/load_balancer.py +++ b/vllm_omni/distributed/omni_coordinator/load_balancer.py @@ -9,14 +9,14 @@ from enum import Enum from typing import Any, TypedDict -from .messages import InstanceInfo +from .messages import ReplicaInfo class Task(TypedDict, total=False): - """Task structure passed from async_omni (stage.submit(task)). + """Task structure passed to ``StagePool.pick`` / ``LoadBalancer.select``. - Mirrors the dict built in AsyncOmni with request_id, engine_inputs, - sampling_params. Future load-balancing policies may use these fields. + Mirrors the dict built around a stage submission with request_id and any + payload-related fields a future load-balancing policy might inspect. """ request_id: str @@ -28,10 +28,7 @@ class LoadBalancingPolicy(str, Enum): """Enumeration for load balancing policies. These policies are used by :class:`LoadBalancer` implementations to route - tasks to a subset of available instances. - - TODO(NumberWan): Map enum values to balancer classes when OmniCoordinator - integration lands. Tracked in https://github.com/vllm-project/vllm-omni/pull/2448 + tasks to a subset of available replicas. """ RANDOM = "random" @@ -42,74 +39,60 @@ class LoadBalancingPolicy(str, Enum): class LoadBalancer(ABC): """Abstract base class for load balancers. - Subclasses implement :meth:`select` to choose an instance for a given task. + Subclasses implement :meth:`select` to choose a replica for a given task. """ @abstractmethod - def select(self, task: Task, instances: list[InstanceInfo]) -> int: - """Route a task to one of the available instances. + def select(self, task: Task, replicas: list[ReplicaInfo]) -> int: + """Route a task to one of the available replicas. Args: task: The task to route. Not used by the random policy but reserved for future strategies that may inspect task metadata. - instances: List of available instances to choose from. + replicas: List of available replicas to choose from. Returns: - Index of the selected instance in ``instances``. + Index of the selected replica in ``replicas``. Raises: - ValueError: If ``instances`` is empty. + ValueError: If ``replicas`` is empty. """ raise NotImplementedError class RandomBalancer(LoadBalancer): - """Load balancer that selects an instance uniformly at random. - - It intentionally ignores the task payload and chooses a random index from - the provided instance list. More sophisticated policies (e.g. round-robin, - least-queue-length) can be implemented as additional subclasses of - :class:`LoadBalancer`. - """ + """Load balancer that selects a replica uniformly at random.""" - def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002 - if not instances: - raise ValueError("instances must not be empty") + def select(self, task: Task, replicas: list[ReplicaInfo]) -> int: # noqa: ARG002 + if not replicas: + raise ValueError("replicas must not be empty") - return random.randrange(len(instances)) + return random.randrange(len(replicas)) class RoundRobinBalancer(LoadBalancer): - """Load balancer that selects instances in a round-robin fashion. + """Load balancer that selects replicas in a round-robin fashion. - This implementation keeps a running index modulo ``len(instances)``. It - therefore depends on the **order and stable meaning** of the ``instances`` + This implementation keeps a running index modulo ``len(replicas)``. It + therefore depends on the **order and stable meaning** of the ``replicas`` list between calls. If the list length or ordering changes, the sequence of picks may skip or repeat entries relative to a fixed set of backends. - When instance membership changes dynamically, callers should reset routing - state—for example by constructing a new ``RoundRobinBalancer`` or resetting - ``_next_index``—similar to rebuilding ``itertools.cycle`` after mutating - the instance list (as in vLLM's disaggregated proxy examples). - - Concurrency: ``select`` is synchronous and is expected to run on the - coordinator asyncio event loop thread without ``await`` inside this - method, so a single invocation is not interleaved with another on that - thread. A :class:`threading.Lock` still serializes updates to - ``_next_index`` for callers that might invoke ``select`` from multiple - threads or alongside threaded infrastructure (e.g. ZMQ receive threads). + Concurrency: a ``threading.Lock`` serializes updates to ``_next_index`` + for callers that invoke ``select`` from multiple threads or alongside + threaded infrastructure (e.g. ZMQ receive threads). """ def __init__(self, start_index: int = 0) -> None: self._next_index = start_index self._lock = threading.Lock() - def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002 - if not instances: - raise ValueError("instances must not be empty") + def select(self, task: Task, replicas: list[ReplicaInfo]) -> int: # noqa: ARG002 + if not replicas: + raise ValueError("replicas must not be empty") - n = len(instances) + n = len(replicas) with self._lock: idx = self._next_index % n self._next_index = (self._next_index + 1) % n @@ -117,22 +100,22 @@ def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG class LeastQueueLengthBalancer(LoadBalancer): - """Select the instance with the smallest ``queue_length``. + """Select the replica with the smallest ``queue_length``. - If multiple instances share the same minimum queue length, one of them is + If multiple replicas share the same minimum queue length, one of them is chosen uniformly at random. Raises: - ValueError: If any instance has a negative ``queue_length``. + ValueError: If any replica has a negative ``queue_length``. """ - def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002 - if not instances: - raise ValueError("instances must not be empty") + def select(self, task: Task, replicas: list[ReplicaInfo]) -> int: # noqa: ARG002 + if not replicas: + raise ValueError("replicas must not be empty") - queue_lengths = [inst.queue_length for inst in instances] + queue_lengths = [rep.queue_length for rep in replicas] if any(q < 0 for q in queue_lengths): - raise ValueError("queue_length must be non-negative for all instances") + raise ValueError("queue_length must be non-negative for all replicas") min_q = min(queue_lengths) candidates = [i for i, q in enumerate(queue_lengths) if q == min_q] return random.choice(candidates) diff --git a/vllm_omni/distributed/omni_coordinator/messages.py b/vllm_omni/distributed/omni_coordinator/messages.py index 2bb590139e2..5861aaf992d 100644 --- a/vllm_omni/distributed/omni_coordinator/messages.py +++ b/vllm_omni/distributed/omni_coordinator/messages.py @@ -8,23 +8,23 @@ class StageStatus(str, Enum): - """Enumeration for stage instance status.""" + """Enumeration for stage replica status.""" - UP = "up" # Instance is ready and available - DOWN = "down" # Instance is shutdown gracefully - ERROR = "error" # Instance encountered an error or timeout + UP = "up" # Replica is ready and available + DOWN = "down" # Replica is shutdown gracefully + ERROR = "error" # Replica encountered an error or timeout @dataclass -class InstanceEvent: +class ReplicaEvent: """Wire payload from OmniCoordClientForStage to OmniCoordinator. Schema for Stage → Coordinator events over ZMQ: input_addr, output_addr, stage_id, status, queue_length, event_type. """ - input_addr: str # Stage instance input ZMQ address (e.g., "tcp://host:port") - output_addr: str # Stage instance output ZMQ address (e.g., "tcp://host:port") + input_addr: str # Stage replica input ZMQ address (e.g., "tcp://host:port") + output_addr: str # Stage replica output ZMQ address (e.g., "tcp://host:port") stage_id: int # Stage ID event_type: str # "update" | "heartbeat" status: StageStatus # Current status @@ -32,30 +32,30 @@ class InstanceEvent: @dataclass -class InstanceInfo: - """Metadata for a single stage instance. +class ReplicaInfo: + """Metadata for a single stage replica. This type is stored in OmniCoordinator's internal registry and is also - published to hubs via :class:`InstanceList`. + published to hubs via :class:`ReplicaList`. """ - input_addr: str # Stage instance input ZMQ address (e.g., "tcp://host:port") - output_addr: str # Stage instance output ZMQ address (e.g., "tcp://host:port") - stage_id: int # Stage ID of this instance - status: StageStatus # Current status of the instance - queue_length: int # Current queue length of this instance + input_addr: str # Stage replica input ZMQ address (e.g., "tcp://host:port") + output_addr: str # Stage replica output ZMQ address (e.g., "tcp://host:port") + stage_id: int # Stage ID of this replica + status: StageStatus # Current status of the replica + queue_length: int # Current queue length of this replica last_heartbeat: float # Timestamp of the last heartbeat received (seconds) - registered_at: float # Timestamp when the instance was registered (seconds) + registered_at: float # Timestamp when the replica was registered (seconds) @dataclass -class InstanceList: - """Container for instance list updates. +class ReplicaList: + """Container for replica list updates. - OmniCoordinator publishes an :class:`InstanceList` whenever its view of - active instances changes. OmniCoordClientForHub caches the latest value + OmniCoordinator publishes a :class:`ReplicaList` whenever its view of + active replicas changes. OmniCoordClientForHub caches the latest value and exposes it to AsyncOmni and the load balancer. """ - instances: list[InstanceInfo] + replicas: list[ReplicaInfo] timestamp: float # Time when the list was last updated (seconds) diff --git a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_hub.py b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_hub.py index 9081e45917c..7044ff36357 100644 --- a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_hub.py +++ b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_hub.py @@ -9,16 +9,16 @@ import zmq -from .messages import InstanceInfo, InstanceList, StageStatus +from .messages import ReplicaInfo, ReplicaList, StageStatus logger = logging.getLogger(__name__) class OmniCoordClientForHub: - """Client for AsyncOmni side to receive instance list updates. + """Client for AsyncOmni side to receive replica list updates. This client maintains a SUB socket connected to OmniCoordinator's PUB - endpoint and caches the latest :class:`InstanceList` in memory for use by + endpoint and caches the latest :class:`ReplicaList` in memory for use by the load balancer and routing logic. """ @@ -28,7 +28,7 @@ def __init__(self, coord_zmq_addr: str) -> None: self._ctx = zmq.Context() self._lock = threading.Lock() - self._instance_list: InstanceList | None = None + self._replica_list: ReplicaList | None = None self._closed = False self._stop_event = threading.Event() self._init_done = threading.Event() @@ -41,29 +41,29 @@ def __init__(self, coord_zmq_addr: str) -> None: if self._init_error: raise RuntimeError(f"Failed to connect to coordinator at {self._coord_zmq_addr}") from self._init_error[0] - def _decode_instance_list(self, payload: dict[str, Any]) -> InstanceList: - """Convert a JSON-decoded dict into an :class:`InstanceList`.""" - instances_payload = payload.get("instances", []) - instances: list[InstanceInfo] = [] - for inst in instances_payload: - instances.append( - InstanceInfo( - input_addr=inst["input_addr"], - output_addr=inst["output_addr"], - stage_id=int(inst["stage_id"]), - status=StageStatus(inst["status"]), - queue_length=int(inst["queue_length"]), - last_heartbeat=float(inst["last_heartbeat"]), - registered_at=float(inst["registered_at"]), + def _decode_replica_list(self, payload: dict[str, Any]) -> ReplicaList: + """Convert a JSON-decoded dict into a :class:`ReplicaList`.""" + replicas_payload = payload.get("replicas", []) + replicas: list[ReplicaInfo] = [] + for rep in replicas_payload: + replicas.append( + ReplicaInfo( + input_addr=rep["input_addr"], + output_addr=rep["output_addr"], + stage_id=int(rep["stage_id"]), + status=StageStatus(rep["status"]), + queue_length=int(rep["queue_length"]), + last_heartbeat=float(rep["last_heartbeat"]), + registered_at=float(rep["registered_at"]), ) ) timestamp = float(payload.get("timestamp", time())) - return InstanceList(instances=instances, timestamp=timestamp) + return ReplicaList(replicas=replicas, timestamp=timestamp) def _recv_loop(self) -> None: - """Background loop that receives and caches instance lists.""" - sub = None + """Background loop that receives and caches replica lists.""" + sub: zmq.Socket | None = None try: sub = self._ctx.socket(zmq.SUB) sub.setsockopt(zmq.SUBSCRIBE, b"") @@ -115,9 +115,9 @@ def _recv_loop(self) -> None: try: payload = json.loads(data.decode("utf-8")) - inst_list = self._decode_instance_list(payload) + rep_list = self._decode_replica_list(payload) with self._lock: - self._instance_list = inst_list + self._replica_list = rep_list except ( json.JSONDecodeError, KeyError, @@ -125,7 +125,7 @@ def _recv_loop(self) -> None: TypeError, AttributeError, ) as e: - logger.warning("Invalid instance list message, skipping: %s", e) + logger.warning("Invalid replica list message, skipping: %s", e) finally: try: if sub is not None: @@ -137,22 +137,22 @@ def _recv_loop(self) -> None: except zmq.ZMQError: pass - def get_instance_list(self) -> InstanceList: - """Return the latest cached :class:`InstanceList`. + def get_replica_list(self) -> ReplicaList: + """Return the latest cached :class:`ReplicaList`. If no update has been received yet, returns an empty list with ``timestamp=0.0``. """ with self._lock: - if self._instance_list is None: - return InstanceList(instances=[], timestamp=0.0) - return self._instance_list - - def get_instances_for_stage(self, stage_id: int) -> InstanceList: - """Return instances filtered by ``stage_id``.""" - base = self.get_instance_list() - filtered = [inst for inst in base.instances if inst.stage_id == stage_id] - return InstanceList(instances=filtered, timestamp=base.timestamp) + if self._replica_list is None: + return ReplicaList(replicas=[], timestamp=0.0) + return self._replica_list + + def get_replicas_for_stage(self, stage_id: int) -> ReplicaList: + """Return replicas filtered by ``stage_id``.""" + base = self.get_replica_list() + filtered = [rep for rep in base.replicas if rep.stage_id == stage_id] + return ReplicaList(replicas=filtered, timestamp=base.timestamp) def close(self) -> None: """Close the SUB socket and stop the background thread.""" diff --git a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py index cd3c99ab812..f6ea1605c7b 100644 --- a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py +++ b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py @@ -1,24 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import json import logging import threading import time +from collections.abc import Callable from dataclasses import asdict import zmq -from .messages import InstanceEvent, StageStatus +from .messages import ReplicaEvent, StageStatus logger = logging.getLogger(__name__) class OmniCoordClientForStage: - """Client used by stage instances to send events to OmniCoordinator. + """Client used by stage replicas to send events to OmniCoordinator. This client maintains a DEALER socket connected to OmniCoordinator's - ROUTER endpoint and sends JSON-encoded events describing instance status. + ROUTER endpoint and sends JSON-encoded events describing replica status. """ def __init__( @@ -49,6 +51,11 @@ def __init__( self._heartbeat_interval = 5.0 self._stop_event = threading.Event() self._send_lock = threading.RLock() + # Optional hook invoked from the heartbeat thread before each + # heartbeat send. Stages set this to refresh ``queue_length`` (or any + # other field) just-in-time. Exceptions raised by the hook are + # suppressed and logged. + self._on_heartbeat: Callable[[], None] | None = None self._send_event("update") @@ -100,11 +107,13 @@ def _reconnect(self, max_retries: int = 3, retry_interval: float = 5.0) -> bool: return False def _send_event(self, event_type: str) -> None: - """Send an InstanceEvent to OmniCoordinator. + """Send a ReplicaEvent to OmniCoordinator. Wire format: input_addr, output_addr, stage_id, status, queue_length, event_type. - For "update": includes status and queue_length from instance state. - For "heartbeat": status and queue_length are null. + For "update": includes status and queue_length from replica state. + For "heartbeat": includes the latest queue_length (refreshed by the + optional ``_on_heartbeat`` hook) so the coordinator can propagate + live load to load balancers between explicit ``update`` events. On send failure (ZMQError / RuntimeError), attempts to reconnect up to 3 times (5s sleep each) and retries the send once after a @@ -114,7 +123,7 @@ def _send_event(self, event_type: str) -> None: if self._closed: raise RuntimeError("Client already closed") - event = InstanceEvent( + event = ReplicaEvent( input_addr=self._input_addr, output_addr=self._output_addr, stage_id=self._stage_id, @@ -150,7 +159,7 @@ def update_info( status: StageStatus | None = None, queue_length: int | None = None, ) -> None: - """Update instance information and notify OmniCoordinator. + """Update replica information and notify OmniCoordinator. At least one of ``status`` or ``queue_length`` must be provided. """ @@ -174,6 +183,15 @@ def _heartbeat_loop(self) -> None: if self._closed: break + # Invoke the optional pre-heartbeat hook so callers (e.g. the + # engine subprocess) can refresh ``queue_length`` from live state + # before the heartbeat is sent. Exceptions are swallowed so a + # buggy hook never breaks the heartbeat loop. + hook = self._on_heartbeat + if hook is not None: + with contextlib.suppress(Exception): + hook() + try: self._send_event("heartbeat") except (RuntimeError, zmq.ZMQError) as e: diff --git a/vllm_omni/distributed/omni_coordinator/omni_coordinator.py b/vllm_omni/distributed/omni_coordinator/omni_coordinator.py index 2c7c8fbb995..edd9d80e3b1 100644 --- a/vllm_omni/distributed/omni_coordinator/omni_coordinator.py +++ b/vllm_omni/distributed/omni_coordinator/omni_coordinator.py @@ -11,22 +11,22 @@ import zmq -from .messages import InstanceEvent, InstanceInfo, InstanceList, StageStatus +from .messages import ReplicaEvent, ReplicaInfo, ReplicaList, StageStatus logger = logging.getLogger(__name__) class OmniCoordinator: - """Coordinator for stage instances and hub clients. + """Coordinator for stage replicas and hub clients. - This service receives instance events from :class:`OmniCoordClientForStage` - via a ZMQ ROUTER socket and publishes active instance lists to + This service receives replica events from :class:`OmniCoordClientForStage` + via a ZMQ ROUTER socket and publishes active replica lists to :class:`OmniCoordClientForHub` via a PUB socket. - The coordinator maintains an in-memory registry of all known instances, + The coordinator maintains an in-memory registry of all known replicas, including their status, queue length, and heartbeat timestamps. A background thread periodically checks for heartbeat timeouts and marks - unhealthy instances as ``StageStatus.ERROR``. + unhealthy replicas as ``StageStatus.ERROR``. """ def __init__( @@ -40,7 +40,7 @@ def __init__( Args: router_zmq_addr: ZMQ address to bind the ROUTER socket. pub_zmq_addr: ZMQ address to bind the PUB socket. - heartbeat_timeout: Seconds before an instance is considered + heartbeat_timeout: Seconds before a replica is considered unhealthy if no heartbeat / update is received. """ self._router_zmq_addr = router_zmq_addr @@ -55,7 +55,7 @@ def __init__( self._pub = self._ctx.socket(zmq.PUB) self._pub.bind(self._pub_zmq_addr) - self._instances: dict[str, InstanceInfo] = {} + self._replicas: dict[str, ReplicaInfo] = {} self._lock = threading.Lock() self._pub_lock = threading.Lock() @@ -75,43 +75,43 @@ def __init__( self._periodic_thread = threading.Thread(target=self._periodic_loop, daemon=True) self._periodic_thread.start() - def get_active_instances(self) -> InstanceList: - """Return an :class:`InstanceList` of active (UP) instances only.""" + def get_active_replicas(self) -> ReplicaList: + """Return a :class:`ReplicaList` of active (UP) replicas only.""" with self._lock: - active = [inst for inst in self._instances.values() if inst.status == StageStatus.UP] - return InstanceList(instances=active, timestamp=time()) + active = [rep for rep in self._replicas.values() if rep.status == StageStatus.UP] + return ReplicaList(replicas=active, timestamp=time()) - def add_new_instance(self, event: InstanceEvent) -> None: - """Add a new instance based on an incoming event.""" + def add_new_replica(self, event: ReplicaEvent) -> None: + """Add a new replica based on an incoming event.""" with self._lock: - self._add_new_instance_locked(event) + self._add_new_replica_locked(event) self._schedule_broadcast() - def update_instance_info(self, event: InstanceEvent) -> None: - """Update an existing instance based on an incoming event.""" + def update_replica_info(self, event: ReplicaEvent) -> None: + """Update an existing replica based on an incoming event.""" with self._lock: - self._update_instance_info_locked(event) + self._update_replica_info_locked(event) self._schedule_broadcast() - def remove_instance(self, event: InstanceEvent) -> None: - """Mark an instance as removed / down based on an incoming event. + def remove_replica(self, event: ReplicaEvent) -> None: + """Mark a replica as removed / down based on an incoming event. - This marks the instance's status as DOWN or ERROR (depending on the + This marks the replica's status as DOWN or ERROR (depending on the event) but keeps it in the internal registry. It is removed from the - *active* instance list published to hubs. + *active* replica list published to hubs. """ with self._lock: - self._remove_instance_locked(event) + self._remove_replica_locked(event) self._schedule_broadcast() - def publish_instance_list_update(self) -> bool: - """Publish the current active instance list to all subscribers. + def publish_replica_list_update(self) -> bool: + """Publish the current active replica list to all subscribers. Returns: True if the PUB send succeeded, False if it was dropped (e.g. socket not ready when using ``zmq.NOBLOCK``). """ - active_list = self.get_active_instances() + active_list = self.get_active_replicas() payload = asdict(active_list) data = json.dumps(payload).encode("utf-8") @@ -133,12 +133,12 @@ def _schedule_broadcast(self) -> None: with self._pending_lock: self._pending_broadcast = True - def _mark_instance_error_locked(self, info: InstanceInfo) -> None: - """Mark instance as ERROR (e.g. after heartbeat timeout).""" + def _mark_replica_error_locked(self, info: ReplicaInfo) -> None: + """Mark replica as ERROR (e.g. after heartbeat timeout).""" info.status = StageStatus.ERROR def _check_heartbeat_timeouts(self) -> None: - """Mark instances as ERROR if their heartbeat has timed out.""" + """Mark replicas as ERROR if their heartbeat has timed out.""" now = time() timed_out = False gc_ttl = 600.0 # 10 minutes @@ -146,17 +146,17 @@ def _check_heartbeat_timeouts(self) -> None: with self._lock: to_delete: list[str] = [] - for input_addr, info in self._instances.items(): + for input_addr, info in self._replicas.items(): if info.status == StageStatus.UP and now - info.last_heartbeat > self._heartbeat_timeout: - self._mark_instance_error_locked(info) + self._mark_replica_error_locked(info) timed_out = True elif info.status in (StageStatus.DOWN, StageStatus.ERROR) and now - info.last_heartbeat > gc_ttl: to_delete.append(input_addr) for input_addr in to_delete: - del self._instances[input_addr] + del self._replicas[input_addr] if timed_out: - # Instance liveness changed; request broadcast. + # Replica liveness changed; request broadcast. self._schedule_broadcast() def close(self) -> None: @@ -187,10 +187,10 @@ def close(self) -> None: except zmq.ZMQError: pass - def _parse_instance_event(self, data: dict[str, Any]) -> InstanceEvent | None: - """Parse wire payload dict into InstanceEvent. Returns None if invalid.""" + def _parse_replica_event(self, data: dict[str, Any]) -> ReplicaEvent | None: + """Parse wire payload dict into ReplicaEvent. Returns None if invalid.""" try: - return InstanceEvent( + return ReplicaEvent( input_addr=str(data["input_addr"]), output_addr=str(data["output_addr"]), stage_id=int(data["stage_id"]), @@ -202,7 +202,7 @@ def _parse_instance_event(self, data: dict[str, Any]) -> InstanceEvent | None: return None def _recv_loop(self) -> None: - """Background loop that receives and processes instance events.""" + """Background loop that receives and processes replica events.""" while self._running: try: frames = self._router.recv_multipart() @@ -219,12 +219,12 @@ def _recv_loop(self) -> None: payload = frames[-1] try: data = json.loads(payload.decode("utf-8")) - event = self._parse_instance_event(data) + event = self._parse_replica_event(data) except json.JSONDecodeError as e: - logger.warning("Invalid JSON in instance event, dropping: %s", e) + logger.warning("Invalid JSON in replica event, dropping: %s", e) continue if event is None: - logger.warning("Malformed instance event, dropping") + logger.warning("Malformed replica event, dropping") continue self._handle_event(event) @@ -234,7 +234,10 @@ def _periodic_loop(self) -> None: Heartbeat timeouts are checked on their original cadence, while all broadcast requests are coalesced and flushed at most once per - ``_publish_min_interval``. + ``_publish_min_interval``. The heartbeat-check tick also schedules a + keepalive broadcast so late-joining hubs (which miss any PUB sends + that happened before their SUB connected) catch up within at most + ``heartbeat_interval`` seconds. """ heartbeat_interval = max(1.0, min(self._heartbeat_timeout / 2.0, 5.0)) loop_interval = self._publish_min_interval @@ -245,6 +248,13 @@ def _periodic_loop(self) -> None: if now - last_heartbeat_check >= heartbeat_interval: self._check_heartbeat_timeouts() + # Keepalive broadcast: ZMQ PUB doesn't queue for late + # subscribers, so an OmniCoordClientForHub that connects + # after the initial UP events miss them entirely and would + # never see the current replica list otherwise. Scheduling a + # broadcast on every heartbeat tick caps that staleness at + # ``heartbeat_interval`` without flooding the wire. + self._schedule_broadcast() last_heartbeat_check = now with self._pending_lock: @@ -256,43 +266,51 @@ def _periodic_loop(self) -> None: continue # Publish outside lock. Clear pending only on success. - if self.publish_instance_list_update(): + if self.publish_replica_list_update(): with self._pending_lock: self._pending_broadcast = False if self._stop_event.wait(timeout=loop_interval): break - def _handle_event(self, event: InstanceEvent) -> None: + def _handle_event(self, event: ReplicaEvent) -> None: """Dispatch an incoming event to the appropriate handler.""" try: input_addr = event.input_addr - # Heartbeat: only update last_heartbeat; if previously ERROR, - # promote back to UP and broadcast once. + # Heartbeat: refresh last_heartbeat and queue_length. The stage + # client refreshes queue_length just-in-time via its + # ``_on_heartbeat`` hook, so heartbeats are the only periodic + # source of live load for LeastQueueLengthBalancer; failing to + # propagate it here would let the policy route on stale data. + # If previously ERROR, promote back to UP and broadcast once. if event.event_type == "heartbeat": promote = False + queue_changed = False with self._lock: - info = self._instances.get(input_addr) + info = self._replicas.get(input_addr) if info is not None: info.last_heartbeat = time() + if event.queue_length is not None and info.queue_length != event.queue_length: + info.queue_length = event.queue_length + queue_changed = True if info.status == StageStatus.ERROR: info.status = StageStatus.UP promote = True - if promote: + if promote or queue_changed: self._schedule_broadcast() return # Check-and-act under single lock to avoid TOCTOU race (duplicate - # registration when concurrent events arrive for the same instance). + # registration when concurrent events arrive for the same replica). with self._lock: - if input_addr not in self._instances: - self._add_new_instance_locked(event) + if input_addr not in self._replicas: + self._add_new_replica_locked(event) else: if event.status == StageStatus.DOWN: - self._remove_instance_locked(event) + self._remove_replica_locked(event) else: - self._update_instance_info_locked(event) + self._update_replica_info_locked(event) # Any non-heartbeat state change that affects the active list # is coalesced and flushed via the periodic loop. @@ -300,7 +318,7 @@ def _handle_event(self, event: InstanceEvent) -> None: except (KeyError, ValueError, TypeError) as e: logger.warning("Dropping malformed event: %s", e) - def _add_new_instance_locked(self, event: InstanceEvent) -> None: + def _add_new_replica_locked(self, event: ReplicaEvent) -> None: input_addr = event.input_addr if not input_addr: raise KeyError("input_addr required") @@ -309,7 +327,7 @@ def _add_new_instance_locked(self, event: InstanceEvent) -> None: raise KeyError("stage_id required and must be non-negative") now = time() - info = InstanceInfo( + info = ReplicaInfo( input_addr=input_addr, output_addr=event.output_addr, stage_id=stage_id, @@ -318,11 +336,11 @@ def _add_new_instance_locked(self, event: InstanceEvent) -> None: last_heartbeat=now, registered_at=now, ) - self._instances[input_addr] = info + self._replicas[input_addr] = info - def _update_instance_info_locked(self, event: InstanceEvent) -> None: + def _update_replica_info_locked(self, event: ReplicaEvent) -> None: input_addr = event.input_addr - info = self._instances[input_addr] + info = self._replicas[input_addr] if event.status is not None: info.status = event.status @@ -330,9 +348,9 @@ def _update_instance_info_locked(self, event: InstanceEvent) -> None: if event.queue_length is not None: info.queue_length = event.queue_length - def _remove_instance_locked(self, event: InstanceEvent) -> None: + def _remove_replica_locked(self, event: ReplicaEvent) -> None: input_addr = event.input_addr - info = self._instances.get(input_addr) + info = self._replicas.get(input_addr) if info is None: return diff --git a/vllm_omni/distributed/omni_coordinator/runtime.py b/vllm_omni/distributed/omni_coordinator/runtime.py new file mode 100644 index 00000000000..71df7c0521c --- /dev/null +++ b/vllm_omni/distributed/omni_coordinator/runtime.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Lifecycle wrapper around :class:`OmniCoordinator`. + +``OmniCoordinatorRuntime`` is the single-purpose owner of the head-side +coordinator process artifacts: it picks two free TCP ports, constructs an +:class:`OmniCoordinator` bound to them, exposes the resulting addresses, and +provides a single ``close()`` method to tear everything down. + +The ROUTER address is later handed to :class:`OmniMasterServer` so it can be +published to registering replicas; the PUB address is handed to the +``Orchestrator``, which constructs its :class:`OmniCoordClientForHub` against +it. +""" + +from __future__ import annotations + +import logging + +from vllm.utils.network_utils import get_open_ports_list + +from .omni_coordinator import OmniCoordinator + +logger = logging.getLogger(__name__) + + +class OmniCoordinatorRuntime: + """Own one :class:`OmniCoordinator` and the two ports it binds. + + Constructor binds; :meth:`close` tears down. The class deliberately does + not expose the coordinator instance — callers should consume the + coordinator only via its wire protocol through + :class:`OmniCoordClientForStage` and :class:`OmniCoordClientForHub`. + """ + + def __init__( + self, + *, + host: str, + heartbeat_timeout: float, + ) -> None: + if not host: + raise ValueError("host must be a non-empty string") + if heartbeat_timeout <= 0: + raise ValueError("heartbeat_timeout must be positive") + + router_port, pub_port = get_open_ports_list(count=2) + self.router_address: str = f"tcp://{host}:{router_port}" + self.pub_address: str = f"tcp://{host}:{pub_port}" + + self._closed = False + self._coordinator = OmniCoordinator( + router_zmq_addr=self.router_address, + pub_zmq_addr=self.pub_address, + heartbeat_timeout=heartbeat_timeout, + ) + + logger.info( + "[OmniCoordinatorRuntime] Started (router=%s pub=%s heartbeat_timeout=%.1fs)", + self.router_address, + self.pub_address, + heartbeat_timeout, + ) + + def close(self) -> None: + """Tear down the underlying coordinator. Idempotent.""" + if self._closed: + return + self._closed = True + try: + self._coordinator.close() + except Exception: + logger.exception("[OmniCoordinatorRuntime] coordinator close failed") diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index f01094befab..2af6f93d1ff 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -163,6 +163,10 @@ def _add_omni_specific_args(cls, parser: argparse.ArgumentParser) -> argparse.Ar omni_master_address: str | None = None omni_master_port: int | None = None + # OmniCoordinator integration knobs (process-local). + omni_dp_size_local: int = 1 + omni_lb_policy: str = "random" + omni_heartbeat_timeout: float = 30.0 stage_configs_path: str | None = None output_modalities: list[str] | None = None log_stats: bool = False diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index ed4eed80844..a54f803d55c 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -18,7 +18,7 @@ import time import uuid import weakref -from collections.abc import Mapping, Sequence +from collections.abc import Awaitable, Callable, Mapping, Sequence from contextlib import ExitStack from dataclasses import asdict from typing import TYPE_CHECKING, Any @@ -44,6 +44,13 @@ from vllm_omni.distributed.omni_connectors.utils.initialization import ( resolve_omni_kv_config_for_stage, ) +from vllm_omni.distributed.omni_coordinator import ( + LeastQueueLengthBalancer, + LoadBalancer, + LoadBalancingPolicy, + RandomBalancer, + RoundRobinBalancer, +) from vllm_omni.engine import OmniEngineCoreRequest from vllm_omni.engine.orchestrator import Orchestrator from vllm_omni.engine.serialization import ( @@ -127,12 +134,49 @@ # trigger the ``create_model_config`` guard). _PARENT_ARGS_STRIP: frozenset[str] = frozenset({"stage_configs_path"}) + +def _build_load_balancer_factory(policy: str) -> Callable[[], LoadBalancer]: + """Translate ``--omni-lb-policy`` (string) into a per-pool LB factory.""" + try: + normalized = LoadBalancingPolicy(policy) + except ValueError as exc: + valid = ", ".join(p.value for p in LoadBalancingPolicy) + raise ValueError(f"unknown --omni-lb-policy {policy!r} (valid: {valid})") from exc + if normalized is LoadBalancingPolicy.RANDOM: + return RandomBalancer + if normalized is LoadBalancingPolicy.ROUND_ROBIN: + return RoundRobinBalancer + if normalized is LoadBalancingPolicy.LEAST_QUEUE_LENGTH: + return LeastQueueLengthBalancer + raise ValueError(f"unhandled load balancing policy {normalized!r}") + + # Fields always populated by callers (via ``from_cli_args`` / ``asdict``) so # their presence as an override is never a surprise — suppress the # "override ignored" warning for these. _PARENT_ARGS_NO_WARN: frozenset[str] = frozenset({"model"}) +@dataclasses.dataclass +class _StageRemoteFactoryContext: + """Per-stage context cached by AsyncOmniEngine for dynamic replica attach. + + Populated once during ``_bootstrap_orchestrator`` from the per-stage + init plans. ``_build_remote_replica`` consumes it to construct the + right head-side stage client when a headless replica registers. + """ + + stage_id: int + stage_type: str + stage_cfg: Any + base_metadata: Any + # LLM-only fields: + vllm_config: Any | None = None + executor_class: type | None = None + # Diffusion-only fields: + diffusion_batch_size: int = 1 + + def _inject_global_id(target: Any, request_id: str) -> None: """Inject global_request_id into a prompt dict's additional_information.""" if isinstance(target, dict): @@ -189,27 +233,24 @@ def _apply_omni_final_stage_metadata( def _weak_shutdown_async_omni_engine( - orchestrator_thread: threading.Thread | None, - request_queue: janus.Queue[dict[str, Any]] | None, - output_queue: janus.Queue[dict[str, Any]] | None, - rpc_output_queue: janus.Queue[dict[str, Any]] | None, + orchestrator_thread: threading.Thread, + request_queue: janus.Queue[dict[str, Any]], + output_queue: janus.Queue[dict[str, Any]], + rpc_output_queue: janus.Queue[dict[str, Any]], ) -> None: """Best-effort orchestrator cleanup for GC finalization.""" try: - if request_queue is not None: - request_queue.sync_q.put_nowait({"type": "shutdown"}) + request_queue.sync_q.put_nowait({"type": "shutdown"}) except Exception: pass try: - if orchestrator_thread is not None and orchestrator_thread.is_alive(): + if orchestrator_thread.is_alive(): orchestrator_thread.join(timeout=10) except Exception: pass for q in (request_queue, output_queue, rpc_output_queue): - if q is None: - continue try: q.close() except Exception: @@ -280,6 +321,23 @@ def __init__( self._omni_master_port: int | None = kwargs.get("omni_master_port") self._omni_master_server: OmniMasterServer | None = None + # New omni-coordinator flags. Consumed only in single_stage_mode. + # ``omni_dp_size_local`` is process-local: each invocation (head and + # every headless) launches that many replicas for its own stage. + self._omni_dp_size_local: int = int(kwargs.get("omni_dp_size_local") or 1) + if self._omni_dp_size_local < 1: + raise ValueError(f"--omni-dp-size-local must be >= 1, got {self._omni_dp_size_local}") + self._omni_lb_policy: str = str(kwargs.get("omni_lb_policy") or "random") + self._omni_heartbeat_timeout: float = float(kwargs.get("omni_heartbeat_timeout") or 30.0) + if self._omni_heartbeat_timeout <= 0: + raise ValueError(f"--omni-heartbeat-timeout must be > 0, got {self._omni_heartbeat_timeout}") + # Coordinator runtime (head-distributed only). + self._coordinator_runtime: Any | None = None + # Per-stage construction context, captured after _initialize_stages + # and used by ``_build_remote_replica`` (the RemoteReplicaFactory + # passed to Orchestrator) when a headless replica registers. + self._stage_remote_factory_contexts: dict[int, _StageRemoteFactoryContext] = {} + if single_stage_mode: logger.info( "[AsyncOmniEngine] Single-stage mode enabled (stage_id_filter=%s, master=%s:%s)", @@ -300,9 +358,14 @@ def __init__( self.supported_tasks: tuple[str, ...] = ("generate",) self.default_sampling_params_list: list[Any] = [] self.stage_metadata: list[dict[str, Any]] = [] - self.request_queue: janus.Queue[dict[str, Any]] | None = None - self.output_queue: janus.Queue[dict[str, Any]] | None = None - self.rpc_output_queue: janus.Queue[dict[str, Any]] | None = None + # Janus queues are constructed here (not deferred to the orchestrator + # thread) so the master server's ROUTER thread always sees a non-None + # ``self.request_queue`` when on_register fires. ``async_q`` lazily + # binds to whatever event loop first awaits on it, which is the + # orchestrator loop — so cross-thread use stays correct. + self.request_queue: janus.Queue[dict[str, Any]] = janus.Queue() + self.output_queue: janus.Queue[dict[str, Any]] = janus.Queue() + self.rpc_output_queue: janus.Queue[dict[str, Any]] = janus.Queue() self._shutdown_called = False self._weak_finalizer: weakref.finalize | None = None self._rpc_lock = threading.Lock() @@ -422,30 +485,37 @@ def _shutdown_initialized_clients(clients: Sequence[Any]) -> None: ) def _validate_single_stage_mode_replica_constraints(self) -> None: - """Reject unsupported replica fan-out in single-stage mode.""" + """Apply --omni-dp-size-local to the local stage's runtime.num_replicas. + + In the previous revision this method rejected LLM stages with + ``num_replicas > 1``. The whole point of ``--omni-dp-size-local`` is + to lift that restriction for the *local* stage, so the rejection is + gone. We now use this hook to write ``--omni-dp-size-local`` onto + the self-stage's runtime config so downstream code + (``compute_replica_layout`` → ``_build_logical_stage_init_plans``) + sees a consistent view. + """ if not self.single_stage_mode: return + target_stage_id = self._single_stage_id_filter + if target_stage_id is None: + return - unsupported: list[tuple[int, int]] = [] for idx, stage_cfg in enumerate(self.stage_configs): - runtime_cfg = getattr(stage_cfg, "runtime", {}) - num_replicas = int( - runtime_cfg.get("num_replicas", 1) - if hasattr(runtime_cfg, "get") - else getattr(runtime_cfg, "num_replicas", 1) - ) - if num_replicas <= 1: - continue - if getattr(stage_cfg, "stage_type", "llm") == "diffusion": - continue stage_id = int(getattr(stage_cfg, "stage_id", idx)) - unsupported.append((stage_id, num_replicas)) - - if unsupported: - raise ValueError( - "single_stage_mode only supports num_replicas > 1 for diffusion stages; " - f"found non-diffusion stages {unsupported}" - ) + runtime_cfg = getattr(stage_cfg, "runtime", None) + if runtime_cfg is None: + continue + if stage_id == target_stage_id: + # Self stage: take --omni-dp-size-local from this process. + try: + runtime_cfg.num_replicas = self._omni_dp_size_local + except Exception: + if hasattr(runtime_cfg, "__setitem__"): + runtime_cfg["num_replicas"] = self._omni_dp_size_local + # Other stages keep their config-declared num_replicas; in + # head-distributed mode they will be launched as ``launch_mode + # == "remote"`` with the configured count. def _build_logical_stage_init_plans( self, @@ -519,9 +589,18 @@ def _build_logical_stage_init_plans( replica_metadata = extract_stage_metadata(replica_cfg) replica_metadata.replica_id = replica_id - if self.single_stage_mode: - if replica_metadata.stage_type != "diffusion": - replica_metadata.runtime_cfg = None + # In single_stage_mode the head only owns its self stage's + # replicas; the remote-stage metadata exists only so the + # orchestrator can route requests through StagePool. Wiping + # ``runtime_cfg`` there makes sense (we don't manage those + # devices). For the *self stage* we MUST keep the + # per-replica runtime so ``setup_stage_devices`` can apply + # the device split from ``compute_replica_layout`` — + # otherwise every replica inherits the parent's full + # CUDA_VISIBLE_DEVICES and stacks on cuda:0 (OOM with any + # model whose footprint exceeds ~1/(2N) of the card). + if launch_mode == "remote" and replica_metadata.stage_type != "diffusion": + replica_metadata.runtime_cfg = None replicas.append( ReplicaInitPlan( @@ -568,11 +647,22 @@ def _start_omni_master_server(self, stage_plans: Sequence[LogicalStageInitPlan]) all_stage_ids.append(stage_id) stage_replica_counts[stage_id] = len(plan.replicas) + # Start the OmniCoordinator runtime first so its router address is + # available to publish in every registration reply. + from vllm_omni.distributed.omni_coordinator import OmniCoordinatorRuntime + + self._coordinator_runtime = OmniCoordinatorRuntime( + host=self._omni_master_address, + heartbeat_timeout=self._omni_heartbeat_timeout, + ) + self._omni_master_server = OmniMasterServer( master_address=self._omni_master_address, master_port=self._omni_master_port, stage_ids=all_stage_ids, stage_replica_counts=stage_replica_counts, + coordinator_router_address=self._coordinator_runtime.router_address, + on_register=self._dispatch_master_register, ) self._omni_master_server.start() logger.info( @@ -580,6 +670,175 @@ def _start_omni_master_server(self, stage_plans: Sequence[LogicalStageInitPlan]) all_stage_ids, ) + # ------------------------------------------------------------------ + # Remote replica factory (head-side client construction) + # ------------------------------------------------------------------ + + def _capture_stage_factory_contexts( + self, stage_plans: Sequence[LogicalStageInitPlan] + ) -> dict[int, _StageRemoteFactoryContext]: + """Snapshot per-stage construction context for dynamic replica attach. + + Called once after ``_initialize_stages`` finishes. The captured + context holds everything :meth:`_build_remote_replica` needs to + build a fresh head-side client when a new headless replica + registers (vllm_config / executor_class for LLM, batch_size for + diffusion, plus the base stage metadata). + + Per-replica fields like ``replica_id`` are filled in at build + time, not at capture time. + """ + contexts: dict[int, _StageRemoteFactoryContext] = {} + for plan in stage_plans: + if not plan.replicas: + # Stage was declared but has zero replicas locally; we still + # want to be able to attach incoming headless ones, so use + # the stage_cfg-derived context if any replica plan exists. + continue + template = plan.replicas[0] + stage_id = int(plan.configured_stage_id) + stage_type = template.metadata.stage_type or "llm" + contexts[stage_id] = _StageRemoteFactoryContext( + stage_id=stage_id, + stage_type=stage_type, + stage_cfg=template.stage_cfg, + base_metadata=template.metadata, + vllm_config=template.stage_vllm_config, + executor_class=template.executor_class, + diffusion_batch_size=self.diffusion_batch_size, + ) + return contexts + + async def _build_remote_replica(self, stage_id: int, replica_id: int) -> Any: + """Construct a head-side stage client for a newly-registered remote replica. + + Used by :class:`Orchestrator` as its ``remote_replica_factory``. + The orchestrator awaits this from its own asyncio loop, so the + client is created in the same loop that owns ZMQ sockets — no + cross-thread setup is required. + + Raises if the stage is not known or if its construction context + was not captured (e.g. the stage was empty at bring-up time). + """ + ctx = self._stage_remote_factory_contexts.get(stage_id) + if ctx is None: + raise RuntimeError( + f"no factory context captured for stage {stage_id}; " + f"known stages: {sorted(self._stage_remote_factory_contexts.keys())}" + ) + if self._omni_master_server is None: + raise RuntimeError("OmniMasterServer is not running; cannot build remote replica") + + alloc = self._omni_master_server.get_allocation(stage_id, replica_id=replica_id) + + # Build a per-replica copy of the base metadata so ``replica_id`` + # is correct (StageMetadata is a plain dataclass-like). + metadata = copy.copy(ctx.base_metadata) + try: + metadata.replica_id = replica_id + except Exception: + # Best-effort: if the metadata object is frozen / unusual, the + # downstream client will fall back to ``replica_id = 0``. + pass + + if ctx.stage_type == "diffusion": + client = StageDiffusionClient.from_addresses( + metadata, + request_address=alloc.input_bind_address, + response_address=alloc.output_bind_address, + batch_size=ctx.diffusion_batch_size, + ) + logger.info( + "[AsyncOmniEngine] Built remote diffusion client for stage=%d replica=%d (req=%s resp=%s)", + stage_id, + replica_id, + alloc.input_bind_address, + alloc.output_bind_address, + ) + return client + + # LLM path + if ctx.vllm_config is None or ctx.executor_class is None: + raise RuntimeError(f"stage {stage_id} factory context is missing vllm_config / executor_class") + + # The headless's StageEngineCoreProc subprocess calls + # vllm.v1.engine.core.startup_handshake at boot and blocks until the + # head's handshake ROUTER answers — without this step it hits the + # built-in 5-minute timeout and exits. The bootstrap (pre-allocated) + # path runs `connect_remote_engine_cores` to perform that + # handshake; dynamic attach must do the same, otherwise every + # replica that comes in via `on_register` (auto-assigned or + # explicit-id-beyond-pre-alloc) deadlocks. Run the blocking + # handshake in a thread so the orchestrator loop stays responsive, + # then build the async client on this loop where `make_async_mp_client` + # expects to be invoked. + master_server = self._omni_master_server + ctx_vllm_config = ctx.vllm_config + ctx_executor_class = ctx.executor_class + + def _run_handshake() -> Any: + with connect_remote_engine_cores( + vllm_config=ctx_vllm_config, + omni_master_server=master_server, + stage_id=stage_id, + replica_id=replica_id, + ) as remote_resources: + _engine_manager, _coordinator, addresses, _ = remote_resources + return _engine_manager, _coordinator, addresses + + engine_manager, coordinator, addresses = await asyncio.to_thread(_run_handshake) + + client_addresses: dict[str, str] = { + "input_address": addresses.inputs[0], + "output_address": addresses.outputs[0], + } + if addresses.frontend_stats_publish_address is not None: + client_addresses["stats_update_address"] = addresses.frontend_stats_publish_address + + client = StageEngineCoreClientBase.make_async_mp_client( + vllm_config=ctx_vllm_config, + executor_class=ctx_executor_class, + metadata=metadata, + client_addresses=client_addresses, + proc=None, + engine_manager=engine_manager, + coordinator=coordinator, + ) + logger.info( + "[AsyncOmniEngine] Built remote LLM client for stage=%d replica=%d (input=%s)", + stage_id, + replica_id, + client_addresses["input_address"], + ) + return client + + # ------------------------------------------------------------------ + # OmniCoordinator on_register proxy + # ------------------------------------------------------------------ + + def _dispatch_master_register(self, stage_id: int, replica_id: int, alloc: Any) -> None: + """Forward a master-server registration to the orchestrator queue. + + Called on :class:`OmniMasterServer`'s ROUTER thread (not the + orchestrator loop). Must return promptly. ``self.request_queue`` + is created in :meth:`__init__` before the master server starts, + so it is always non-None here; ``janus.Queue.sync_q`` is + thread-safe by construction. + """ + msg = { + "type": "register_remote_replica", + "stage_id": int(stage_id), + "replica_id": int(replica_id), + } + try: + self.request_queue.sync_q.put_nowait(msg) + except Exception: + logger.exception( + "[AsyncOmniEngine] Failed to enqueue register_remote_replica for stage=%d replica=%d", + stage_id, + replica_id, + ) + def _initialize_llm_replica( self, plan: ReplicaInitPlan, @@ -665,6 +924,11 @@ def _initialize_llm_replica( stage_init_timeout, ) if self.single_stage_mode and self._omni_master_server is not None: + coord_router_addr: str | None = ( + self._coordinator_runtime.router_address + if self._coordinator_runtime is not None + else None + ) engine_manager, coordinator, addresses = launch_stack.enter_context( launch_omni_core_engines( vllm_config=vllm_config, @@ -674,6 +938,7 @@ def _initialize_llm_replica( stage_id=plan.metadata.stage_id, stage_config=stage_cfg, replica_id=plan.replica_id, + omni_coordinator_address=coord_router_addr, ) ) else: @@ -809,12 +1074,20 @@ def _initialize_diffusion_replica( "[AsyncOmniEngine] Stage %s diffusion registration completed", plan.metadata.stage_id, ) + coord_router_addr: str | None = ( + self._coordinator_runtime.router_address + if self._coordinator_runtime is not None + else None + ) proc, _, _, _ = spawn_diffusion_proc( self.model, od_config, handshake_address=handshake_address, request_address=request_address, response_address=response_address, + omni_coordinator_address=coord_router_addr, + omni_stage_id=plan.metadata.stage_id, + omni_replica_id=plan.replica_id, ) complete_diffusion_handshake(proc, handshake_address, stage_init_timeout) logger.info( @@ -1028,7 +1301,11 @@ def _initialize_stages(self, stage_init_timeout: int) -> None: replicas_per_stage, replica_devices_map, ) + # Capture per-stage context now (before _start_omni_master_server) + # so the on_register proxy can build head-side clients for + # registrations that arrive immediately after the server starts. if self.single_stage_mode: + self._stage_remote_factory_contexts = self._capture_stage_factory_contexts(stage_plans) self._start_omni_master_server(stage_plans) stage_pools: list[StagePool] = [] @@ -1064,6 +1341,14 @@ def _initialize_stages(self, stage_init_timeout: int) -> None: self._omni_master_server.stop() except Exception: logger.exception("[AsyncOmniEngine] Failed to stop OmniMasterServer during stage-init cleanup") + if self._coordinator_runtime is not None: + try: + self._coordinator_runtime.close() + except Exception: + logger.exception( + "[AsyncOmniEngine] Failed to close OmniCoordinatorRuntime during stage-init cleanup" + ) + self._coordinator_runtime = None raise self.stage_pools = stage_pools @@ -1083,13 +1368,6 @@ def _initialize_stages(self, stage_init_timeout: int) -> None: supported_tasks.add("speech") self.supported_tasks = tuple(supported_tasks) if supported_tasks else ("generate",) - def _initialize_janus_queues(self) -> None: - """Initialize janus queues inside orchestrator thread loop context.""" - self.request_queue = janus.Queue() - self.output_queue = janus.Queue() - self.rpc_output_queue = janus.Queue() - logger.debug("[AsyncOmniEngine] janus queues initialized in orchestrator thread loop") - def _bootstrap_orchestrator( self, stage_init_timeout: int, @@ -1101,10 +1379,15 @@ def _bootstrap_orchestrator( asyncio.set_event_loop(loop) async def _run_orchestrator() -> None: - self._initialize_janus_queues() - self._initialize_stages(stage_init_timeout) pd_config = self._detect_pd_config() + coordinator_pub_address: str | None = None + load_balancer_factory: Callable[[], LoadBalancer] | None = None + remote_replica_factory: Callable[[int, int], Awaitable[Any]] | None = None + if self._coordinator_runtime is not None: + coordinator_pub_address = self._coordinator_runtime.pub_address + load_balancer_factory = _build_load_balancer_factory(self._omni_lb_policy) + remote_replica_factory = self._build_remote_replica orchestrator = Orchestrator( request_async_queue=self.request_queue.async_q, output_async_queue=self.output_queue.async_q, @@ -1112,6 +1395,9 @@ async def _run_orchestrator() -> None: stage_pools=self.stage_pools, async_chunk=self.async_chunk, pd_config=pd_config, + coordinator_pub_address=coordinator_pub_address, + load_balancer_factory=load_balancer_factory, + remote_replica_factory=remote_replica_factory, ) if not startup_future.done(): startup_future.set_result(asyncio.get_running_loop()) @@ -1128,10 +1414,8 @@ async def _run_orchestrator() -> None: error_text = str(e) or "Orchestrator thread crashed" try: error_msg = {"type": "error", "error": error_text, "fatal": True} - if self.output_queue is not None: - self.output_queue.sync_q.put_nowait(error_msg) - if self.rpc_output_queue is not None: - self.rpc_output_queue.sync_q.put_nowait(error_msg) + self.output_queue.sync_q.put_nowait(error_msg) + self.rpc_output_queue.sync_q.put_nowait(error_msg) except Exception: pass raise @@ -1759,8 +2043,6 @@ def add_request( reasoning_ended=reasoning_ended, resumable=resumable, ) - if self.request_queue is None: - raise RuntimeError("request_queue is not initialized") self.request_queue.sync_q.put_nowait(msg) # CFG companion expansion: create and enqueue companion requests @@ -1828,8 +2110,6 @@ def add_streaming_update( resumable=resumable, message_type="streaming_update", ) - if self.request_queue is None: - raise RuntimeError("request_queue is not initialized") self.request_queue.sync_q.put_nowait(msg) async def add_streaming_update_async( @@ -1856,8 +2136,6 @@ async def add_streaming_update_async( def try_get_output(self, timeout: float = 0.001) -> dict[str, Any] | None: """Read one output message from the Orchestrator output queue.""" - if self.output_queue is None: - return None try: return self.output_queue.sync_q.get(timeout=timeout) except queue.Empty: @@ -1867,8 +2145,6 @@ def try_get_output(self, timeout: float = 0.001) -> dict[str, Any] | None: async def try_get_output_async(self) -> dict[str, Any] | None: """Async read from the Orchestrator output queue.""" - if self.output_queue is None: - return None try: return self.output_queue.sync_q.get_nowait() except queue.Empty: @@ -1882,8 +2158,6 @@ def get_stage_metadata(self, stage_id: int) -> dict[str, Any]: def abort(self, request_ids: list[str]) -> None: """Send abort message to the Orchestrator.""" - if self.request_queue is None: - raise RuntimeError("request_queue is not initialized") self.request_queue.sync_q.put_nowait( { "type": "abort", @@ -1908,11 +2182,6 @@ def collective_rpc( This uses a dedicated RPC output queue so control-plane messages do not race with the normal request output polling loop. """ - if self.request_queue is None: - raise RuntimeError("request_queue is not initialized") - if self.rpc_output_queue is None: - raise RuntimeError("rpc_output_queue is not initialized") - rpc_id = uuid.uuid4().hex msg = { "type": "collective_rpc", @@ -1990,8 +2259,7 @@ def shutdown(self) -> None: logger.info("[AsyncOmniEngine] Shutting down Orchestrator") try: - if self.request_queue is not None: - self.request_queue.sync_q.put_nowait({"type": "shutdown"}) + self.request_queue.sync_q.put_nowait({"type": "shutdown"}) except Exception: pass if self.is_alive(): @@ -2000,8 +2268,6 @@ def shutdown(self) -> None: logger.warning("[AsyncOmniEngine] Orchestrator thread did not exit in time") for q in (self.request_queue, self.output_queue, self.rpc_output_queue): - if q is None: - continue try: q.close() except Exception: @@ -2014,6 +2280,13 @@ def shutdown(self) -> None: logger.exception("[AsyncOmniEngine] Failed to stop OmniMasterServer during shutdown") self._omni_master_server = None + if self._coordinator_runtime is not None: + try: + self._coordinator_runtime.close() + except Exception: + logger.exception("[AsyncOmniEngine] Failed to close OmniCoordinatorRuntime during shutdown") + self._coordinator_runtime = None + def _try_shutdown(self, *args, **kwargs) -> None: try: self.shutdown() diff --git a/vllm_omni/engine/omni_core_engine_proc_manager.py b/vllm_omni/engine/omni_core_engine_proc_manager.py new file mode 100644 index 00000000000..9d69a8af2df --- /dev/null +++ b/vllm_omni/engine/omni_core_engine_proc_manager.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Process manager for omni stage engine subprocesses. + +This is a drop-in replacement for vLLM's :class:`CoreEngineProcManager` that +spawns :meth:`StageEngineCoreProc.run_stage_core` instead of the upstream +``EngineCoreProc.run_engine_core``, and forwards omni-specific kwargs +(coordinator address, stage id, per-rank replica id). + +Each spawned subprocess corresponds to exactly one omni *replica*: it has its +own ZMQ allocation from :class:`OmniMasterServer` and (when an +``omni_coordinator_address`` is provided) its own +:class:`OmniCoordClientForStage` reporting heartbeat / status. + +Liveness monitoring and shutdown are inherited from +:class:`CoreEngineProcManager` unchanged. +""" + +from __future__ import annotations + +import contextlib +import threading +import weakref +from multiprocessing.process import BaseProcess +from multiprocessing.queues import Queue + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import numa_utils +from vllm.utils.system_utils import get_mp_context +from vllm.v1.engine.utils import CoreEngineProcManager +from vllm.v1.executor import Executor +from vllm.v1.utils import shutdown + +from vllm_omni.engine.stage_engine_core_proc import StageEngineCoreProc + +logger = init_logger(__name__) + +try: + # ``set_device_control_env_var`` lives next to CoreEngineProcManager and + # is only required for non-CUDA DP, so we tolerate its absence on + # older / future vLLM revisions. + from vllm.v1.engine.utils import set_device_control_env_var # type: ignore +except ImportError: # pragma: no cover - depends on vLLM build + set_device_control_env_var = None # type: ignore[assignment] + + +class OmniCoreEngineProcManager(CoreEngineProcManager): + """Spawn :class:`StageEngineCoreProc` subprocesses with omni kwargs. + + The body mirrors :class:`CoreEngineProcManager.__init__` because the + upstream class hardcodes ``target=EngineCoreProc.run_engine_core`` and + does not expose an extensibility hook. The differences from upstream are: + + * ``target`` is :meth:`StageEngineCoreProc.run_stage_core`. + * Per-rank ``omni_replica_id`` is computed as + ``base_replica_id + rank_idx`` and added to each subprocess's kwargs. + * ``omni_coordinator_address`` (if provided) and ``omni_stage_id`` are + added to every subprocess's kwargs. + """ + + def __init__( + self, + local_engine_count: int, + start_index: int, + local_start_index: int, + vllm_config: VllmConfig, + local_client: bool, + handshake_address: str, + executor_class: type[Executor], + log_stats: bool, + *, + omni_stage_id: int, + omni_coordinator_address: str | None = None, + omni_replica_base_id: int = 0, + client_handshake_address: str | None = None, + tensor_queue: Queue | None = None, + ) -> None: + # NOTE: we intentionally do not call ``super().__init__`` — the + # parent's body hardcodes the wrong target. We re-implement it here + # while reusing the parent's instance methods (shutdown, monitor). + if local_engine_count <= 0: + raise ValueError(f"local_engine_count must be > 0, got {local_engine_count}") + + context = get_mp_context() + common_kwargs: dict[str, object] = { + "vllm_config": vllm_config, + "local_client": local_client, + "handshake_address": handshake_address, + "executor_class": executor_class, + "log_stats": log_stats, + "tensor_queue": tensor_queue, + "omni_stage_id": int(omni_stage_id), + "omni_coordinator_address": omni_coordinator_address, + } + + if client_handshake_address: + common_kwargs["client_handshake_address"] = client_handshake_address + + is_dp = vllm_config.parallel_config.data_parallel_size > 1 + + self.processes: list[BaseProcess] = [] + local_dp_ranks: list[int] = [] + for index in range(local_engine_count): + local_index = local_start_index + index + global_index = start_index + index + # Each spawned subprocess is one omni replica. The replica id + # is contiguous within this manager; the master server may have + # pre-allocated a contiguous block starting at ``omni_replica_base_id``. + omni_replica_id = omni_replica_base_id + index + + local_dp_ranks.append(local_index) + self.processes.append( + context.Process( + target=StageEngineCoreProc.run_stage_core, + name=( + f"StageEngineCoreProc_stage{omni_stage_id}" + f"_replica{omni_replica_id}" + (f"_DP{global_index}" if is_dp else "") + ), + kwargs=common_kwargs + | { + "dp_rank": global_index, + "local_dp_rank": local_index, + "omni_replica_id": omni_replica_id, + }, + ) + ) + + self._finalizer = weakref.finalize(self, shutdown, self.processes) + self.manager_stopped = threading.Event() + self.failed_proc_name: str | None = None + + try: + for proc, local_dp_rank in zip(self.processes, local_dp_ranks): + device_control_context: contextlib.AbstractContextManager[None] = contextlib.nullcontext() + if ( + is_dp + and set_device_control_env_var is not None + and (not current_platform.is_cuda_alike() or vllm_config.parallel_config.use_ray) + ): + device_control_context = set_device_control_env_var(vllm_config, local_dp_rank) + + with ( + device_control_context, + numa_utils.configure_subprocess( + vllm_config, + local_rank=0, + dp_local_rank=local_dp_rank, + process_kind="EngineCore", + ), + ): + proc.start() + finally: + if self.finished_procs(): + self.shutdown() diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index 2d2ac47cbb3..ee611c5ad3e 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -4,12 +4,20 @@ Runs inside a background thread with its own asyncio event loop. Owns logical request progression across stage pools and handles stage-to-stage transfer logic. + +In distributed mode (``coordinator_pub_address`` provided), it also +owns the single :class:`OmniCoordClientForHub`, runs a +:meth:`_watch_replica_list` task that converts replica disappearances +into ``unregister_remote_replica`` control messages, and handles the +``register_remote_replica`` / ``unregister_remote_replica`` flow that +attaches / detaches head-side stage clients for headless replicas. """ from __future__ import annotations import asyncio import time as _time +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from typing import Any @@ -22,12 +30,28 @@ from vllm.v1.engine import EngineCoreOutputs from vllm.v1.engine.exceptions import EngineDeadError +from vllm_omni.distributed.omni_coordinator import ( + LoadBalancer, + OmniCoordClientForHub, + RandomBalancer, + StageStatus, +) from vllm_omni.engine import OmniEngineCoreRequest from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.engine.serialization import serialize_additional_information from vllm_omni.engine.stage_pool import StagePool from vllm_omni.outputs import OmniRequestOutput +# Factory signature for building a head-side stage client for a +# *dynamically attached* (auto-assigned) remote replica. +# +# Receives ``(stage_id, replica_id)`` and returns an awaitable yielding the +# constructed client (any type — it must satisfy the shape expected by the +# matching :class:`StagePool`, i.e. expose ``client_addresses["input_address"]`` +# or ``request_address``, plus the usual ``add_request_async`` / +# ``get_output_async`` / ``shutdown`` surface). +RemoteReplicaFactory = Callable[[int, int], Awaitable[Any]] + logger = init_logger(__name__) @@ -113,6 +137,10 @@ class StreamingInputState: class Orchestrator: """Runs inside a background thread's asyncio event loop.""" + # Cadence at which the replica-list watcher polls for disappearances. + _WATCH_REPLICA_INTERVAL_S: float = 0.5 + _WATCH_REPLICA_IDLE_INTERVAL_S: float = 1.0 + def __init__( self, request_async_queue: janus.AsyncQueue[dict[str, Any]], @@ -122,6 +150,9 @@ def __init__( *, async_chunk: bool = False, pd_config: dict[str, Any] | None = None, + coordinator_pub_address: str | None = None, + load_balancer_factory: Callable[[], LoadBalancer] | None = None, + remote_replica_factory: RemoteReplicaFactory | None = None, ) -> None: self.request_async_queue = request_async_queue self.output_async_queue = output_async_queue @@ -148,6 +179,28 @@ def __init__( self._fatal_error: str | None = None self._fatal_error_stage_id: int | None = None + # Background tasks for fire-and-forget message handlers (currently + # only ``register_remote_replica`` and ``unregister_remote_replica``). + # Held as a set so each task's reference survives the loop and the + # task can self-deregister on completion. + self._membership_tasks: set[asyncio.Task[None]] = set() + + # Distributed-mode wiring. The hub is constructed on the + # orchestrator's asyncio loop because it spawns a SUB background + # thread; building it from another thread would race the + # ``_init_done`` event. + self._hub: OmniCoordClientForHub | None = ( + OmniCoordClientForHub(coordinator_pub_address) if coordinator_pub_address is not None else None + ) + self._remote_replica_factory = remote_replica_factory + # Inject hub + per-pool LB into each StagePool so they can run + # distributed dispatch via ``StagePool.pick``. + if self._hub is not None: + factory = load_balancer_factory or RandomBalancer + for pool in self.stage_pools: + pool.attach_hub(self._hub) + pool.attach_load_balancer(factory()) + async def run(self) -> None: """Main entry point for the Orchestrator event loop.""" logger.info("[Orchestrator] Starting event loop") @@ -157,9 +210,16 @@ async def run(self) -> None: self._orchestration_output_handler(), name="orchestrator-stage-output-handler", ) + # The replica watcher only runs in distributed mode. It's still + # created in both cases so ``run()`` has a uniform task graph; + # ``_watch_replica_list`` is a no-op poll when ``self._hub`` is None. + watch_task = asyncio.create_task( + self._watch_replica_list(), + name="orchestrator-replica-watcher", + ) try: - await asyncio.gather(request_task, output_task) + await asyncio.gather(request_task, output_task, watch_task) except asyncio.CancelledError: raise except Exception: @@ -167,11 +227,11 @@ async def run(self) -> None: raise finally: self._shutdown_event.set() - for task in (request_task, output_task): + for task in (request_task, output_task, watch_task): if not task.done(): task.cancel() try: - await asyncio.gather(request_task, output_task, return_exceptions=True) + await asyncio.gather(request_task, output_task, watch_task, return_exceptions=True) except Exception: pass @@ -181,8 +241,32 @@ async def run(self) -> None: if self._fatal_error is not None: await self._drain_pending_requests_on_fatal() + # Wait briefly for any in-flight membership handlers (register / + # unregister remote replica) to finish so they don't leave the + # head-side pool in a half-attached state. Cancel anything that + # hasn't completed in time; the generic pending-task sweep below + # will collect the cancellations. + if self._membership_tasks: + try: + await asyncio.wait_for( + asyncio.gather(*self._membership_tasks, return_exceptions=True), + timeout=10.0, + ) + except (asyncio.TimeoutError, Exception): + for t in self._membership_tasks: + if not t.done(): + t.cancel() + self._shutdown_stages() + # Close the hub last so any in-flight dispatch still has access. + if self._hub is not None: + try: + self._hub.close() + except RuntimeError: + pass + self._hub = None + loop = asyncio.get_running_loop() pending = [t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task() and not t.done()] for task in pending: @@ -190,6 +274,27 @@ async def run(self) -> None: if pending: await asyncio.gather(*pending, return_exceptions=True) + # ---- Background task helpers ---- + + def _spawn_membership_task(self, coro: Awaitable[None], *, label: str) -> None: + """Run a fire-and-forget membership-change coroutine. + + Holds a strong reference until completion (asyncio would otherwise + garbage-collect a bare task), and logs any uncaught exception. + """ + task = asyncio.create_task(coro, name=f"orchestrator-{label}") + self._membership_tasks.add(task) + + def _on_done(t: asyncio.Task[None]) -> None: + self._membership_tasks.discard(t) + if t.cancelled(): + return + exc = t.exception() + if exc is not None: + logger.error("[Orchestrator] %s task crashed", label, exc_info=exc) + + task.add_done_callback(_on_done) + # ---- Request handling ---- async def _request_handler(self) -> None: @@ -208,6 +313,19 @@ async def _request_handler(self) -> None: await self._handle_abort(msg) elif msg_type == "collective_rpc": await self._handle_collective_rpc(msg) + elif msg_type == "register_remote_replica": + # Dynamic-attach involves a ~5s blocking handshake (run in a + # thread by ``_build_remote_replica``); ``await`` here would + # block the queue and stall the next ``add_request`` until + # the attach finishes. Dispatch as a background task so the + # main message loop keeps draining. + self._spawn_membership_task(self._handle_register_remote_replica(msg), label="register_remote_replica") + elif msg_type == "unregister_remote_replica": + # Symmetric with register: keep the main queue flowing. + self._spawn_membership_task( + self._handle_unregister_remote_replica(msg), + label="unregister_remote_replica", + ) elif msg_type == "shutdown": logger.info("[Orchestrator] Received shutdown signal") self._shutdown_event.set() @@ -381,7 +499,7 @@ async def _handle_collective_rpc(self, msg: dict[str, Any]) -> None: results: list[Any] = [] stage_ids: list[int] = [] for pool in target_pools: - for replica_id in range(pool.num_replicas): + for replica_id in pool.live_replica_ids(): stage_result = await pool.collective_rpc( replica_id=replica_id, method=method, @@ -418,7 +536,7 @@ async def _orchestration_loop(self) -> None: idle = True for stage_id in range(self.num_stages): pool = self.stage_pools[stage_id] - for replica_id in range(pool.num_replicas): + for replica_id in pool.live_replica_ids(): if self._shutdown_event.is_set(): return @@ -1079,14 +1197,139 @@ async def _drain_pending_requests_on_fatal(self) -> None: ) self.request_states.pop(req_id, None) + # ---- Distributed-mode replica attach / detach ---- + + async def _watch_replica_list(self) -> None: + """Convert hub replica disappearances into unregister control messages.""" + last_up: set[tuple[int, str]] = set() + while not self._shutdown_event.is_set(): + if self._hub is None: + # No coordinator wired up; sleep coarsely and re-check shutdown. + try: + await asyncio.sleep(self._WATCH_REPLICA_IDLE_INTERVAL_S) + except asyncio.CancelledError: + raise + continue + + try: + snap = self._hub.get_replica_list() + current = {(rep.stage_id, rep.input_addr) for rep in snap.replicas if rep.status == StageStatus.UP} + for stage_id, addr in last_up - current: + await self.request_async_queue.put( + { + "type": "unregister_remote_replica", + "stage_id": stage_id, + "input_addr": addr, + } + ) + last_up = current + except asyncio.CancelledError: + raise + except Exception: + logger.exception("[Orchestrator] _watch_replica_list iteration failed") + + try: + await asyncio.sleep(self._WATCH_REPLICA_INTERVAL_S) + except asyncio.CancelledError: + raise + + async def _handle_register_remote_replica(self, msg: dict[str, Any]) -> None: + """Bind a head-side client for a newly registered remote replica.""" + stage_id = int(msg["stage_id"]) + replica_id = int(msg["replica_id"]) + if not (0 <= stage_id < self.num_stages): + logger.warning( + "[Orchestrator] register_remote_replica: stage_id %d out of range (num_stages=%d)", + stage_id, + self.num_stages, + ) + return + if self._remote_replica_factory is None: + logger.warning( + "[Orchestrator] register_remote_replica received for stage=%d replica=%d but no factory installed", + stage_id, + replica_id, + ) + return + + try: + await self._attach_remote_replica(stage_id, replica_id) + except Exception: + logger.exception( + "[Orchestrator] failed to attach remote replica stage=%d replica=%d", + stage_id, + replica_id, + ) + + async def _handle_unregister_remote_replica(self, msg: dict[str, Any]) -> None: + """Tear down the head-side client for a vanished remote replica.""" + stage_id = int(msg["stage_id"]) + input_addr = str(msg["input_addr"]) + if not (0 <= stage_id < self.num_stages): + return + pool = self.stage_pools[stage_id] + affected = pool.invalidate_addr(input_addr) + self._detach_remote_replica(stage_id, input_addr) + if affected: + await self._cleanup_request_ids(affected, abort=True) + for req_id in affected: + await self.output_async_queue.put( + { + "type": "error", + "request_id": req_id, + "stage_id": stage_id, + "error": "stage replica disappeared", + } + ) + + async def _attach_remote_replica(self, stage_id: int, replica_id: int) -> None: + """Build a head-side stage client via the injected factory and register it.""" + factory = self._remote_replica_factory + if factory is None: + return + pool = self.stage_pools[stage_id] + client = await factory(stage_id, replica_id) + input_addr = StagePool._client_input_addr(client) + if input_addr is None: + raise RuntimeError( + f"remote replica factory for stage {stage_id} produced a client without a discoverable input address" + ) + pool.add_client(input_addr, client) + logger.info( + "[Orchestrator] attached remote replica stage=%d replica=%d addr=%s", + stage_id, + replica_id, + input_addr, + ) + + def _detach_remote_replica(self, stage_id: int, input_addr: str) -> None: + """Shut down + remove the head-side client at ``input_addr``.""" + pool = self.stage_pools[stage_id] + client = pool.remove_client(input_addr) + if client is None: + return + try: + client.shutdown() + except Exception: + logger.exception( + "[Orchestrator] failed to shutdown client for stage=%d addr=%s", + stage_id, + input_addr, + ) + logger.info( + "[Orchestrator] detached remote replica stage=%d addr=%s", + stage_id, + input_addr, + ) + def _shutdown_stages(self) -> None: """Shutdown all stage pools.""" if self._stages_shutdown: return self._stages_shutdown = True - total = sum(pool.num_replicas for pool in self.stage_pools) + total = sum(pool.live_num_replicas for pool in self.stage_pools) logger.info("[Orchestrator] Shutting down all %d client(s)", total) for pool in self.stage_pools: - for replica_id in range(pool.num_replicas): + for replica_id in pool.live_replica_ids(): pool.shutdown_replica(replica_id) diff --git a/vllm_omni/engine/stage_engine_core_proc.py b/vllm_omni/engine/stage_engine_core_proc.py index 2ab8b37dd5f..6b77388b41f 100644 --- a/vllm_omni/engine/stage_engine_core_proc.py +++ b/vllm_omni/engine/stage_engine_core_proc.py @@ -7,6 +7,7 @@ from __future__ import annotations +import contextlib import signal from multiprocessing.process import BaseProcess from typing import TYPE_CHECKING, Any @@ -33,6 +34,8 @@ ) from vllm.v1.utils import shutdown +from vllm_omni.distributed.omni_coordinator import OmniCoordClientForStage + if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.v1.executor import Executor @@ -53,35 +56,82 @@ def run_stage_core( *args: Any, dp_rank: int = 0, local_dp_rank: int = 0, + omni_coordinator_address: str | None = None, + omni_stage_id: int | None = None, + omni_replica_id: int = 0, **kwargs: Any, ) -> None: - """Launch StageEngineCoreProc busy loop in background process.""" + """Launch StageEngineCoreProc busy loop in background process. + + Omni-specific kwargs: + - ``omni_coordinator_address``: ROUTER address of the head-side + :class:`OmniCoordinator`. When provided, this subprocess + instantiates an :class:`OmniCoordClientForStage` after the + HELLO/INIT/READY handshake completes and reports its status + + queue length via heartbeats. The hook is wired so each + heartbeat refreshes ``queue_length`` from the live scheduler. + - ``omni_stage_id``: logical stage id this replica belongs to. + Required when ``omni_coordinator_address`` is provided. + - ``omni_replica_id``: cluster-unique replica id within the + stage (assigned by :class:`OmniMasterServer`). Used for + logging / metrics only. + """ signal_callback: SignalCallback | None = None maybe_register_config_serialize_by_value() engine_core: StageEngineCoreProc | None = None + coord_client: OmniCoordClientForStage | None = None try: - vllm_config: VllmConfig = kwargs["vllm_config"] - parallel_config = vllm_config.parallel_config + # NOTE: previous revisions hardcoded data_parallel_size=1 here + # (TODO referencing issue #984). The hardcoding has been removed + # so the DP fields propagate through from the caller exactly + # like upstream vLLM. - set_process_title(f"StageEngineCoreProc_DP{dp_rank}") + stage_label = f"stage{omni_stage_id}" if omni_stage_id is not None else "noid" + set_process_title(f"StageEngineCoreProc_{stage_label}_replica{omni_replica_id}_DP{dp_rank}") decorate_logs() - # the current vllm-omni does not support data parallelism, - # so we set the data parallel size to 1. - # [TODO] support data parallelism in the future. - # https://github.com/vllm-project/vllm-omni/issues/984 - parallel_config.data_parallel_size = 1 - parallel_config.data_parallel_size_local = 1 - parallel_config.data_parallel_rank = 0 - parallel_config.data_parallel_index = dp_rank - engine_core = StageEngineCoreProc( *args, engine_index=dp_rank, **kwargs, ) + # Each subprocess corresponds to exactly one omni replica with + # its own OmniMasterServer allocation, so the heartbeat client + # runs unconditionally — there is no dp_rank-based gating. + if omni_coordinator_address is not None: + if omni_stage_id is None: + raise ValueError("omni_stage_id must be provided when omni_coordinator_address is set") + addresses: EngineZmqAddresses = engine_core.addresses + if not addresses.inputs or not addresses.outputs: + raise RuntimeError( + "EngineCore handshake did not populate input/output addresses; " + "cannot start OmniCoordClientForStage" + ) + coord_client = OmniCoordClientForStage( + coord_zmq_addr=omni_coordinator_address, + input_addr=addresses.inputs[0], + output_addr=addresses.outputs[0], + stage_id=int(omni_stage_id), + ) + + def _refresh_queue_length() -> None: + """Pre-heartbeat hook: refresh queue_length from scheduler.""" + scheduler = getattr(engine_core, "scheduler", None) + if scheduler is None: + return + try: + coord_client._queue_length = int( # type: ignore[union-attr] + scheduler.get_num_unfinished_requests() + ) + except Exception: + # Live scheduler stats are best-effort — heartbeats + # must not fail because of a stats lookup error. + pass + + coord_client._on_heartbeat = _refresh_queue_length + def wakeup_engine() -> None: engine_core.input_queue.put_nowait((EngineCoreRequestType.WAKEUP, None)) @@ -111,6 +161,9 @@ def signal_handler(signum: int, frame: Any) -> None: signal.signal(signal.SIGINT, signal.SIG_DFL) if signal_callback is not None: signal_callback.stop() + if coord_client is not None: + with contextlib.suppress(RuntimeError): + coord_client.close() if engine_core is not None: engine_core.shutdown() diff --git a/vllm_omni/engine/stage_engine_startup.py b/vllm_omni/engine/stage_engine_startup.py index 05bcdf7d138..16c990d9322 100644 --- a/vllm_omni/engine/stage_engine_startup.py +++ b/vllm_omni/engine/stage_engine_startup.py @@ -4,8 +4,9 @@ import contextlib import dataclasses +import socket import threading -from collections.abc import Iterator +from collections.abc import Callable, Iterator from dataclasses import dataclass from typing import Any @@ -31,6 +32,15 @@ StageRoute = tuple[int, int] +# Sentinel that signals "auto-assign me a replica_id" on the wire. Negative +# values are not valid replica ids, so any sub-zero value works equivalently. +AUTO_ASSIGN_REPLICA_ID = -1 + +# Callback signature for OmniMasterServer.on_register. Fires only for +# auto-assigned replicas (new, headless-launched). The arguments are +# (stage_id, replica_id, allocation). +OnRegisterCallback = Callable[[int, int, "StageAllocation"], None] + # Poll period (ms) used by the registration/handshake loop. _POLL_PERIOD_MS = 5_000 # Default timeout (s) for a stage to send READY. @@ -110,6 +120,9 @@ def __init__( master_port: int, stage_ids: list[int], stage_replica_counts: dict[int, int] | None = None, + *, + coordinator_router_address: str | None = None, + on_register: OnRegisterCallback | None = None, ) -> None: self._address = master_address self._port = master_port @@ -117,25 +130,31 @@ def __init__( self._stage_configs: dict[StageRoute, Any] = {} self._stage_coordinator_addresses: dict[StageRoute, StageCoordinatorAddresses] = {} self._stage_config_events: dict[StageRoute, threading.Event] = {} + # Coordinator ROUTER address echoed back in every registration reply + # so OmniCoordClientForStage knows where to connect from inside the + # engine subprocess. + self._coordinator_router_address = coordinator_router_address + # Fires only for *newly assigned* (auto-assigned) replicas, not for + # head-side pre-allocated slots that already have head-side clients. + self._on_register = on_register + # Per-stage allocation lock + auto-assign cursor, so concurrent + # registrations from multiple headless processes for the same stage + # don't race on the routing table. + self._alloc_lock = threading.Lock() + self._stage_ids_known: set[int] = set(int(sid) for sid in stage_ids) self._thread: threading.Thread | None = None self._stop_event = threading.Event() stage_replica_counts = dict(stage_replica_counts or {}) for sid in stage_ids: - replica_count = max(1, int(stage_replica_counts.get(sid, 1))) + replica_count = int(stage_replica_counts.get(sid, 1)) + # Allow 0 explicitly so non-self stages (head distributed mode) + # can declare "no local replicas; remote ones will register + # dynamically". + if replica_count < 0: + raise ValueError(f"stage_replica_counts[{sid}] must be >= 0, got {replica_count}") for replica_id in range(replica_count): - route = (sid, replica_id) - self._stage_config_events[route] = threading.Event() - self._stage_coordinator_addresses[route] = StageCoordinatorAddresses() - hs_port, inp_port, out_port = get_open_ports_list(count=3) - self._stage_routes[route] = StageAllocation( - handshake_bind_address=f"tcp://{master_address}:{hs_port}", - handshake_connect_address=f"tcp://{master_address}:{hs_port}", - input_bind_address=f"tcp://{master_address}:{inp_port}", - input_connect_address=f"tcp://{master_address}:{inp_port}", - output_bind_address=f"tcp://{master_address}:{out_port}", - output_connect_address=f"tcp://{master_address}:{out_port}", - ) + self._allocate_route_locked(sid, replica_id) logger.info( "[OmniMasterServer] Pre-allocated addresses for stages %s (master=%s:%d)", @@ -157,10 +176,74 @@ def port(self) -> int: """Return the registration port exposed to stage launchers.""" return self._port + @property + def coordinator_router_address(self) -> str | None: + """Return the OmniCoordinator ROUTER address echoed to replicas.""" + return self._coordinator_router_address + def get_allocation(self, stage_id: int, replica_id: int = 0) -> StageAllocation: """Return the full address allocation for *stage_id*.""" return self._stage_routes[(stage_id, replica_id)] + # ------------------------------------------------------------------ + # Allocation + # ------------------------------------------------------------------ + + def _allocate_route_locked(self, stage_id: int, replica_id: int) -> StageAllocation: + """Allocate handshake/input/output ports for ``(stage_id, replica_id)``. + + Idempotent: if the route already exists, returns the existing + allocation unchanged. Caller is responsible for holding + ``self._alloc_lock`` when needed. + """ + route = (stage_id, replica_id) + existing = self._stage_routes.get(route) + if existing is not None: + return existing + + self._stage_config_events[route] = threading.Event() + self._stage_coordinator_addresses[route] = StageCoordinatorAddresses() + hs_port, inp_port, out_port = get_open_ports_list(count=3) + alloc = StageAllocation( + handshake_bind_address=f"tcp://{self._address}:{hs_port}", + handshake_connect_address=f"tcp://{self._address}:{hs_port}", + input_bind_address=f"tcp://{self._address}:{inp_port}", + input_connect_address=f"tcp://{self._address}:{inp_port}", + output_bind_address=f"tcp://{self._address}:{out_port}", + output_connect_address=f"tcp://{self._address}:{out_port}", + ) + self._stage_routes[route] = alloc + return alloc + + def _next_free_replica_id(self, stage_id: int) -> int: + """Return the next replica id to assign for an auto-assign registration. + + Strategy: prefer filling a pre-allocated-but-unfilled slot (one that + ``__init__`` reserved in ``_stage_routes`` but no registration has + completed yet) so the head's bootstrap path — which waits on + ``_stage_config_events[(stage_id, replica_id)]`` for specific + pre-allocated ids — unblocks. Only when every pre-allocated slot for + this stage has been filled do we allocate a fresh id. + + Without this, a headless contributor using ``--omni-dp-size-local > 1`` + (auto-assign mode) would skip past pre-allocated slot 0 and pick ids + beyond ``num_replicas``, deadlocking the head's + ``connect_remote_engine_cores`` wait. + """ + # Pre-allocated slots that haven't received a registration yet are + # tracked by absence from ``_stage_configs``. + for sid, rid in sorted(self._stage_routes): + if sid != stage_id: + continue + if (sid, rid) not in self._stage_configs: + return rid + # Every pre-allocated slot is filled; allocate a fresh id. + used = {rid for (sid, rid) in self._stage_routes if sid == stage_id} + rid = 0 + while rid in used: + rid += 1 + return rid + def register_stage_config( self, stage_id: int, @@ -280,12 +363,21 @@ def _serve(self, ctx: zmq.Context) -> None: # type: ignore[type-arg] poller = zmq.Poller() poller.register(reg_socket, zmq.POLLIN) + # The server runs until ``stop()`` is called so that headless replicas + # spawned after the head finished its initial bring-up can still + # register dynamically. ``pending`` is kept around purely for + # debug-level logging of which pre-allocated slots have not yet + # registered; once empty it does not terminate the loop. pending: set[StageRoute] = set(self._stage_routes.keys()) - while pending and not self._stop_event.is_set(): + while not self._stop_event.is_set(): events: list[tuple[zmq.Socket, int]] = poller.poll(_POLL_PERIOD_MS) # type: ignore[assignment] if not events: - logger.debug("[OmniMasterServer] Still waiting for registration from stages: %s", pending) + if pending: + logger.debug( + "[OmniMasterServer] Still waiting for registration from pre-allocated slots: %s", + pending, + ) continue for sock, _ in events: @@ -296,12 +388,12 @@ def _serve(self, ctx: zmq.Context) -> None: # type: ignore[type-arg] # Cleanup reg_socket.close(linger=0) - logger.info("[OmniMasterServer] All stages registered; server thread exiting.") + logger.info("[OmniMasterServer] Server thread exiting.") def _handle_registration(self, reg_socket: zmq.Socket) -> StageRoute | None: # type: ignore[type-arg] """Receive a stage registration and reply with the handshake address. - Returns the registered stage_id on success, or None on failure. + Returns ``(stage_id, replica_id)`` on success or ``None`` on failure. """ frames = reg_socket.recv_multipart() if len(frames) < 2: @@ -318,45 +410,157 @@ def _handle_registration(self, reg_socket: zmq.Socket) -> StageRoute | None: # logger.warning("[OmniMasterServer] Failed to decode registration message: %s", exc) return None - stage_id: int | None = msg.get("stage_id") - replica_id = int(msg.get("replica_id", 0) or 0) - key = (stage_id, replica_id) - if key not in self._stage_routes: + stage_id_raw = msg.get("stage_id") + if not isinstance(stage_id_raw, int) or stage_id_raw < 0: logger.warning( - "[OmniMasterServer] Received registration for unknown stage_id=%s replica_id=%s", - stage_id, - replica_id, + "[OmniMasterServer] Registration missing or invalid stage_id: %r", + stage_id_raw, ) return None + stage_id: int = stage_id_raw + + incoming_replica_id = int(msg.get("replica_id", 0) or 0) + was_auto_assigned = incoming_replica_id < 0 + + # Distinguish two registration shapes: + # - Pre-allocated slots (concrete replica_id >= 0): the head built + # this slot during _initialize_stages. Just confirm it; do NOT + # fire ``on_register`` (the head already has a head-side client). + # - Auto-assigned slots (replica_id == AUTO_ASSIGN_REPLICA_ID): + # a *new* replica from a headless launcher. Allocate, then + # fire ``on_register`` so the orchestrator attaches. + with self._alloc_lock: + if was_auto_assigned: + replica_id = self._next_free_replica_id(stage_id) + # When auto-assign picks a slot the head pre-allocated (and + # is therefore waiting on in ``connect_remote_engine_cores``), + # the head's bootstrap path builds the head-side client. We + # must NOT also fire ``on_register`` for it; otherwise the + # orchestrator would build a duplicate client and overwrite + # the bootstrap-built one in the pool, leaking it. + preexisting_slot = (stage_id, replica_id) in self._stage_routes + alloc = self._allocate_route_locked(stage_id, replica_id) + if preexisting_slot: + was_auto_assigned = False + else: + replica_id = incoming_replica_id + if (stage_id, replica_id) not in self._stage_routes: + # Tolerate explicit replica_ids that haven't been + # pre-allocated (e.g. headless that wants a specific id). + alloc = self._allocate_route_locked(stage_id, replica_id) + was_auto_assigned = True + else: + alloc = self._stage_routes[(stage_id, replica_id)] + + # Cross-host override: when the registering replica advertised its + # own bind address + ports, rewrite the StageAllocation so the + # subsequent reply (and any later head-side lookup of this route + # via ``_stage_routes``) uses addresses that are actually local to + # the replica's host. The master's pre-allocated ports (picked on + # the master's host) only work for co-located replicas. + new_bind_address = msg.get("replica_bind_address") + if new_bind_address: + hs_port = int(msg["replica_handshake_port"]) + inp_port = int(msg["replica_input_port"]) + out_port = int(msg["replica_output_port"]) + alloc = StageAllocation( + handshake_bind_address=f"tcp://{new_bind_address}:{hs_port}", + handshake_connect_address=f"tcp://{new_bind_address}:{hs_port}", + input_bind_address=f"tcp://{new_bind_address}:{inp_port}", + input_connect_address=f"tcp://{new_bind_address}:{inp_port}", + output_bind_address=f"tcp://{new_bind_address}:{out_port}", + output_connect_address=f"tcp://{new_bind_address}:{out_port}", + ) + self._stage_routes[(stage_id, replica_id)] = alloc + logger.info( + "[OmniMasterServer] Stage %d replica %d advertised cross-host bind: %s", + stage_id, + replica_id, + alloc.handshake_bind_address, + ) - self.register_stage_config( - stage_id, - msg.get("stage_config"), - coordinator_addresses=StageCoordinatorAddresses( - coordinator_input=msg.get("coordinator_input"), - coordinator_output=msg.get("coordinator_output"), - frontend_stats_publish_address=msg.get("frontend_stats_publish_address"), - ), - replica_id=replica_id, - ) + # Mark the slot as filled *inside* the lock. Without this, + # concurrent auto-assign registrations from a second headless + # could call ``_next_free_replica_id`` between the lock + # releasing above and the ``register_stage_config`` call + # below, observe the slot as unfilled, and hand the same + # pre-allocated handshake/input/output addresses to two + # different replicas — which then collide on + # ``zmq_socket_ctx(handshake_address, ROUTER, bind=True)``. + self.register_stage_config( + stage_id, + msg.get("stage_config"), + coordinator_addresses=StageCoordinatorAddresses( + coordinator_input=msg.get("coordinator_input"), + coordinator_output=msg.get("coordinator_output"), + frontend_stats_publish_address=msg.get("frontend_stats_publish_address"), + ), + replica_id=replica_id, + ) + + # Fire on_register only for genuinely new (auto-assigned or newly + # allocated) replicas, on the ROUTER thread. Callback is expected to + # be cheap and non-blocking (e.g. enqueue onto an asyncio queue). + if was_auto_assigned and self._on_register is not None: + try: + self._on_register(stage_id, replica_id, alloc) + except Exception: + logger.exception( + "[OmniMasterServer] on_register callback failed for stage=%d replica=%d", + stage_id, + replica_id, + ) - alloc = self._stage_routes[key] response = msgspec.msgpack.encode( { "handshake_address": alloc.handshake_connect_address, "input_address": alloc.input_bind_address, "output_address": alloc.output_bind_address, + "replica_id": replica_id, + "coordinator_router_address": self._coordinator_router_address, } ) # ROUTER-DEALER: reply is [identity, payload] (no empty delimiter). reg_socket.send_multipart([identity, response]) logger.info( - "[OmniMasterServer] Stage %d replica %d registered; assigned handshake=%s", + "[OmniMasterServer] Stage %d replica %d registered (auto=%s); handshake=%s", stage_id, replica_id, + was_auto_assigned, alloc.handshake_connect_address, ) - return key + return (stage_id, replica_id) + + +@dataclass(frozen=True) +class StageRegistrationResponse: + """Reply payload returned by :class:`OmniMasterServer` after a successful registration.""" + + handshake_address: str + input_address: str + output_address: str + replica_id: int + coordinator_router_address: str | None + + +def _detect_local_bind_address(master_address: str, master_port: int) -> str: + """Return the local IP the kernel would use to reach the master. + + Uses a connected UDP socket as a routing-table probe: ``connect()`` on + SOCK_DGRAM sends no packets but forces a route lookup, after which + ``getsockname()[0]`` exposes the source IP that an outbound packet to + ``(master_address, master_port)`` would carry. For a co-located master + this returns the loopback or eth0 IP (same effect as the legacy + ``self._address`` behaviour); for a remote master it returns the + NIC IP that's actually reachable from the master — which is exactly + the address the headless's per-stage ZMQ sockets must bind on. + """ + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect((master_address, master_port)) + return s.getsockname()[0] + finally: + s.close() def register_stage_with_omni_master( @@ -367,23 +571,36 @@ def register_stage_with_omni_master( omni_stage_config: Any = None, coordinator: DPCoordinator | None = None, return_addresses: bool = False, - replica_id: int = 0, -) -> str | tuple[str, str, str]: + replica_id: int | None = 0, + return_full_response: bool = False, + replica_bind_address: str | None = None, +) -> str | tuple[str, str, str] | StageRegistrationResponse: """Register a stage with the omni master server. Returns the per-stage handshake address by default. When ``return_addresses`` is true, also returns the stage input/output - addresses allocated by the master. + addresses allocated by the master. When ``return_full_response`` is + true, returns the full :class:`StageRegistrationResponse` including the + assigned ``replica_id`` and the OmniCoordinator ROUTER address (if + published by the master). + + Pass ``replica_id=None`` to request auto-assignment of a free replica + id by the master (used by headless launchers). """ + if replica_id is None: + wire_replica_id = AUTO_ASSIGN_REPLICA_ID + else: + wire_replica_id = int(replica_id) + reg_ctx = zmq.Context() try: reg_sock: zmq.Socket = reg_ctx.socket(zmq.DEALER) # type: ignore[attr-defined] try: reg_sock.connect(f"tcp://{omni_master_address}:{omni_master_port}") - payload = { + payload: dict[str, Any] = { "stage_id": omni_stage_id, - "replica_id": replica_id, + "replica_id": wire_replica_id, "stage_config": _serialize_stage_config(omni_stage_config), } if coordinator is not None: @@ -392,6 +609,24 @@ def register_stage_with_omni_master( payload["coordinator_output"] = coordinator_output payload["frontend_stats_publish_address"] = coordinator.get_stats_publish_address() + # Always advertise THIS host's local bind address + 3 locally + # free ports so the master can root the per-stage socket + # allocation on the replica's own interface. For a co-located + # replica the detected IP matches the master's address and + # the override is a no-op semantically; for a cross-host + # replica it's what makes the headless's ROUTER bind succeed + # (otherwise the master would hand back ``tcp://:port`` + # and ``zmq.bind`` would EADDRNOTAVAIL on the remote host). + if replica_bind_address is None: + replica_bind_address = _detect_local_bind_address( + omni_master_address, omni_master_port + ) + hs_port, inp_port, out_port = get_open_ports_list(count=3) + payload["replica_bind_address"] = replica_bind_address + payload["replica_handshake_port"] = hs_port + payload["replica_input_port"] = inp_port + payload["replica_output_port"] = out_port + reg_sock.send(msgspec.msgpack.encode(payload)) timeout_ms = _DEFAULT_STARTUP_TIMEOUT_S * 1_000 if not reg_sock.poll(timeout=timeout_ms): @@ -402,13 +637,16 @@ def register_stage_with_omni_master( f"for stage {omni_stage_id}." ) response_bytes = reg_sock.recv() - response = msgspec.msgpack.decode(response_bytes) - handshake_address: str = response["handshake_address"] - input_address: str = response["input_address"] - output_address: str = response["output_address"] + response_msg = msgspec.msgpack.decode(response_bytes) + handshake_address: str = response_msg["handshake_address"] + input_address: str = response_msg["input_address"] + output_address: str = response_msg["output_address"] + assigned_replica_id: int = int(response_msg.get("replica_id", wire_replica_id)) + coord_router_addr: str | None = response_msg.get("coordinator_router_address") logger.info( - "Stage %d registered; handshake_address=%s", + "Stage %d replica %d registered; handshake_address=%s", omni_stage_id, + assigned_replica_id, handshake_address, ) finally: @@ -416,6 +654,14 @@ def register_stage_with_omni_master( finally: reg_ctx.term() + if return_full_response: + return StageRegistrationResponse( + handshake_address=handshake_address, + input_address=input_address, + output_address=output_address, + replica_id=assigned_replica_id, + coordinator_router_address=coord_router_addr, + ) if return_addresses: return handshake_address, input_address, output_address return handshake_address @@ -543,8 +789,16 @@ def launch_omni_core_engines( stage_id: int, stage_config: Any = None, replica_id: int = 0, + *, + omni_coordinator_address: str | None = None, ) -> Iterator[tuple[CoreEngineProcManager, DPCoordinator | None, EngineZmqAddresses]]: - """Launch local engine cores using the omni registration flow.""" + """Launch local engine cores using the omni registration flow. + + When ``omni_coordinator_address`` is provided, the spawned engine + subprocesses use :class:`OmniCoreEngineProcManager` and each + instantiates an :class:`OmniCoordClientForStage` after the handshake + completes so the head's :class:`OmniCoordinator` knows about them. + """ addresses = omni_master_server.get_zmq_addresses(stage_id, replica_id=replica_id) parallel_config = vllm_config.parallel_config # Determine the number of local engines and their ranks. @@ -608,16 +862,35 @@ def launch_omni_core_engines( handshake_bind_address = omni_master_server.get_allocation(stage_id, replica_id=replica_id).handshake_bind_address with zmq_socket_ctx(handshake_bind_address, zmq.ROUTER, bind=True) as handshake_socket: - local_engine_manager = CoreEngineProcManager( - local_engine_count=local_engine_count, - start_index=start_index, - local_start_index=local_start_index, - vllm_config=vllm_config, - local_client=True, - handshake_address=handshake_address, - executor_class=executor_class, - log_stats=log_stats, - ) + if omni_coordinator_address is not None: + # Use the omni subclass so each spawned subprocess instantiates + # an OmniCoordClientForStage and heartbeats to the coordinator. + from vllm_omni.engine.omni_core_engine_proc_manager import OmniCoreEngineProcManager + + local_engine_manager: CoreEngineProcManager = OmniCoreEngineProcManager( + local_engine_count=local_engine_count, + start_index=start_index, + local_start_index=local_start_index, + vllm_config=vllm_config, + local_client=True, + handshake_address=handshake_address, + executor_class=executor_class, + log_stats=log_stats, + omni_stage_id=stage_id, + omni_coordinator_address=omni_coordinator_address, + omni_replica_base_id=replica_id, + ) + else: + local_engine_manager = CoreEngineProcManager( + local_engine_count=local_engine_count, + start_index=start_index, + local_start_index=local_start_index, + vllm_config=vllm_config, + local_client=True, + handshake_address=handshake_address, + executor_class=executor_class, + log_stats=log_stats, + ) yield local_engine_manager, coordinator, addresses diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py index ce68a23daa4..e29cdc90386 100644 --- a/vllm_omni/engine/stage_init_utils.py +++ b/vllm_omni/engine/stage_init_utils.py @@ -479,9 +479,19 @@ def get_stage_devices_per_replica(stage_cfg: Any) -> int: def compute_replica_layout( stage_configs: Sequence[Any], + *, + allow_zero: bool = False, ) -> tuple[list[int], dict[int, list[str]]]: """Compute per-stage replica counts and device assignments. + Args: + stage_configs: per-stage config objects with a ``runtime`` sub-config + exposing ``num_replicas`` and ``devices``. + allow_zero: when True, ``num_replicas == 0`` is honored (used by + single-stage / head-distributed mode for non-self stages that + will be filled dynamically by remote registrations); when False + (default), the count is clamped to at least 1. + Returns: replicas_per_stage: num_replicas per logical stage. replica_devices_map: stage_idx -> per-replica device strings @@ -495,7 +505,9 @@ def compute_replica_layout( if hasattr(runtime_cfg, "get") else getattr(runtime_cfg, "num_replicas", 1) ) - replicas_per_stage.append(max(1, num_replicas)) + if num_replicas < 0: + raise ValueError(f"num_replicas must be >= 0, got {num_replicas}") + replicas_per_stage.append(num_replicas if allow_zero else max(1, num_replicas)) replica_devices_map: dict[int, list[str]] = {} for stage_id, stage_cfg in enumerate(stage_configs): diff --git a/vllm_omni/engine/stage_pool.py b/vllm_omni/engine/stage_pool.py index 6f745427112..bb8a08a2a61 100644 --- a/vllm_omni/engine/stage_pool.py +++ b/vllm_omni/engine/stage_pool.py @@ -10,6 +10,13 @@ from vllm.logger import init_logger from vllm.v1.engine import EngineCoreOutputs +from vllm_omni.distributed.omni_coordinator import ( + LoadBalancer, + OmniCoordClientForHub, + ReplicaInfo, + StageStatus, +) +from vllm_omni.distributed.omni_coordinator.load_balancer import Task from vllm_omni.metrics.stats import StageRequestStats as StageRequestMetrics from vllm_omni.metrics.stats import StageStats from vllm_omni.metrics.utils import count_tokens_from_outputs @@ -30,48 +37,108 @@ class _ReplicaMetrics: class StagePool: - """Replicas of one logical stage with RR + affinity selection.""" + """Replicas of one logical stage + per-stage routing (LB + affinity). + + The pool owns the head-side stage clients for one logical stage. It also + absorbs the per-stage dispatch responsibility (load balancing, affinity + tracking, bounded-wait pick) that used to live in a separate + ``StageDispatcher`` class — see the design doc for the rationale. + + In distributed mode (when an :class:`OmniCoordClientForHub` and a + :class:`LoadBalancer` are injected via :meth:`attach_hub` / + :meth:`attach_load_balancer`), :meth:`pick` consults the hub's cached + replica list and routes via the load balancer, sticking subsequent calls + for the same ``request_id`` to the same replica. + + In non-distributed mode (no hub attached), :meth:`pick` falls back to the + legacy ``select_replica_id`` round-robin path so the multi-stage + in-process invocation is unchanged. + + Dynamic replica membership: when a remote replica is added or removed + (driven by :class:`Orchestrator` via :meth:`add_client` / + :meth:`remove_client`), the pool keeps stable integer ``replica_id``s by + storing clients in a list whose entries can be ``None`` after a removal. + Iteration callers should use :meth:`live_replica_ids` rather than + ``range(pool.num_replicas)`` to skip the gaps. + """ + + DISPATCH_WAIT_TIMEOUT_S: float = 10.0 + DISPATCH_RETRY_INTERVAL_S: float = 0.1 def __init__( self, stage_id: int, - clients: Any | list[Any], + clients: object | list[object], *, output_processor: Any = None, stage_vllm_config: Any = None, ) -> None: if isinstance(clients, list): - normalized_clients = list(clients) + normalized_clients: list[object] = list(clients) else: normalized_clients = [clients] - if not normalized_clients: - raise ValueError(f"StagePool for stage {stage_id} has no replicas") + # Allow empty pools when running in distributed head mode for a + # non-self stage; clients will arrive via add_client(...). self.stage_id = stage_id - self.clients: list[Any] = normalized_clients + self.clients: list[Any | None] = list(normalized_clients) self._output_processor = output_processor self._stage_vllm_config = stage_vllm_config self._next_replica_id = 0 self._request_bindings: dict[str, int] = {} self._replica_metrics: list[_ReplicaMetrics] = [_ReplicaMetrics() for _ in self.clients] + # Distributed-mode state. Populated by add_client / remove_client. + self._addr_to_replica_id: dict[str, int] = {} + for replica_id, client in enumerate(self.clients): + if client is not None: + addr = self._client_input_addr(client) + if addr is not None: + self._addr_to_replica_id[addr] = replica_id + + # Distributed-mode dispatch hooks (injected by Orchestrator on bring-up). + self._hub: OmniCoordClientForHub | None = None + self._lb: LoadBalancer | None = None + # ``request_id`` → ``input_addr`` affinity (distributed mode only). + # Kept separate from the legacy ``_request_bindings`` so the two + # binding shapes do not collide. + self._affinity: dict[str, str] = {} + # ---- Stage-level properties ---- @property def num_replicas(self) -> int: + """Total slot count, including ``None`` holes from removed replicas. + + Use :meth:`live_replica_ids` to iterate only live entries. + """ return len(self.clients) + @property + def live_num_replicas(self) -> int: + """Number of currently live (non-None) replicas in this pool.""" + return sum(1 for c in self.clients if c is not None) + + def live_replica_ids(self) -> list[int]: + """Return the indices of currently live replicas in this pool.""" + return [i for i, c in enumerate(self.clients) if c is not None] + @property def stage_type(self) -> str | None: - return getattr(self.stage_client, "stage_type", None) + client = self.stage_client + return None if client is None else getattr(client, "stage_type", None) @property def final_output(self) -> bool: - return bool(getattr(self.clients[0], "final_output", False)) + client = self.stage_client + return False if client is None else bool(getattr(client, "final_output", False)) @property - def stage_client(self) -> Any: - return self.clients[0] + def stage_client(self) -> Any | None: + for client in self.clients: + if client is not None: + return client + return None @property def stage_vllm_config(self) -> Any: @@ -81,11 +148,209 @@ def stage_vllm_config(self) -> Any: def output_processor(self) -> Any: return self._output_processor - # ---- Route binding lifecycle ---- + @property + def is_distributed(self) -> bool: + """True iff a hub has been attached (i.e. running in head-distributed mode).""" + return self._hub is not None + + # ---- Distributed-mode dispatch hooks ---- + + def attach_hub(self, hub: OmniCoordClientForHub | None) -> None: + """Inject the shared :class:`OmniCoordClientForHub`. + + Called once by :class:`Orchestrator` after the hub is constructed. + ``hub=None`` keeps the pool in legacy mode (no behavior change). + """ + self._hub = hub + + def attach_load_balancer(self, lb: LoadBalancer | None) -> None: + """Inject the per-pool :class:`LoadBalancer` for distributed-mode pick.""" + self._lb = lb + + # ---- Dynamic membership (distributed mode) ---- + + @staticmethod + def _client_input_addr(client: Any) -> str | None: + """Return the input ZMQ address advertised by ``client`` if any. + + LLM clients expose ``client_addresses["input_address"]``; diffusion + clients expose ``request_address``. Both are stable strings used by + :class:`OmniCoordinator` to key replicas. + """ + request_address = getattr(client, "request_address", None) + if isinstance(request_address, str) and request_address: + return request_address + addrs = getattr(client, "client_addresses", None) + if isinstance(addrs, dict): + addr = addrs.get("input_address") + if isinstance(addr, str) and addr: + return addr + return None + + def add_client(self, input_addr: str, client: Any) -> int: + """Register a head-side client for ``input_addr``. + + Returns the assigned ``replica_id`` (index into :attr:`clients`). + If the address is already known, replaces the existing client and + returns its existing id (this should not happen in practice — the + master server assigns unique slots — but the contract is idempotent + to keep the dispatch layer robust). + """ + if not input_addr: + raise ValueError("input_addr must be a non-empty string") + + existing = self._addr_to_replica_id.get(input_addr) + if existing is not None: + self.clients[existing] = client + return existing + + replica_id = len(self.clients) + self.clients.append(client) + self._addr_to_replica_id[input_addr] = replica_id + self._replica_metrics.append(_ReplicaMetrics()) + return replica_id + + def remove_client(self, input_addr: str) -> Any | None: + """Remove the client at ``input_addr``. Returns the removed client or ``None``. + + Slot is marked ``None`` to preserve indices for outstanding bindings. + """ + replica_id = self._addr_to_replica_id.pop(input_addr, None) + if replica_id is None: + return None + client = self.clients[replica_id] + self.clients[replica_id] = None + return client + + def get_client_by_addr(self, input_addr: str) -> Any | None: + """Return the live client for ``input_addr`` if present.""" + replica_id = self._addr_to_replica_id.get(input_addr) + if replica_id is None: + return None + return self.clients[replica_id] + + def get_replica_id_by_addr(self, input_addr: str) -> int | None: + """Return the stable replica_id for ``input_addr`` if registered.""" + return self._addr_to_replica_id.get(input_addr) + + # ---- Per-request distributed dispatch ---- + + async def pick( + self, + request_id: str, + task: Task | None = None, + *, + affinity_request_id: str | None = None, + ) -> int: + """Return a replica id for ``request_id``. + + In distributed mode: consults the hub for UP replicas, runs the load + balancer, and records affinity so future picks for the same + ``request_id`` return the same replica. Bounded wait up to + ``DISPATCH_WAIT_TIMEOUT_S`` when no UP replica is currently usable. + + In non-distributed (legacy) mode: delegates to + :meth:`select_replica_id`. + """ + if self._hub is None or self._lb is None: + return self.select_replica_id(request_id, affinity_request_id=affinity_request_id) + + # 1. Sticky: previously bound and still serviceable? + bound_addr = self._affinity.get(request_id) + if bound_addr is not None: + replica_id = self._serviceable_replica_id_for_addr(bound_addr) + if replica_id is not None: + return replica_id + # Bound replica is gone or DOWN — fall through to re-select. + self._affinity.pop(request_id, None) + + # 2. Inherited affinity (CFG companion sharing a parent request_id). + if affinity_request_id is not None: + parent_addr = self._affinity.get(affinity_request_id) + if parent_addr is not None: + replica_id = self._serviceable_replica_id_for_addr(parent_addr) + if replica_id is not None: + self._affinity[request_id] = parent_addr + return replica_id + + # 3. Fresh pick: poll hub + LB with bounded wait. + task = task or Task(request_id=request_id) + deadline = _time.monotonic() + self.DISPATCH_WAIT_TIMEOUT_S + while True: + candidates = self._collect_serviceable_replicas() + if candidates: + # LB chose an index *into our candidates list*. + lb_idx = self._lb.select(task, [rep for rep, _ in candidates]) + replica_info, replica_id = candidates[lb_idx] + self._affinity[request_id] = replica_info.input_addr + return replica_id + + now = _time.monotonic() + if now >= deadline: + raise RuntimeError(f"no UP replica for stage {self.stage_id} after {self.DISPATCH_WAIT_TIMEOUT_S:.1f}s") + await asyncio.sleep(min(self.DISPATCH_RETRY_INTERVAL_S, deadline - now)) + + def _collect_serviceable_replicas(self) -> list[tuple[ReplicaInfo, int]]: + """Return list of ``(ReplicaInfo, replica_id)`` for UP, attached replicas.""" + if self._hub is None: + return [] + snap = self._hub.get_replicas_for_stage(self.stage_id) + out: list[tuple[ReplicaInfo, int]] = [] + for rep in snap.replicas: + if rep.status != StageStatus.UP: + continue + replica_id = self._addr_to_replica_id.get(rep.input_addr) + if replica_id is None: + continue # Hub knows about it but head-side client not attached yet. + if self.clients[replica_id] is None: + continue + out.append((rep, replica_id)) + return out + + def _serviceable_replica_id_for_addr(self, input_addr: str) -> int | None: + """Return ``replica_id`` for ``input_addr`` iff currently UP + attached.""" + if self._hub is None: + return None + replica_id = self._addr_to_replica_id.get(input_addr) + if replica_id is None or self.clients[replica_id] is None: + return None + snap = self._hub.get_replicas_for_stage(self.stage_id) + for rep in snap.replicas: + if rep.input_addr == input_addr and rep.status == StageStatus.UP: + return replica_id + return None + + def bind(self, request_id: str, input_addr: str) -> None: + """Explicitly record affinity (distributed mode).""" + self._affinity[request_id] = input_addr + + def release(self, request_id: str) -> None: + """Drop affinity (distributed mode) and legacy binding for ``request_id``.""" + self._affinity.pop(request_id, None) + self.release_binding(request_id) + + def invalidate_addr(self, input_addr: str) -> list[str]: + """Drop affinity rows pointing at ``input_addr``; return affected request ids.""" + affected: list[str] = [rid for rid, addr in self._affinity.items() if addr == input_addr] + for rid in affected: + self._affinity.pop(rid, None) + return affected + + # ---- Legacy (non-distributed) route binding ---- def get_bound_replica_id(self, request_id: str) -> int | None: - """Return the currently bound replica id for *request_id* if present.""" - return self._request_bindings.get(request_id) + """Return the currently bound replica id for *request_id* if present. + + In distributed mode the binding may have been recorded via + :meth:`pick`; we honor it transparently here. + """ + legacy = self._request_bindings.get(request_id) + if legacy is not None: + return legacy + addr = self._affinity.get(request_id) + if addr is None: + return None + return self._addr_to_replica_id.get(addr) def get_bound_client(self, request_id: str) -> Any | None: """Return the currently bound client for *request_id* if present.""" @@ -97,6 +362,7 @@ def get_bound_client(self, request_id: str) -> Any | None: def release_binding(self, request_id: str) -> None: """Drop the route binding for *request_id* in this stage.""" self._request_bindings.pop(request_id, None) + self._affinity.pop(request_id, None) def release_bindings(self, request_ids: list[str]) -> None: """Drop route bindings for the given request ids in this stage.""" @@ -109,18 +375,28 @@ def select_replica_id( *, affinity_request_id: str | None = None, ) -> int: - """Pick a replica id for *request_id* and cache the choice.""" + """Pick a replica id for *request_id* and cache the choice (legacy path).""" cached = self.get_bound_replica_id(request_id) - if cached is not None: + if cached is not None and self.clients[cached] is not None: return cached - chosen = self.get_bound_replica_id(affinity_request_id) if affinity_request_id is not None else None + chosen: int | None = None + if affinity_request_id is not None: + parent = self.get_bound_replica_id(affinity_request_id) + if parent is not None and self.clients[parent] is not None: + chosen = parent + if chosen is None: - if self.num_replicas == 1: - chosen = 0 + live = self.live_replica_ids() + if not live: + raise RuntimeError(f"stage {self.stage_id} has no live replicas") + if len(live) == 1: + chosen = live[0] else: - chosen = self._next_replica_id - self._next_replica_id = (self._next_replica_id + 1) % self.num_replicas + # Round-robin over live replicas only. + start = self._next_replica_id % len(live) + chosen = live[start] + self._next_replica_id = (self._next_replica_id + 1) % len(live) self._request_bindings[request_id] = chosen return chosen @@ -184,21 +460,26 @@ async def submit_initial( params = params_override if params_override is not None else req_state.sampling_params_list[self.stage_id] submit_kwargs = dict(submit_kwargs or {}) if self.stage_type == "diffusion": - replica_id = self.select_replica_id( + replica_id = await self._pick_or_select( request_id, affinity_request_id=affinity_request_id, ) client = self.clients[replica_id] + if client is None: + raise RuntimeError(f"stage {self.stage_id} replica {replica_id} is not attached") if isinstance(request, list): await client.add_batch_request_async(request_id, request, params, **submit_kwargs) else: await client.add_request_async(request_id, request, params, **submit_kwargs) return replica_id - replica_id = self.select_replica_id( + replica_id = await self._pick_or_select( request_id, affinity_request_id=affinity_request_id, ) + client = self.clients[replica_id] + if client is None: + raise RuntimeError(f"stage {self.stage_id} replica {replica_id} is not attached") try: self.output_processor.add_request( request=request, @@ -212,7 +493,7 @@ async def submit_initial( raise try: - await self.clients[replica_id].add_request_async(request, **submit_kwargs) + await client.add_request_async(request, **submit_kwargs) except Exception: self.release_binding(request_id) rollback = getattr(self.output_processor, "remove_request", None) @@ -240,11 +521,15 @@ async def submit_update( """Submit a streaming update to an already admitted request.""" params = req_state.sampling_params_list[self.stage_id] replica_id = self.get_bound_replica_id(request_id) - if replica_id is None: - replica_id = self.select_replica_id(request_id) + if replica_id is None or self.clients[replica_id] is None: + replica_id = await self._pick_or_select(request_id) + + client = self.clients[replica_id] + if client is None: + raise RuntimeError(f"stage {self.stage_id} replica {replica_id} is not attached") if self.stage_type == "diffusion": - await self.clients[replica_id].add_request_async(request_id, request, params) + await client.add_request_async(request_id, request, params) else: # Refresh the shared output-processor state before yielding to the # stage client so streaming segments are merged against the latest @@ -256,9 +541,20 @@ async def submit_update( request_index=0, queue=None, ) - await self.clients[replica_id].add_request_async(request) + await client.add_request_async(request) return replica_id + async def _pick_or_select( + self, + request_id: str, + *, + affinity_request_id: str | None = None, + ) -> int: + """Bridge to ``pick`` in distributed mode or ``select_replica_id`` legacy.""" + if self.is_distributed: + return await self.pick(request_id, affinity_request_id=affinity_request_id) + return self.select_replica_id(request_id, affinity_request_id=affinity_request_id) + # ---- Stage-local polling ---- async def _poll_stage_raw(self, client: Any) -> EngineCoreOutputs | None: @@ -275,6 +571,8 @@ async def process_llm_raw_outputs( ) -> list[Any]: """Run the shared LLM output processor on one raw poll result.""" client = self.clients[replica_id] + if client is None: + return [] processor = self.output_processor processed = processor.process_outputs( raw_outputs.outputs, @@ -298,6 +596,8 @@ async def poll_llm_raw_output( ) -> EngineCoreOutputs | None: """Poll raw EngineCore outputs from one LLM replica once.""" client = self.clients[replica_id] + if client is None: + return None try: return await asyncio.wait_for( self._poll_stage_raw(client), @@ -317,7 +617,10 @@ async def poll_llm_raw_output( def poll_diffusion_output(self, replica_id: int) -> Any | None: """Drain one ready diffusion output from the given replica if present.""" - return self.clients[replica_id].get_diffusion_output_nowait() + client = self.clients[replica_id] + if client is None: + return None + return client.get_diffusion_output_nowait() # ---- Stage-local control plane ---- @@ -325,7 +628,7 @@ async def abort_requests(self, request_ids: list[str]) -> None: """Abort the given requests in this stage pool. Request-bound abort routing stays inside the pool because route affinity - (`request_id -> replica_id`) is pool-owned. + (``request_id -> replica_id``) is pool-owned. """ if not request_ids: return @@ -333,13 +636,16 @@ async def abort_requests(self, request_ids: list[str]) -> None: request_ids_by_replica: dict[int, list[str]] = {} for request_id in request_ids: replica_id = self.get_bound_replica_id(request_id) - if replica_id is None: - logger.debug("[StagePool] abort: no binding for req=%s in stage-%s", request_id, self.stage_id) + if replica_id is None or self.clients[replica_id] is None: + logger.debug("[StagePool] abort: no live binding for req=%s in stage-%s", request_id, self.stage_id) continue request_ids_by_replica.setdefault(replica_id, []).append(request_id) for replica_id, replica_request_ids in request_ids_by_replica.items(): - await self.clients[replica_id].abort_requests_async(replica_request_ids) + client = self.clients[replica_id] + if client is None: + continue + await client.abort_requests_async(replica_request_ids) # Clean up OutputProcessor state (e.g. mm_accumulated tensors) that # would otherwise leak — aborted requests never produce a final @@ -355,10 +661,15 @@ async def collective_rpc( timeout: float | None = None, args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None, - ) -> Any: + ) -> dict[str, Any] | Any: """Dispatch a stage-scoped control-plane RPC to one physical route.""" kwargs = dict(kwargs or {}) client = self.clients[replica_id] + if client is None: + return { + "supported": False, + "error": f"stage {self.stage_id} replica {replica_id} is not attached", + } try: if hasattr(client, "collective_rpc_async"): return await client.collective_rpc_async( @@ -386,7 +697,11 @@ async def collective_rpc( def shutdown_replica(self, replica_id: int) -> None: """Shutdown one backend handle in this stage pool.""" + if replica_id >= len(self.clients): + return client = self.clients[replica_id] + if client is None: + return try: client.shutdown() logger.info( diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index d714d1f53ff..3cc08eb7887 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -9,6 +9,8 @@ import json import os import signal +import threading +from multiprocessing import connection from types import FrameType from typing import Any @@ -103,6 +105,36 @@ def validate(self, args: argparse.Namespace) -> None: if args.stage_id is not None and (args.omni_master_address is None or args.omni_master_port is None): raise ValueError("--stage-id requires both --omni-master-address and --omni-master-port to be set") + # --omni-replica-address is only consulted in run_headless(); reject it + # on the head so a misconfigured launch fails loudly instead of being + # silently ignored. + if getattr(args, "omni_replica_address", None) is not None and not args.headless: + raise ValueError("--omni-replica-address requires --headless to be set") + + # --omni-dp-size-local is process-local. A value other than 1 only + # makes sense when this process owns a stage (head or headless). + omni_dp_size_local = getattr(args, "omni_dp_size_local", None) + if omni_dp_size_local is not None: + if omni_dp_size_local < 1: + raise ValueError(f"--omni-dp-size-local must be >= 1, got {omni_dp_size_local}") + if omni_dp_size_local != 1 and args.stage_id is None: + raise ValueError("--omni-dp-size-local != 1 requires --stage-id to be set") + + # --omni-lb-policy is validated against the LoadBalancingPolicy enum. + omni_lb_policy = getattr(args, "omni_lb_policy", None) + if omni_lb_policy is not None: + from vllm_omni.distributed.omni_coordinator import LoadBalancingPolicy + + try: + LoadBalancingPolicy(omni_lb_policy) + except ValueError as exc: + valid = ", ".join(p.value for p in LoadBalancingPolicy) + raise ValueError(f"--omni-lb-policy={omni_lb_policy!r} is not one of: {valid}") from exc + + omni_heartbeat_timeout = getattr(args, "omni_heartbeat_timeout", None) + if omni_heartbeat_timeout is not None and omni_heartbeat_timeout <= 0: + raise ValueError(f"--omni-heartbeat-timeout must be > 0, got {omni_heartbeat_timeout}") + # Skip validation for diffusion models as they have different requirements from vllm_omni.diffusion.utils.hf_utils import is_diffusion_model @@ -253,6 +285,50 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu type=int, help="Port of the Omni orchestrator (master).", ) + omni_config_group.add_argument( + "--omni-replica-address", + "-ora", + type=str, + default=None, + help=( + "Local bind address (this host's IP) that the headless stage " + "advertises to the Omni master for its handshake/input/output " + "ZMQ sockets. If unset, auto-detected via a UDP-connect " + "routing probe against --omni-master-address. Override only " + "when the auto-detected IP is wrong (e.g. multi-NIC host " + "where the master is reachable on the wrong interface)." + ), + ) + omni_config_group.add_argument( + "--omni-dp-size-local", + type=int, + default=1, + help=( + "Number of stage replicas this runtime launches locally for its " + "own --stage-id. Process-local: head and every headless invocation " + "read their own copy; values may differ across invocations. " + "Requires --stage-id to be set when not equal to 1." + ), + ) + omni_config_group.add_argument( + "--omni-lb-policy", + type=str, + default="random", + choices=["random", "round-robin", "least-queue-length"], + help=( + "Per-stage load-balancing policy used by the head's StagePool to " + "route requests across UP replicas. Only consulted on the head runtime." + ), + ) + omni_config_group.add_argument( + "--omni-heartbeat-timeout", + type=float, + default=30.0, + help=( + "Seconds before an unreporting replica is marked ERROR in the " + "OmniCoordinator. Only consulted on the head runtime." + ), + ) # Diffusion model specific arguments omni_config_group.add_argument( @@ -565,9 +641,15 @@ def _create_default_diffusion_stage_cfg(args: argparse.Namespace) -> list[dict[s def run_headless(args: argparse.Namespace) -> None: - """Run a single stage in headless mode.""" + """Run a single stage in headless mode. + + Honors ``--omni-dp-size-local``: launches that many replicas locally for + ``--stage-id``. Each replica registers with the head's OmniMasterServer + (auto-assigned replica id when ``--omni-dp-size-local > 1`` so multiple + headless invocations can coexist) and reports heartbeats to the head's + OmniCoordinator. + """ from vllm.v1.engine.coordinator import DPCoordinator - from vllm.v1.engine.utils import CoreEngineProcManager from vllm.v1.executor.multiproc_executor import MultiprocExecutor from vllm.version import __version__ as VLLM_VERSION @@ -576,6 +658,7 @@ def run_headless(args: argparse.Namespace) -> None: spawn_diffusion_proc, ) from vllm_omni.distributed.omni_connectors.utils.initialization import resolve_omni_kv_config_for_stage + from vllm_omni.engine.omni_core_engine_proc_manager import OmniCoreEngineProcManager from vllm_omni.engine.stage_engine_startup import register_stage_with_omni_master from vllm_omni.engine.stage_init_utils import ( build_diffusion_config, @@ -583,18 +666,24 @@ def run_headless(args: argparse.Namespace) -> None: build_vllm_config, extract_stage_metadata, get_stage_connector_spec, + get_stage_devices_per_replica, inject_kv_stage_info, load_omni_transfer_config_for_model, prepare_engine_environment, + setup_stage_devices, + split_devices_for_replicas, terminate_alive_proc, ) from vllm_omni.entrypoints.utils import inject_omni_kv_config, load_and_resolve_stage_configs + from vllm_omni.platforms import current_omni_platform model = args.model stage_id: int | None = args.stage_id replica_id: int = args.replica_id omni_master_address: str | None = args.omni_master_address omni_master_port: int | None = args.omni_master_port + omni_replica_address: str | None = getattr(args, "omni_replica_address", None) + omni_dp_size_local: int = max(1, int(getattr(args, "omni_dp_size_local", 1) or 1)) if stage_id is None: raise ValueError("--stage-id is required in headless mode") @@ -610,6 +699,11 @@ def run_headless(args: argparse.Namespace) -> None: args_dict = vars(args).copy() args_dict.pop("_cli_explicit_keys", None) + # Forward ``--deploy-config`` so the headless reads the same YAML the + # head was launched with — otherwise ``load_and_resolve_stage_configs`` + # falls back to ``vllm_omni/deploy/.yaml`` and the headless's + # view of ``stage.runtime.devices`` diverges from the head's, breaking + # the per-replica device split. config_path, stage_configs = load_and_resolve_stage_configs( model, args_dict.get("stage_configs_path"), @@ -632,6 +726,33 @@ def run_headless(args: argparse.Namespace) -> None: omni_transfer_config = load_omni_transfer_config_for_model(model, config_path) omni_conn_cfg, omni_from, omni_to = resolve_omni_kv_config_for_stage(omni_transfer_config, stage_id) + # When ``--omni-dp-size-local > 1``, slice the YAML's ``devices:`` field + # into per-replica subsets so each subprocess we spawn below sees a + # narrowed ``CUDA_VISIBLE_DEVICES`` and doesn't stack on cuda:0. Mirrors + # the head-side per-replica device application at + # ``async_omni_engine.py`` (setup_stage_devices around each launch). + runtime_cfg = getattr(stage_cfg, "runtime", None) + devices_str: str | None = None + if runtime_cfg is not None: + devices_str = ( + runtime_cfg.get("devices") if hasattr(runtime_cfg, "get") else getattr(runtime_cfg, "devices", None) + ) + devices_per_replica = get_stage_devices_per_replica(stage_cfg) + if omni_dp_size_local > 1 and devices_str: + per_replica_devices: list[str | None] = split_devices_for_replicas( + devices_str, omni_dp_size_local, devices_per_replica, stage_id + ) + logger.info( + "[Headless] Stage %d: %d local replicas, devices_per_replica=%d, per-replica devices: %s", + stage_id, + omni_dp_size_local, + devices_per_replica, + per_replica_devices, + ) + else: + per_replica_devices = [None] * omni_dp_size_local + device_control_env = current_omni_platform.device_control_env_var + if stage_cfg.stage_type == "diffusion": metadata = extract_stage_metadata(stage_cfg) if omni_conn_cfg: @@ -640,39 +761,109 @@ def run_headless(args: argparse.Namespace) -> None: od_config = build_diffusion_config(model, stage_cfg, metadata) logger.info( - "[Headless] Launching diffusion stage %d replica %d via OmniMasterServer at %s:%d", + "[Headless] Launching %d diffusion replica(s) for stage %d via OmniMasterServer at %s:%d", + omni_dp_size_local, stage_id, - replica_id, omni_master_address, omni_master_port, ) - proc = None + procs: list[Any] = [] try: - handshake_address, request_address, response_address = register_stage_with_omni_master( - omni_master_address=omni_master_address, - omni_master_port=omni_master_port, - omni_stage_id=stage_id, - omni_stage_config=stage_cfg, - return_addresses=True, - replica_id=replica_id, - ) - proc, _, _, _ = spawn_diffusion_proc( - model, - od_config, - handshake_address=handshake_address, - request_address=request_address, - response_address=response_address, + for _rep_idx in range(omni_dp_size_local): + # Auto-assign replica id when launching multiple replicas + # so independent headless invocations can coexist for the + # same stage. The user-supplied --replica-id is honored + # only when launching exactly one replica. + req_replica_id: int | None = replica_id if omni_dp_size_local == 1 else None + response = register_stage_with_omni_master( + omni_master_address=omni_master_address, + omni_master_port=omni_master_port, + omni_stage_id=stage_id, + omni_stage_config=stage_cfg, + replica_id=req_replica_id, + return_full_response=True, + replica_bind_address=omni_replica_address, + ) + # Apply this replica's CUDA_VISIBLE_DEVICES (only when + # ``--omni-dp-size-local > 1`` and the YAML's stage devices + # field is set). The spawned subprocess inherits the env at + # spawn time; we restore the parent env afterwards so the + # next replica's setup sees the same baseline. + previous_visible_devices = os.environ.get(device_control_env) + try: + if per_replica_devices[_rep_idx] is not None: + setup_stage_devices(stage_id, {"devices": per_replica_devices[_rep_idx]}) + # Each StageDiffusionProc starts its own + # torch.distributed group bound to + # ``od_config.master_port``. Without an explicit + # per-replica override all spawned subprocesses + # share the value ``OmniDiffusionConfig.__post_init__`` + # picked once (and the second binder hits EADDRINUSE + # on ``init_process_group``). We can't use + # kernel-ephemeral allocation either, because the + # master server's pre-allocated ZMQ ports (returned + # by ``register_stage_with_omni_master``) also live + # in the ephemeral range and are not actually bound + # until the headless ``_perform_diffusion_handshake`` + # runs — so picking an ephemeral port here can steal + # a port the master server already promised to a + # sibling headless. Use ``settle_port`` from a base + # above the Linux default ephemeral range + # (32768-60999) so torch.distributed master ports + # never overlap with ZMQ allocations. + if omni_dp_size_local > 1: + od_config.master_port = od_config.settle_port( + 61000 + _rep_idx * 100, + port_inc=37, + ) + proc, _, _, _ = spawn_diffusion_proc( + model, + od_config, + handshake_address=response.handshake_address, + request_address=response.input_address, + response_address=response.output_address, + omni_coordinator_address=response.coordinator_router_address, + omni_stage_id=stage_id, + omni_replica_id=response.replica_id, + ) + finally: + if previous_visible_devices is None: + current_omni_platform.unset_device_control_env_var() + else: + current_omni_platform.set_device_control_env_var(previous_visible_devices) + complete_diffusion_handshake(proc, response.handshake_address, args.stage_init_timeout) + procs.append(proc) + logger.info( + "[Headless] Diffusion replica id=%d for stage %d is up (coord=%s)", + response.replica_id, + stage_id, + response.coordinator_router_address, + ) + + # Block on the sentinel set so any replica crash is detected + # immediately (the previous per-proc join loop only noticed + # crashes in registration order). Any exit triggers fleet + # shutdown via the finally block; non-zero exits propagate. + sentinel_to_proc = {p.sentinel: p for p in procs} + died = connection.wait(list(sentinel_to_proc.keys())) + first = sentinel_to_proc[died[0]] + logger.info( + "[Headless] Diffusion replica %s exited (code=%s); shutting down stage %d.", + first.name, + first.exitcode, + stage_id, ) - complete_diffusion_handshake(proc, handshake_address, args.stage_init_timeout) - proc.join() - if proc.exitcode not in (None, 0): - raise RuntimeError(f"Diffusion stage {stage_id} replica {replica_id} exited with code {proc.exitcode}") + if first.exitcode not in (None, 0): + raise RuntimeError( + f"Diffusion stage {stage_id} replica {first.name!r} exited with code {first.exitcode}" + ) return finally: - logger.info("[Headless] Shutting down stage %d replica %d.", stage_id, replica_id) - if proc is not None and proc.is_alive(): - terminate_alive_proc(proc) + logger.info("[Headless] Shutting down %d diffusion replica(s) for stage %d.", len(procs), stage_id) + for p in procs: + if p.is_alive(): + terminate_alive_proc(p) stage_connector_spec = get_stage_connector_spec( omni_transfer_config=omni_transfer_config, @@ -680,8 +871,10 @@ def run_headless(args: argparse.Namespace) -> None: async_chunk=False, ) - # Device assignment is managed externally (e.g. CUDA_VISIBLE_DEVICES); - # runtime_cfg is intentionally ignored in headless mode. + # ``runtime_cfg`` is mostly inherited from the parent's + # CUDA_VISIBLE_DEVICES; when ``--omni-dp-size-local > 1`` we additionally + # bracket each replica's spawn below with setup_stage_devices so they + # don't all stack on cuda:0 (see ``per_replica_devices`` above). engine_args_dict = build_engine_args_dict( stage_cfg, model, @@ -754,50 +947,98 @@ def signal_handler(signum: int, frame: FrameType | None) -> None: ) logger.info( - "[Headless] Launching %d engine core(s) for stage %d replica %d via OmniMasterServer at %s:%d", + "[Headless] Launching %d omni replica(s) (vLLM dp_size_local=%d each) for stage %d " + "via OmniMasterServer at %s:%d", + omni_dp_size_local, local_engine_count, stage_id, - replica_id, omni_master_address, omni_master_port, ) - # Headless mode launches all local engine cores for a single stage. - # The OmniMasterServer allocates one handshake endpoint per stage, so we - # register the stage once here and let every local engine core reuse the - # returned handshake address directly. - handshake_address = register_stage_with_omni_master( - omni_master_address=omni_master_address, - omni_master_port=omni_master_port, - omni_stage_id=stage_id, - omni_stage_config=stage_cfg, - coordinator=coordinator, - replica_id=replica_id, - ) - - engine_manager = None + # One OmniMasterServer registration per omni replica; each registration + # yields its own (handshake, input, output) allocation and the head's + # OmniCoordinator ROUTER address. We then spawn one + # OmniCoreEngineProcManager per replica so its subprocess gets the + # right replica id wired into its OmniCoordClientForStage. log_stats = bool(args.log_stats) if args.disable_log_stats: log_stats = False + engine_managers: list[Any] = [] + monitor_threads: list[threading.Thread] = [] + + def _monitor_target(mgr: Any) -> None: + try: + mgr.monitor_engine_liveness() + except Exception: + logger.exception("[Headless] monitor_engine_liveness raised") + try: - engine_manager = CoreEngineProcManager( - local_engine_count=local_engine_count, - start_index=dp_rank, - local_start_index=0, - vllm_config=vllm_config, - local_client=False, - handshake_address=handshake_address, - executor_class=executor_class, - log_stats=log_stats, - ) - # vllm>=0.19 renamed CoreEngineProcManager.join_first() to - # monitor_engine_liveness() (see upstream PR #35862). - engine_manager.monitor_engine_liveness() + for _rep_idx in range(omni_dp_size_local): + req_replica_id: int | None = replica_id if omni_dp_size_local == 1 else None + response = register_stage_with_omni_master( + omni_master_address=omni_master_address, + omni_master_port=omni_master_port, + omni_stage_id=stage_id, + omni_stage_config=stage_cfg, + coordinator=coordinator, + replica_id=req_replica_id, + return_full_response=True, + replica_bind_address=omni_replica_address, + ) + # Per-replica CUDA_VISIBLE_DEVICES, same pattern as the diffusion + # branch above. OmniCoreEngineProcManager.__init__ spawns its + # subprocesses via context.Process inside the constructor, so we + # must set the env *before* instantiation and restore after. + previous_visible_devices = os.environ.get(device_control_env) + try: + if per_replica_devices[_rep_idx] is not None: + setup_stage_devices(stage_id, {"devices": per_replica_devices[_rep_idx]}) + mgr = OmniCoreEngineProcManager( + local_engine_count=local_engine_count, + start_index=dp_rank, + local_start_index=0, + vllm_config=vllm_config, + local_client=False, + handshake_address=response.handshake_address, + executor_class=executor_class, + log_stats=log_stats, + omni_stage_id=stage_id, + omni_coordinator_address=response.coordinator_router_address, + omni_replica_base_id=response.replica_id, + ) + finally: + if previous_visible_devices is None: + current_omni_platform.unset_device_control_env_var() + else: + current_omni_platform.set_device_control_env_var(previous_visible_devices) + engine_managers.append(mgr) + logger.info( + "[Headless] Stage %d replica id=%d up (coord=%s)", + stage_id, + response.replica_id, + response.coordinator_router_address, + ) + + # Run all managers' liveness monitors in parallel. Each blocks + # until its own subprocesses exit (or fail). + if len(engine_managers) == 1: + engine_managers[0].monitor_engine_liveness() + else: + for mgr in engine_managers: + t = threading.Thread(target=_monitor_target, args=(mgr,), name=f"omni-replica-monitor-{id(mgr):x}") + t.start() + monitor_threads.append(t) + for t in monitor_threads: + t.join() finally: - logger.info("[Headless] Shutting down stage %d.", stage_id) - if engine_manager is not None: - engine_manager.shutdown() + logger.info("[Headless] Shutting down stage %d (%d managers).", stage_id, len(engine_managers)) + for mgr in engine_managers: + try: + mgr.shutdown() + except Exception: + logger.exception("[Headless] engine manager shutdown failed") if coordinator is not None: coordinator.shutdown() diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 1e057a71efa..fc94bf2d709 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -1918,8 +1918,6 @@ def sample( min_score = torch.finfo(logits.dtype).min - assert logits.shape[0] == 1, f"HunyuanImage3 sampler requires max_num_seqs=1, got batch size {logits.shape[0]}" - for req_idx in range(logits.shape[0]): decoded_tokens: list[int] = ( sampling_metadata.output_token_ids[req_idx] if req_idx < len(sampling_metadata.output_token_ids) else []