Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm_omni/diffusion/stage_diffusion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def _initialize_client(
self.engine_input_source = getattr(metadata, "engine_input_source", [])
self._proc = proc
self._owns_process = proc is not None
# Expose the ZMQ addresses on the instance so callers (e.g.
# ``StagePool._client_input_addr``) can identify the diffusion
# replica by its bound address.
self.request_address = request_address
self.response_address = response_address

self._zmq_ctx = zmq.Context()
self._request_socket = self._zmq_ctx.socket(zmq.PUSH)
Expand Down
75 changes: 74 additions & 1 deletion vllm_omni/diffusion/stage_diffusion_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import asyncio
import contextlib
import signal
import time
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -30,6 +31,7 @@
OmniMsgpackDecoder,
OmniMsgpackEncoder,
)
from vllm_omni.distributed.omni_coordinator import OmniCoordClientForStage
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput

Expand All @@ -54,6 +56,19 @@ def __init__(self, model: str, od_config: OmniDiffusionConfig) -> None:
self._engine: DiffusionEngine | None = None
self._executor: ThreadPoolExecutor | None = None
self._closed = False
# Set by ``run_loop`` to the live dispatch task dict so
# :attr:`queue_length` can report in-flight requests for the
# OmniCoordinator heartbeat hook.
self._active_tasks: dict[str, asyncio.Task] | None = None

@property
def queue_length(self) -> int:
"""Number of in-flight diffusion requests.

Returns 0 before :meth:`run_loop` starts and after it exits.
"""
tasks = self._active_tasks
return 0 if tasks is None else len(tasks)

# ------------------------------------------------------------------
# Initialization
Expand Down Expand Up @@ -309,6 +324,9 @@ async def run_loop(
decoder = OmniMsgpackDecoder()

tasks: dict[str, asyncio.Task] = {}
# Expose the live task dict so :attr:`queue_length` (used by the
# OmniCoordinator heartbeat hook) can read the in-flight count.
self._active_tasks = tasks

async def _dispatch_request(
request_id: str,
Expand Down Expand Up @@ -466,6 +484,7 @@ async def _dispatch_batch(
if tasks:
await asyncio.gather(*tasks.values(), return_exceptions=True)

self._active_tasks = None
request_socket.close()
response_socket.close()
ctx.term()
Expand Down Expand Up @@ -504,8 +523,22 @@ def run_diffusion_proc(
handshake_address: str,
request_address: str,
response_address: str,
*,
omni_coordinator_address: str | None = None,
omni_stage_id: int | None = None,
omni_replica_id: int = 0,
) -> None:
"""Entry point for the diffusion subprocess."""
"""Entry point for the diffusion subprocess.

Omni-specific kwargs (mirroring :meth:`StageEngineCoreProc.run_stage_core`):
- ``omni_coordinator_address``: ROUTER address of the head-side
OmniCoordinator. When set, a :class:`OmniCoordClientForStage`
reports the diffusion replica's status + queue length.
- ``omni_stage_id``: logical stage id; required when
``omni_coordinator_address`` is set.
- ``omni_replica_id``: cluster-unique replica id within the
stage (logging / metrics only).
"""
shutdown_requested = False

def signal_handler(signum: int, frame: Any) -> None:
Expand All @@ -518,6 +551,7 @@ def signal_handler(signum: int, frame: Any) -> None:
signal.signal(signal.SIGINT, signal_handler)

proc = cls(model, od_config)
coord_client: OmniCoordClientForStage | None = None
try:
proc.initialize()

Expand All @@ -529,6 +563,32 @@ def signal_handler(signum: int, frame: Any) -> None:
handshake_socket.close()
handshake_ctx.term()

# Wire OmniCoordClientForStage *after* READY so that the head
# has bound its head-side request/response sockets — the
# address pair we report is the same pair this proc binds to
# (request/response addresses passed in).
if omni_coordinator_address is not None:
if omni_stage_id is None:
raise ValueError("omni_stage_id must be provided when omni_coordinator_address is set")
coord_client = OmniCoordClientForStage(
coord_zmq_addr=omni_coordinator_address,
input_addr=request_address,
output_addr=response_address,
stage_id=int(omni_stage_id),
)

def _refresh_queue_length() -> None:
coord_client._queue_length = proc.queue_length # type: ignore[union-attr]

coord_client._on_heartbeat = _refresh_queue_length

logger.info(
"StageDiffusionProc registered with OmniCoordinator (stage_id=%d replica_id=%d coord=%s)",
omni_stage_id,
omni_replica_id,
omni_coordinator_address,
)

# Run async event loop
asyncio.run(proc.run_loop(request_address, response_address))

Expand All @@ -539,6 +599,9 @@ def signal_handler(signum: int, frame: Any) -> None:
logger.exception("StageDiffusionProc encountered a fatal error.")
raise
finally:
if coord_client is not None:
with contextlib.suppress(RuntimeError):
coord_client.close()
proc.close()


Expand All @@ -551,10 +614,17 @@ def spawn_diffusion_proc(
handshake_address: str | None = None,
request_address: str | None = None,
response_address: str | None = None,
*,
omni_coordinator_address: str | None = None,
omni_stage_id: int | None = None,
omni_replica_id: int = 0,
) -> tuple[BaseProcess, str, str, str]:
"""Spawn a StageDiffusionProc subprocess.

Returns ``(proc, handshake_address, request_address, response_address)``.

Pass ``omni_coordinator_address`` / ``omni_stage_id`` / ``omni_replica_id``
to have the subprocess publish heartbeats to an OmniCoordinator.
"""
handshake_address = handshake_address or get_open_zmq_ipc_path()
request_address = request_address or get_open_zmq_ipc_path()
Expand All @@ -570,6 +640,9 @@ def spawn_diffusion_proc(
"handshake_address": handshake_address,
"request_address": request_address,
"response_address": response_address,
"omni_coordinator_address": omni_coordinator_address,
"omni_stage_id": omni_stage_id,
"omni_replica_id": omni_replica_id,
},
)
proc.start()
Expand Down
10 changes: 6 additions & 4 deletions vllm_omni/distributed/omni_coordinator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@
RoundRobinBalancer,
Task,
)
from .messages import InstanceEvent, InstanceInfo, InstanceList, StageStatus
from .messages import ReplicaEvent, ReplicaInfo, ReplicaList, StageStatus
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve the coordinator Instance aliases

Removing the exported InstanceEvent / InstanceInfo / InstanceList names breaks existing coordinator imports in this repo (for example the tests/distributed/omni_coordinator tests still import InstanceInfo and InstanceList) and any downstream code using the public package export. Unless all call sites are migrated in the same change, keep aliases to the new Replica* classes so those imports continue to work.

Useful? React with 👍 / 👎.

from .omni_coord_client_for_hub import OmniCoordClientForHub
from .omni_coord_client_for_stage import OmniCoordClientForStage
from .omni_coordinator import OmniCoordinator
from .runtime import OmniCoordinatorRuntime

__all__ = [
"OmniCoordinator",
"OmniCoordinatorRuntime",
"StageStatus",
"InstanceEvent",
"InstanceInfo",
"InstanceList",
"ReplicaEvent",
"ReplicaInfo",
"ReplicaList",
"OmniCoordClientForStage",
"OmniCoordClientForHub",
"Task",
Expand Down
85 changes: 34 additions & 51 deletions vllm_omni/distributed/omni_coordinator/load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from enum import Enum
from typing import Any, TypedDict

from .messages import InstanceInfo
from .messages import ReplicaInfo


class Task(TypedDict, total=False):
"""Task structure passed from async_omni (stage.submit(task)).
"""Task structure passed to ``StagePool.pick`` / ``LoadBalancer.select``.

Mirrors the dict built in AsyncOmni with request_id, engine_inputs,
sampling_params. Future load-balancing policies may use these fields.
Mirrors the dict built around a stage submission with request_id and any
payload-related fields a future load-balancing policy might inspect.
"""

request_id: str
Expand All @@ -28,10 +28,7 @@ class LoadBalancingPolicy(str, Enum):
"""Enumeration for load balancing policies.

These policies are used by :class:`LoadBalancer` implementations to route
tasks to a subset of available instances.

TODO(NumberWan): Map enum values to balancer classes when OmniCoordinator
integration lands. Tracked in https://github.com/vllm-project/vllm-omni/pull/2448
tasks to a subset of available replicas.
"""

RANDOM = "random"
Expand All @@ -42,97 +39,83 @@ class LoadBalancingPolicy(str, Enum):
class LoadBalancer(ABC):
"""Abstract base class for load balancers.

Subclasses implement :meth:`select` to choose an instance for a given task.
Subclasses implement :meth:`select` to choose a replica for a given task.
"""

@abstractmethod
def select(self, task: Task, instances: list[InstanceInfo]) -> int:
"""Route a task to one of the available instances.
def select(self, task: Task, replicas: list[ReplicaInfo]) -> int:
"""Route a task to one of the available replicas.

Args:
task: The task to route. Not used by the random policy but reserved
for future strategies that may inspect task metadata.
instances: List of available instances to choose from.
replicas: List of available replicas to choose from.

Returns:
Index of the selected instance in ``instances``.
Index of the selected replica in ``replicas``.

Raises:
ValueError: If ``instances`` is empty.
ValueError: If ``replicas`` is empty.
"""

raise NotImplementedError


class RandomBalancer(LoadBalancer):
"""Load balancer that selects an instance uniformly at random.

It intentionally ignores the task payload and chooses a random index from
the provided instance list. More sophisticated policies (e.g. round-robin,
least-queue-length) can be implemented as additional subclasses of
:class:`LoadBalancer`.
"""
"""Load balancer that selects a replica uniformly at random."""

def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002
if not instances:
raise ValueError("instances must not be empty")
def select(self, task: Task, replicas: list[ReplicaInfo]) -> int: # noqa: ARG002
if not replicas:
raise ValueError("replicas must not be empty")

return random.randrange(len(instances))
return random.randrange(len(replicas))


class RoundRobinBalancer(LoadBalancer):
"""Load balancer that selects instances in a round-robin fashion.
"""Load balancer that selects replicas in a round-robin fashion.

This implementation keeps a running index modulo ``len(instances)``. It
therefore depends on the **order and stable meaning** of the ``instances``
This implementation keeps a running index modulo ``len(replicas)``. It
therefore depends on the **order and stable meaning** of the ``replicas``
list between calls. If the list length or ordering changes, the sequence
of picks may skip or repeat entries relative to a fixed set of backends.

When instance membership changes dynamically, callers should reset routing
state—for example by constructing a new ``RoundRobinBalancer`` or resetting
``_next_index``—similar to rebuilding ``itertools.cycle`` after mutating
the instance list (as in vLLM's disaggregated proxy examples).

Concurrency: ``select`` is synchronous and is expected to run on the
coordinator asyncio event loop thread without ``await`` inside this
method, so a single invocation is not interleaved with another on that
thread. A :class:`threading.Lock` still serializes updates to
``_next_index`` for callers that might invoke ``select`` from multiple
threads or alongside threaded infrastructure (e.g. ZMQ receive threads).
Concurrency: a ``threading.Lock`` serializes updates to ``_next_index``
for callers that invoke ``select`` from multiple threads or alongside
threaded infrastructure (e.g. ZMQ receive threads).
"""

def __init__(self, start_index: int = 0) -> None:
self._next_index = start_index
self._lock = threading.Lock()

def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002
if not instances:
raise ValueError("instances must not be empty")
def select(self, task: Task, replicas: list[ReplicaInfo]) -> int: # noqa: ARG002
if not replicas:
raise ValueError("replicas must not be empty")

n = len(instances)
n = len(replicas)
with self._lock:
idx = self._next_index % n
self._next_index = (self._next_index + 1) % n
return idx


class LeastQueueLengthBalancer(LoadBalancer):
"""Select the instance with the smallest ``queue_length``.
"""Select the replica with the smallest ``queue_length``.

If multiple instances share the same minimum queue length, one of them is
If multiple replicas share the same minimum queue length, one of them is
chosen uniformly at random.

Raises:
ValueError: If any instance has a negative ``queue_length``.
ValueError: If any replica has a negative ``queue_length``.
"""

def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002
if not instances:
raise ValueError("instances must not be empty")
def select(self, task: Task, replicas: list[ReplicaInfo]) -> int: # noqa: ARG002
if not replicas:
raise ValueError("replicas must not be empty")

queue_lengths = [inst.queue_length for inst in instances]
queue_lengths = [rep.queue_length for rep in replicas]
if any(q < 0 for q in queue_lengths):
raise ValueError("queue_length must be non-negative for all instances")
raise ValueError("queue_length must be non-negative for all replicas")
min_q = min(queue_lengths)
candidates = [i for i, q in enumerate(queue_lengths) if q == min_q]
return random.choice(candidates)
Expand Down
Loading
Loading