diff --git a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py index 0525e944004..75e6d49c587 100644 --- a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py +++ b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py @@ -21,6 +21,7 @@ build_rank_aware_send_keys, get_kv_target_ranks, get_local_tp_rank, + get_omni_replica_id, get_tp_world_size, kv_zmq_port, merge_received_rank_shards, @@ -415,7 +416,13 @@ def connector(self): stage_int = int(self.config.from_stage) if self.config.from_stage is not None else 0 except (TypeError, ValueError): stage_int = 0 - zmq_port = kv_zmq_port(base_port, stage_int, self._tp_topo.local_rank) + replica_id = get_omni_replica_id() + zmq_port = kv_zmq_port( + base_port, + stage_int, + self._tp_topo.local_rank, + replica_id=replica_id, + ) if self.config.need_send_cache: c_extra["role"] = "sender" diff --git a/vllm_omni/distributed/omni_connectors/utils/initialization.py b/vllm_omni/distributed/omni_connectors/utils/initialization.py index f012af3c9c3..765da65789b 100644 --- a/vllm_omni/distributed/omni_connectors/utils/initialization.py +++ b/vllm_omni/distributed/omni_connectors/utils/initialization.py @@ -25,9 +25,17 @@ # Port stride between TP ranks so each worker binds a unique ZMQ port # when TP > 1. Must be larger than the maximum number of pipeline stages. -# Formula: zmq_port = base + KV_TRANSFER_PORT_OFFSET + rank * STRIDE + stage +# Formula: +# zmq_port = base + KV_TRANSFER_PORT_OFFSET +# + replica * KV_REPLICA_PORT_STRIDE +# + rank * KV_RANK_PORT_STRIDE +# + stage KV_RANK_PORT_STRIDE = 16 +# Port stride between Omni replicas of the same stage. This reserves a +# comfortably sized block per replica for TP-rank and stage offsets. +KV_REPLICA_PORT_STRIDE = 1024 + def initialize_connectors_from_config( config_path: str | Path | None = None, diff --git a/vllm_omni/distributed/omni_connectors/utils/kv_utils.py b/vllm_omni/distributed/omni_connectors/utils/kv_utils.py index 12b9b3d4f77..fe0df21f25b 100644 --- a/vllm_omni/distributed/omni_connectors/utils/kv_utils.py +++ b/vllm_omni/distributed/omni_connectors/utils/kv_utils.py @@ -16,7 +16,7 @@ ) from vllm.logger import init_logger -from .initialization import KV_RANK_PORT_STRIDE, KV_TRANSFER_PORT_OFFSET +from .initialization import KV_RANK_PORT_STRIDE, KV_REPLICA_PORT_STRIDE, KV_TRANSFER_PORT_OFFSET logger = init_logger(__name__) @@ -94,20 +94,42 @@ def get_tp_world_size() -> int: return 1 +def get_omni_replica_id() -> int: + """Return the Omni replica id for this worker process.""" + try: + replica_id = int(os.environ.get("VLLM_OMNI_REPLICA_ID", "0")) + except (ValueError, TypeError): + return 0 + return max(replica_id, 0) + + # ------------------------------------------------------------------ # # ZMQ port computation # ------------------------------------------------------------------ # -def kv_zmq_port(base_port: int, from_stage: int, local_rank: int = 0) -> int: +def kv_zmq_port( + base_port: int, + from_stage: int, + local_rank: int = 0, + replica_id: int | None = None, +) -> int: """Compute the ZMQ port for a KV-transfer connector. - Each TP rank gets its own port so that TP > 1 deployments do not - cause ``EADDRINUSE`` when multiple sender workers bind on the same - host. The formula is backward-compatible: rank 0 produces the same - port as the previous ``base + OFFSET + stage`` formula. + Each Omni replica and TP rank gets its own port so multi-replica or + TP > 1 deployments do not cause ``EADDRINUSE`` when multiple sender + workers bind on the same host. The formula is backward-compatible: + replica 0 / rank 0 produces the previous ``base + OFFSET + stage`` port. + """ - return base_port + KV_TRANSFER_PORT_OFFSET + local_rank * KV_RANK_PORT_STRIDE + from_stage + replica = get_omni_replica_id() if replica_id is None else max(int(replica_id), 0) + return ( + base_port + + KV_TRANSFER_PORT_OFFSET + + replica * KV_REPLICA_PORT_STRIDE + + local_rank * KV_RANK_PORT_STRIDE + + from_stage + ) # ------------------------------------------------------------------ # diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py index 04a047fb3b6..8849319429a 100644 --- a/vllm_omni/engine/stage_engine_core_client.py +++ b/vllm_omni/engine/stage_engine_core_client.py @@ -20,9 +20,8 @@ from vllm.v1.engine.core_client import AsyncMPClient, DPLBAsyncMPClient from vllm.v1.engine.exceptions import EngineDeadError -from vllm_omni.distributed.omni_connectors.utils.initialization import ( - KV_TRANSFER_PORT_OFFSET, -) +from vllm_omni.distributed.omni_connectors.utils.initialization import KV_TRANSFER_PORT_OFFSET +from vllm_omni.distributed.omni_connectors.utils.kv_utils import kv_zmq_port from vllm_omni.engine.stage_init_utils import StageMetadata if TYPE_CHECKING: @@ -341,7 +340,12 @@ def _initialize_kv_sender_endpoint(self) -> None: try: # Orchestrator always reports rank-0's port; receiver # workers add their own local_rank * KV_RANK_PORT_STRIDE. - sender_port = int(base_port) + KV_TRANSFER_PORT_OFFSET + int(from_stage) + sender_port = kv_zmq_port( + int(base_port), + int(from_stage), + local_rank=0, + replica_id=self.replica_id, + ) except (TypeError, ValueError): logger.warning( "[StageEngineCoreClient] stage-%s [rep-%s] could not resolve sender_zmq_port " @@ -383,7 +387,12 @@ def get_kv_sender_info( # rank-0 base port; receiver workers adjust per KV_RANK_PORT_STRIDE. return { "host": self._kv_sender_host, - "zmq_port": base_port + kv_transfer_port_offset + int(self.stage_id), + "zmq_port": kv_zmq_port( + base_port - KV_TRANSFER_PORT_OFFSET + kv_transfer_port_offset, + int(self.stage_id), + local_rank=0, + replica_id=self.replica_id, + ), } def set_engine_outputs(self, engine_outputs: EngineCoreOutput) -> None: diff --git a/vllm_omni/engine/stage_engine_core_proc.py b/vllm_omni/engine/stage_engine_core_proc.py index 6b77388b41f..0b3dadcbd1c 100644 --- a/vllm_omni/engine/stage_engine_core_proc.py +++ b/vllm_omni/engine/stage_engine_core_proc.py @@ -8,6 +8,7 @@ from __future__ import annotations import contextlib +import os import signal from multiprocessing.process import BaseProcess from typing import TYPE_CHECKING, Any @@ -90,6 +91,7 @@ def run_stage_core( stage_label = f"stage{omni_stage_id}" if omni_stage_id is not None else "noid" set_process_title(f"StageEngineCoreProc_{stage_label}_replica{omni_replica_id}_DP{dp_rank}") decorate_logs() + os.environ["VLLM_OMNI_REPLICA_ID"] = str(max(int(omni_replica_id), 0)) engine_core = StageEngineCoreProc( *args,