Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion vllm_omni/distributed/omni_connectors/kv_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
build_rank_aware_send_keys,
get_kv_target_ranks,
get_local_tp_rank,
get_omni_replica_id,
get_tp_world_size,
kv_zmq_port,
merge_received_rank_shards,
Expand Down Expand Up @@ -415,7 +416,13 @@ def connector(self):
stage_int = int(self.config.from_stage) if self.config.from_stage is not None else 0
except (TypeError, ValueError):
stage_int = 0
zmq_port = kv_zmq_port(base_port, stage_int, self._tp_topo.local_rank)
replica_id = get_omni_replica_id()
zmq_port = kv_zmq_port(
base_port,
stage_int,
self._tp_topo.local_rank,
replica_id=replica_id,
)

if self.config.need_send_cache:
c_extra["role"] = "sender"
Expand Down
10 changes: 9 additions & 1 deletion vllm_omni/distributed/omni_connectors/utils/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,17 @@

# Port stride between TP ranks so each worker binds a unique ZMQ port
# when TP > 1. Must be larger than the maximum number of pipeline stages.
# Formula: zmq_port = base + KV_TRANSFER_PORT_OFFSET + rank * STRIDE + stage
# Formula:
# zmq_port = base + KV_TRANSFER_PORT_OFFSET
# + replica * KV_REPLICA_PORT_STRIDE
# + rank * KV_RANK_PORT_STRIDE
# + stage
KV_RANK_PORT_STRIDE = 16

# Port stride between Omni replicas of the same stage. This reserves a
# comfortably sized block per replica for TP-rank and stage offsets.
KV_REPLICA_PORT_STRIDE = 1024


def initialize_connectors_from_config(
config_path: str | Path | None = None,
Expand Down
36 changes: 29 additions & 7 deletions vllm_omni/distributed/omni_connectors/utils/kv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from vllm.logger import init_logger

from .initialization import KV_RANK_PORT_STRIDE, KV_TRANSFER_PORT_OFFSET
from .initialization import KV_RANK_PORT_STRIDE, KV_REPLICA_PORT_STRIDE, KV_TRANSFER_PORT_OFFSET

logger = init_logger(__name__)

Expand Down Expand Up @@ -94,20 +94,42 @@ def get_tp_world_size() -> int:
return 1


def get_omni_replica_id() -> int:
"""Return the Omni replica id for this worker process."""
try:
replica_id = int(os.environ.get("VLLM_OMNI_REPLICA_ID", "0"))
except (ValueError, TypeError):
return 0
return max(replica_id, 0)


# ------------------------------------------------------------------ #
# ZMQ port computation
# ------------------------------------------------------------------ #


def kv_zmq_port(base_port: int, from_stage: int, local_rank: int = 0) -> int:
def kv_zmq_port(
base_port: int,
from_stage: int,
local_rank: int = 0,
replica_id: int | None = None,
) -> int:
"""Compute the ZMQ port for a KV-transfer connector.

Each TP rank gets its own port so that TP > 1 deployments do not
cause ``EADDRINUSE`` when multiple sender workers bind on the same
host. The formula is backward-compatible: rank 0 produces the same
port as the previous ``base + OFFSET + stage`` formula.
Each Omni replica and TP rank gets its own port so multi-replica or
TP > 1 deployments do not cause ``EADDRINUSE`` when multiple sender
workers bind on the same host. The formula is backward-compatible:
replica 0 / rank 0 produces the previous ``base + OFFSET + stage`` port.

"""
return base_port + KV_TRANSFER_PORT_OFFSET + local_rank * KV_RANK_PORT_STRIDE + from_stage
replica = get_omni_replica_id() if replica_id is None else max(int(replica_id), 0)
return (
base_port
+ KV_TRANSFER_PORT_OFFSET
+ replica * KV_REPLICA_PORT_STRIDE
+ local_rank * KV_RANK_PORT_STRIDE
+ from_stage
)


# ------------------------------------------------------------------ #
Expand Down
19 changes: 14 additions & 5 deletions vllm_omni/engine/stage_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
from vllm.v1.engine.core_client import AsyncMPClient, DPLBAsyncMPClient
from vllm.v1.engine.exceptions import EngineDeadError

from vllm_omni.distributed.omni_connectors.utils.initialization import (
KV_TRANSFER_PORT_OFFSET,
)
from vllm_omni.distributed.omni_connectors.utils.initialization import KV_TRANSFER_PORT_OFFSET
from vllm_omni.distributed.omni_connectors.utils.kv_utils import kv_zmq_port
from vllm_omni.engine.stage_init_utils import StageMetadata

if TYPE_CHECKING:
Expand Down Expand Up @@ -341,7 +340,12 @@ def _initialize_kv_sender_endpoint(self) -> None:
try:
# Orchestrator always reports rank-0's port; receiver
# workers add their own local_rank * KV_RANK_PORT_STRIDE.
sender_port = int(base_port) + KV_TRANSFER_PORT_OFFSET + int(from_stage)
sender_port = kv_zmq_port(
int(base_port),
int(from_stage),
local_rank=0,
replica_id=self.replica_id,
)
except (TypeError, ValueError):
logger.warning(
"[StageEngineCoreClient] stage-%s [rep-%s] could not resolve sender_zmq_port "
Expand Down Expand Up @@ -383,7 +387,12 @@ def get_kv_sender_info(
# rank-0 base port; receiver workers adjust per KV_RANK_PORT_STRIDE.
return {
"host": self._kv_sender_host,
"zmq_port": base_port + kv_transfer_port_offset + int(self.stage_id),
"zmq_port": kv_zmq_port(
base_port - KV_TRANSFER_PORT_OFFSET + kv_transfer_port_offset,
int(self.stage_id),
local_rank=0,
replica_id=self.replica_id,
),
}

def set_engine_outputs(self, engine_outputs: EngineCoreOutput) -> None:
Expand Down
2 changes: 2 additions & 0 deletions vllm_omni/engine/stage_engine_core_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import contextlib
import os
import signal
from multiprocessing.process import BaseProcess
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -90,6 +91,7 @@ def run_stage_core(
stage_label = f"stage{omni_stage_id}" if omni_stage_id is not None else "noid"
set_process_title(f"StageEngineCoreProc_{stage_label}_replica{omni_replica_id}_DP{dp_rank}")
decorate_logs()
os.environ["VLLM_OMNI_REPLICA_ID"] = str(max(int(omni_replica_id), 0))

engine_core = StageEngineCoreProc(
*args,
Expand Down