diff --git a/docs/design/feature/disaggregated_inference.md b/docs/design/feature/disaggregated_inference.md index 83c35ac0107..0696aa116af 100644 --- a/docs/design/feature/disaggregated_inference.md +++ b/docs/design/feature/disaggregated_inference.md @@ -8,6 +8,7 @@ Backend-specific setup lives in separate docs: - [SharedMemoryConnector](omni_connectors/shared_memory_connector.md) - [MooncakeStoreConnector](omni_connectors/mooncake_store_connector.md) - [MooncakeTransferEngineConnector](omni_connectors/mooncake_transfer_engine_connector.md) +- [MoriTransferEngineConnector](omni_connectors/mori_transfer_engine_connector.md) - [YuanrongConnector](omni_connectors/yuanrong_connector.md) ## Overview @@ -22,6 +23,7 @@ Current connectors operate in D2H2D (device to host to device) mode. | Single node | SharedMemoryConnector | Auto-configured if no connector is specified. | | Multi node (Mooncake Store) | MooncakeStoreConnector | TCP-based, requires Mooncake Master + metadata server. | | Multi node (Mooncake RDMA) | MooncakeTransferEngineConnector | RDMA/TCP direct transfer with managed memory pool. Fastest. | +| Multi node (Mori RDMA) | MoriTransferEngineConnector | RDMA direct transfer via Mori IOEngine. | | Multi node (Yuanrong) | YuanrongConnector | Requires Yuanrong Datasystem + etcd. | ## Core API diff --git a/docs/design/feature/omni_connectors/mori_transfer_engine_connector.md b/docs/design/feature/omni_connectors/mori_transfer_engine_connector.md new file mode 100644 index 00000000000..367dd2d2872 --- /dev/null +++ b/docs/design/feature/omni_connectors/mori_transfer_engine_connector.md @@ -0,0 +1,77 @@ +# MoriTransferEngineConnector + +## When to Use + +Currently supports intra-node deployment with Mori. + +As noted in #1742, inter-node support will be added back in a future +refactor. + +## Mechanism + +Uses Mori's `IOEngine` / `MemoryDesc` API for zero-copy RDMA transfers. + +- Data Plane: RDMA (InfiniBand/RoCE) with managed memory pool. +- Control Plane: ZMQ for pull-request handshake and async completion. + +## Installation + +See the [Mori repository](https://github.com/ROCm/mori) for installation instructions. + +## Configuration + +Mori is configured through the new deploy-config schema (see +[`docs/configuration/stage_configs.md`](../../../configuration/stage_configs.md)). +Define the connector at the top level of the deploy YAML and reference it +by name from each stage's `input_connectors` / `output_connectors`: + +```yaml +connectors: + mori_connector: + name: MoriTransferEngineConnector + extra: + host: "auto" + zmq_port: 50051 + device_name: "" + memory_pool_size: 536870912 + memory_pool_device: "cuda" + +stages: + - stage_id: 0 + output_connectors: + to_stage_1: mori_connector + + - stage_id: 1 + input_connectors: + from_stage_0: mori_connector +``` + +A ready-to-run intra-node example for Qwen3-Omni-MoE on AMD MI300X lives +at +[`vllm_omni/deploy/qwen3_omni_moe_mori_intranode.yaml`](../../../../vllm_omni/deploy/qwen3_omni_moe_mori_intranode.yaml) +and can be loaded with: + +```bash +vllm-omni serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --log-stats \ + --deploy-config vllm_omni/deploy/qwen3_omni_moe_mori_intranode.yaml +``` + +The yaml wires `MoriTransferEngineConnector` (with `backend_type: xgmi`) to +the chunk_transfer_adapter path (`async_chunk: true`) so stage-to-stage +hidden-state and codec-frame streams ship GPU-to-GPU over AMD Infinity +Fabric instead of SHM. Qwen2.5-Omni + Mori is not yet functional on +the chunk path: it needs `thinker2talker_async_chunk` / +`talker2code2wav_async_chunk` input processors that do not exist yet +(the orchestrator-level path the upstream PR originally targeted was +lost during the entrypoints → engine refactor and removed in #1742). + +Parameters: + +- host: local RDMA IP (`"auto"` for auto-detect). +- zmq_port: ZMQ base port for control-plane communication. +- device_name: RDMA device (e.g., `"mlx5_0"`), empty for auto-detect. +- memory_pool_size: RDMA memory pool size in bytes. +- memory_pool_device: `"cpu"` (pinned) or `"cuda"` (GPUDirect / XGMI RDMA). + +For more details, refer to the +[Mori repository](https://github.com/ROCm/mori). diff --git a/tests/distributed/omni_connectors/test_omni_connector_configs.py b/tests/distributed/omni_connectors/test_omni_connector_configs.py index 3dc80b57bf5..c954ecb4488 100644 --- a/tests/distributed/omni_connectors/test_omni_connector_configs.py +++ b/tests/distributed/omni_connectors/test_omni_connector_configs.py @@ -6,7 +6,12 @@ import pytest # Use the new import path for initialization utilities -from vllm_omni.distributed.omni_connectors.utils.initialization import load_omni_transfer_config +from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec, OmniTransferConfig +from vllm_omni.distributed.omni_connectors.utils.initialization import ( + _inject_chunk_path_endpoints, + get_connectors_config_for_stage, + load_omni_transfer_config, +) pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -64,3 +69,197 @@ def test_load_qwen_yaml_configs(yaml_file): except Exception as e: pytest.fail(f"Failed to load config {yaml_file.name}: {e}") + + +# --------------------------------------------------------------------------- +# Framework-level per-stage role + endpoint derivation for the chunk +# transfer adapter path. +# +# ``get_connectors_config_for_stage`` is responsible for turning a +# role-neutral edge-level ConnectorSpec into a per-stage view where each +# stage carries: +# +# * ``role=sender`` if the stage only has outgoing edges; +# * ``role=receiver`` if the stage only has incoming edges; +# * ``role=dual`` if the stage has both (middle stage in a 3+ stage +# pipeline; e.g. the Qwen3-Omni-MoE talker stage). Dual stages +# emit ``from_stage_*`` and ``to_stage_*`` entries that share the +# same composite extra so downstream flattening (engine-side +# ``get_stage_connector_spec`` returning the first spec) always +# recovers a self-consistent config. +# +# For role-bound ZMQ connectors (Mori / Mooncake) the function also +# pre-computes ``zmq_port`` / ``sender_host`` / ``sender_zmq_port`` so +# an intranode pipeline can come up without an external handshake. +# +# The orchestrator-level path (``create_connectors_from_config``) has +# its own Mooncake-specific port adjustment and is intentionally NOT +# exercised here. +# --------------------------------------------------------------------------- + + +def _linear_pipeline_config( + connector_name: str, + extra: dict | None = None, + edges: tuple[tuple[str, str], ...] = (("0", "1"), ("1", "2")), +) -> OmniTransferConfig: + shared_extra = dict(extra or {}) + specs = {edge: ConnectorSpec(name=connector_name, extra=dict(shared_extra)) for edge in edges} + return OmniTransferConfig(connectors=specs) + + +@pytest.fixture +def _stable_local_ip(monkeypatch: pytest.MonkeyPatch) -> str: + """Pin framework-level local-IP detection so endpoint assertions are deterministic.""" + import vllm_omni.distributed.omni_connectors.utils.initialization as init_mod + + monkeypatch.setattr(init_mod, "_detect_local_ip", lambda: "10.20.30.40") + return "10.20.30.40" + + +@pytest.mark.parametrize( + "connector_name", + ["MoriTransferEngineConnector", "MooncakeTransferEngineConnector"], +) +def test_stage_0_is_sender_only(connector_name, _stable_local_ip): + """Stage 0 has only outgoing edges → role=sender, listener port = base + 0.""" + cfg = _linear_pipeline_config(connector_name, extra={"zmq_port": 50051, "host": "auto"}) + + stage0 = get_connectors_config_for_stage(cfg, 0) + + assert list(stage0.keys()) == ["to_stage_1"], "Sender-only stage should emit only to_stage_*" + extra = stage0["to_stage_1"]["spec"]["extra"] + assert extra["role"] == "sender" + assert extra["zmq_port"] == 50051 + assert "sender_host" not in extra + assert "sender_zmq_port" not in extra + + +@pytest.mark.parametrize( + "connector_name", + ["MoriTransferEngineConnector", "MooncakeTransferEngineConnector"], +) +def test_final_stage_is_receiver_only(connector_name, _stable_local_ip): + """Stage 2 has only incoming edges → role=receiver, points at upstream sender.""" + cfg = _linear_pipeline_config(connector_name, extra={"zmq_port": 50051, "host": "auto"}) + + stage2 = get_connectors_config_for_stage(cfg, 2) + + assert list(stage2.keys()) == ["from_stage_1"], "Receiver-only stage should emit only from_stage_*" + extra = stage2["from_stage_1"]["spec"]["extra"] + assert extra["role"] == "receiver" + assert extra["sender_zmq_port"] == 50052 # base + upstream stage id (1) + assert extra["sender_host"] == "10.20.30.40" + + +@pytest.mark.parametrize( + "connector_name", + ["MoriTransferEngineConnector", "MooncakeTransferEngineConnector"], +) +def test_middle_stage_is_dual(connector_name, _stable_local_ip): + """Middle stage has both → role=dual, both entries share composite spec.""" + cfg = _linear_pipeline_config(connector_name, extra={"zmq_port": 50051, "host": "auto"}) + + stage1 = get_connectors_config_for_stage(cfg, 1) + + # Both directions exposed, so whichever one get_stage_connector_spec + # (which does "return first") picks, it recovers the same dual spec. + assert set(stage1.keys()) == {"from_stage_0", "to_stage_2"} + incoming = stage1["from_stage_0"]["spec"]["extra"] + outgoing = stage1["to_stage_2"]["spec"]["extra"] + + for extra in (incoming, outgoing): + assert extra["role"] == "dual" + assert extra["zmq_port"] == 50052 # this stage's listener, base + own stage id (1) + assert extra["sender_host"] == "10.20.30.40" + assert extra["sender_zmq_port"] == 50051 # upstream sender at base + 0 + + # Composite must be identical so order of iteration does not matter. + assert incoming == outgoing + + +def test_shm_connector_is_untouched_by_endpoint_injection(_stable_local_ip): + """SharedMemoryConnector is not role-bound -> no port/host fields appear.""" + cfg = _linear_pipeline_config( + "SharedMemoryConnector", + extra={"shm_threshold_bytes": 65536}, + ) + + for sid in (0, 1, 2): + stage_cfg = get_connectors_config_for_stage(cfg, sid) + for entry in stage_cfg.values(): + extra = entry["spec"]["extra"] + assert "zmq_port" not in extra + assert "sender_host" not in extra + assert "sender_zmq_port" not in extra + assert extra["shm_threshold_bytes"] == 65536 + # Role is still injected (passive connectors ignore it, which + # is fine -- the string is just metadata to them). + assert extra["role"] in {"sender", "receiver", "dual"} + + +def test_explicit_sender_host_and_port_override_win(_stable_local_ip): + """User-provided ``sender_host`` / ``sender_zmq_port`` beat framework derivation.""" + cfg = _linear_pipeline_config( + "MoriTransferEngineConnector", + extra={ + "zmq_port": 50051, + "host": "auto", + "sender_host": "192.168.1.10", + "sender_zmq_port": 60000, + }, + ) + + stage1 = get_connectors_config_for_stage(cfg, 1) + recv_extra = stage1["from_stage_0"]["spec"]["extra"] + assert recv_extra["sender_host"] == "192.168.1.10" + assert recv_extra["sender_zmq_port"] == 60000 + + # Sender-side zmq_port still offsets so co-located stage listeners do + # not collide; the override is the upstream peer's address, not this + # stage's own bind port. + assert recv_extra["zmq_port"] == 50052 + + +def test_explicit_non_auto_host_cascades_to_sender_host(_stable_local_ip): + """Non-auto ``host`` is reused as ``sender_host`` for the receiver side.""" + cfg = _linear_pipeline_config( + "MoriTransferEngineConnector", + extra={"zmq_port": 50051, "host": "172.16.0.5"}, + ) + + stage1 = get_connectors_config_for_stage(cfg, 1) + recv_extra = stage1["from_stage_0"]["spec"]["extra"] + assert recv_extra["sender_host"] == "172.16.0.5" + assert recv_extra["sender_zmq_port"] == 50051 + + +def test_inject_helper_is_noop_for_unknown_connector(_stable_local_ip): + """Non-role-bound connectors are left untouched by the helper.""" + extra: dict = {"zmq_port": 50051, "host": "auto"} + _inject_chunk_path_endpoints( + extra, + connector_name="SomeFutureConnector", + role="dual", + own_stage="1", + upstream_stage="0", + ) + assert extra == {"zmq_port": 50051, "host": "auto"} + + +def test_inject_helper_is_noop_for_non_integer_stage(_stable_local_ip): + """Non-integer pipeline keys (e.g. ``"prefill"``) short-circuit safely.""" + extra: dict = {"zmq_port": 50051, "host": "auto"} + _inject_chunk_path_endpoints( + extra, + connector_name="MoriTransferEngineConnector", + role="sender", + own_stage="prefill", + upstream_stage=None, + ) + assert extra == {"zmq_port": 50051, "host": "auto"} + + +def test_get_connectors_config_for_stage_none_transfer_config(): + """Callers pass ``None`` when no yaml was loaded; return an empty dict.""" + assert get_connectors_config_for_stage(None, 0) == {} diff --git a/vllm_omni/deploy/qwen3_omni_moe_mooncake_intranode.yaml b/vllm_omni/deploy/qwen3_omni_moe_mooncake_intranode.yaml new file mode 100644 index 00000000000..fd4af021540 --- /dev/null +++ b/vllm_omni/deploy/qwen3_omni_moe_mooncake_intranode.yaml @@ -0,0 +1,118 @@ +# Qwen3-Omni-MoE intra-node deploy using Mooncake transfer (benchmark companion). +# +# Sibling of ``qwen3_omni_moe_mori_intranode.yaml`` intended for +# Mori-vs-Mooncake benchmark comparison on a single AMD Instinct MI300X +# OAM node. Pipeline topology and sampling params are identical to the +# Mori yaml; only the connector definition and its backend-specific +# knobs differ. +# +# Transport note (default: ``rdma`` via Mellanox + GPUDirect RDMA): +# ``MooncakeTransferEngineConnector`` forwards ``protocol`` to +# ``mooncake.engine.TransferEngine.initialize``. The default here +# (``"rdma"``) drives Mooncake's RoCE/IB transport; on MI300X with +# ``memory_pool_device: "cuda"`` this ships tensors GPU-to-GPU via +# GPUDirect RDMA through the Mellanox HCAs. This is the path +# end-to-end validated on this branch (three consecutive chat +# completions finish in 18.7 / 16.8 / 16.4 s with 93 per-chunk +# ``[RDMA GET]`` transfers at median ~0.7 ms on edge 0 -> 1). +# +# Switching to true XGMI: +# Mooncake PRs #1742 and #1550 add a first-class HIP transport that +# rides AMD Infinity Fabric XGMI directly, and expose it at the +# Python binding level via a ``xgmi -> hip`` normalization helper +# (so setting ``protocol: "xgmi"`` below asks for HIP). Doing so +# requires: +# (a) a mooncake wheel built with ``-DUSE_HIP=ON`` and reinstalled +# into the container's Python env, +# (b) ``memory_pool_device: "cuda"`` (required for the HIP memory +# allocator), and +# (c) setting ``MC_FORCE_MNNVL=1`` in the vllm-omni launch env so +# ``TransferEngineImpl::init()`` installs the HIP transport +# instead of falling back to RDMA on boxes that also have +# HCAs discovered by the auto-topology scan. +# When all three are in place, Mooncake prints ``Selected HIP +# memory allocator`` at boot and subsequent transfers go over XGMI +# Infinity Fabric directly. The RDMA default above is intentional +# until that full toolchain is validated end-to-end on this +# hardware. Cross-node deployments must keep ``protocol: "rdma"``. +# +# Pipeline wiring (stage_id, execution_type, model_stage, input_sources, +# custom processors, async-chunk builders) lives in +# ``vllm_omni/model_executor/models/qwen3_omni/pipeline.py`` and is +# merged into the stages below by the deploy loader. +# +# Usage: +# vllm-omni serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --log-stats \ +# --deploy-config vllm_omni/deploy/qwen3_omni_moe_mooncake_intranode.yaml +# +# Dependencies: mooncake (``pip install mooncake``), torch-rocm>=2.9, +# vllm-omni with #2383 deploy schema, ibverbs/RDMA NICs exposed via +# /dev/infiniband. +async_chunk: true + +connectors: + mooncake_connector: + name: MooncakeTransferEngineConnector + extra: + host: "auto" + zmq_port: 50051 + protocol: "rdma" # see transport note above; set "xgmi" to opt into HIP + device_name: "" # leave empty to honour $RDMA_DEVICE_NAME + memory_pool_size: 536870912 # 512 MB, matches Mori yaml + memory_pool_device: "cuda" # required for HIP / XGMI memory allocator + # Chunk-transfer knobs consumed by talker2code2wav_async_chunk: + # 25 decoded codec frames per connector.put (~40 ms at 625 Hz + # codec rate) with a 25-frame left context to satisfy the + # code2wav receptive field. Kept identical to the Mori yaml so + # accuracy stays comparable when swapping backends. + codec_chunk_frames: 25 + codec_left_context_frames: 25 + +stages: + - stage_id: 0 + gpu_memory_utilization: 0.9 + enforce_eager: true # ROCm override: avoid flashinfer autotune path + devices: "0" + output_connectors: + to_stage_1: mooncake_connector + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + repetition_penalty: 1.05 + + - stage_id: 1 + gpu_memory_utilization: 0.9 + enforce_eager: true + devices: "1" + input_connectors: + from_stage_0: mooncake_connector + output_connectors: + to_stage_2: mooncake_connector + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + repetition_penalty: 1.05 + + - stage_id: 2 + gpu_memory_utilization: 0.3 + max_num_seqs: 1 + enforce_eager: true + async_scheduling: false + # Codec prefill length (Q * num_frames) exceeds 32k default; + # matches qwen3_omni_moe.yaml code2wav sizing. + max_num_batched_tokens: 51200 + devices: "2" + input_connectors: + from_stage_1: mooncake_connector + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + repetition_penalty: 1.1 diff --git a/vllm_omni/deploy/qwen3_omni_moe_mori_intranode.yaml b/vllm_omni/deploy/qwen3_omni_moe_mori_intranode.yaml new file mode 100644 index 00000000000..791b4fa9d6a --- /dev/null +++ b/vllm_omni/deploy/qwen3_omni_moe_mori_intranode.yaml @@ -0,0 +1,104 @@ +# Qwen3-Omni-MoE intra-node deploy using Mori XGMI transfer. +# +# Targets a single AMD Instinct MI300X OAM node (8x GPU, 192GB HBM3, +# gfx942). Each stage occupies one dedicated GPU (3 GPUs total) and +# stage-to-stage streaming chunks are shipped via Mori's RDMA IOEngine +# (with ``backend_type: xgmi`` selecting the AMD Infinity Fabric +# GPU-to-GPU direct path for intranode transfers). +# +# Pipeline wiring (stage_id, execution_type, model_stage, input_sources, +# custom processors, async-chunk builders) lives in +# ``vllm_omni/model_executor/models/qwen3_omni/pipeline.py`` and is +# merged into the stages below by the deploy loader. +# +# Why this yaml exists: +# * On CUDA/H100 Qwen3-Omni-MoE ships via ``SharedMemoryConnector`` +# (see ``qwen3_omni_moe.yaml``). On MI300X the SHM path still works +# but XGMI RDMA can substantially reduce stage-transition latency +# for the thinker->talker hidden-state stream and the talker->code2wav +# codec-frame stream, both of which run in async-chunk mode. +# * This config swaps the connector definition to +# ``MoriTransferEngineConnector`` while keeping the pipeline topology, +# sampling params, and async_chunk=true gating identical to the CUDA +# deploy, so the chunk_transfer_adapter path (scheduler-owned +# background put/get threads) is fully exercised. +# +# Usage: +# vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --log-stats \ +# --deploy-config vllm_omni/deploy/qwen3_omni_moe_mori_intranode.yaml +# +# Dependencies: mori (``pip install mori``), torch-rocm>=2.9, vllm-omni +# with #2383 deploy schema, ibverbs/RDMA NICs exposed via /dev/infiniband. +async_chunk: true + +connectors: + mori_connector: + name: MoriTransferEngineConnector + extra: + host: "auto" + zmq_port: 50051 + backend_type: "xgmi" # AMD Infinity Fabric GPU-to-GPU direct + device_name: "" # leave empty to honour $MORI_RDMA_DEVICES + memory_pool_size: 536870912 # 512 MB, matches Qwen2.5-Omni Mori deploy + memory_pool_device: "cuda" # GPU memory required for XGMI RDMA transfers + xgmi_num_streams: 64 + xgmi_num_events: 64 + qp_per_transfer: 1 + num_worker_threads: 1 + post_batch_size: -1 + # Chunk-transfer knobs consumed by talker2code2wav_async_chunk: + # 25 decoded codec frames per connector.put (~40 ms at 625 Hz + # codec rate) with a 25-frame left context to satisfy the + # code2wav receptive field. Kept identical to the SHM deploy so + # accuracy stays comparable when swapping backends. + codec_chunk_frames: 25 + codec_left_context_frames: 25 + +stages: + - stage_id: 0 + gpu_memory_utilization: 0.9 + enforce_eager: true # ROCm override: avoid flashinfer autotune path + devices: "0" + output_connectors: + to_stage_1: mori_connector + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + repetition_penalty: 1.05 + + - stage_id: 1 + gpu_memory_utilization: 0.9 + enforce_eager: true + devices: "1" + input_connectors: + from_stage_0: mori_connector + output_connectors: + to_stage_2: mori_connector + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + repetition_penalty: 1.05 + + - stage_id: 2 + gpu_memory_utilization: 0.3 + max_num_seqs: 1 + enforce_eager: true + async_scheduling: false + # Codec prefill length (Q * num_frames) exceeds 32k default; + # matches qwen3_omni_moe.yaml code2wav sizing. + max_num_batched_tokens: 51200 + devices: "2" + input_connectors: + from_stage_1: mori_connector + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + repetition_penalty: 1.1 diff --git a/vllm_omni/distributed/__init__.py b/vllm_omni/distributed/__init__.py index f43a833f8b8..a451fab2044 100644 --- a/vllm_omni/distributed/__init__.py +++ b/vllm_omni/distributed/__init__.py @@ -6,6 +6,7 @@ MooncakeConnector, MooncakeStoreConnector, MooncakeTransferEngineConnector, + MoriTransferEngineConnector, OmniConnectorBase, OmniConnectorFactory, OmniTransferConfig, @@ -24,6 +25,7 @@ "MooncakeConnector", # compat alias "MooncakeStoreConnector", "MooncakeTransferEngineConnector", + "MoriTransferEngineConnector", "SharedMemoryConnector", "YuanrongConnector", # Utilities diff --git a/vllm_omni/distributed/omni_connectors/__init__.py b/vllm_omni/distributed/omni_connectors/__init__.py index 383267928b4..c537ce939f8 100644 --- a/vllm_omni/distributed/omni_connectors/__init__.py +++ b/vllm_omni/distributed/omni_connectors/__init__.py @@ -10,6 +10,11 @@ from .connectors.mooncake_transfer_engine_connector import MooncakeTransferEngineConnector except ImportError: MooncakeTransferEngineConnector = None # RDMA deps (msgspec/zmq/mooncake) not installed + +try: + from .connectors.mori_transfer_engine_connector import MoriTransferEngineConnector +except ImportError: + MoriTransferEngineConnector = None # RDMA deps (msgspec/zmq/mori) not installed from .factory import OmniConnectorFactory from .utils.config import ConnectorSpec, OmniTransferConfig from .utils.initialization import ( @@ -37,6 +42,7 @@ "MooncakeConnector", # compat alias → MooncakeStoreConnector "MooncakeStoreConnector", "MooncakeTransferEngineConnector", + "MoriTransferEngineConnector", "SharedMemoryConnector", "YuanrongConnector", # Utilities diff --git a/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py index bd4160f3e63..a33361feec5 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py @@ -281,6 +281,15 @@ def __init__(self, config: dict[str, Any]): } self.config = config + + # Stage id is read by OmniChunkTransferAdapter.process_pending_chunks + # (and the AR / generation schedulers) to decide whether this + # connector sits at stage 0 of the pipeline. Kept parallel to + # SharedMemoryConnector and MoriTransferEngineConnector. Populated + # by OmniEngineArgs.create_model_config via the ``stage_id`` key + # that get_connectors_config_for_stage injects into extra. + self.stage_id = config.get("stage_id", -1) + host_config = config.get("host") host_value = "auto" if host_config is None else str(host_config) # Default sender/receiver bootstrap to a routable local IP so the @@ -315,15 +324,30 @@ def __init__(self, config: dict[str, Any]): # --- Role --- # "sender": bind ZMQ listener, accept put() calls. - # "receiver": skip ZMQ bind, only accept get() calls. + # "receiver": skip ZMQ bind, only accept get() calls by querying + # an upstream sender at + # ``sender_host`` / ``sender_zmq_port``. + # "dual": a single instance that simultaneously binds the + # listener (so a downstream receiver can pull from + # it) AND queries an upstream sender (so its own + # get() works). Used by middle stages on the + # chunk_transfer_adapter path -- the adapter keeps + # its historical one-connector-per-stage model + # (like ``SharedMemoryConnector``), and a dual + # instance is how this role-bound RDMA connector + # exposes put+get from the same object. The role + # is chosen by the framework (see + # ``get_connectors_config_for_stage``) based on + # whether the stage has incoming / outgoing edges. # The orchestration layer (get_connectors_config_for_stage / # kv_transfer_manager) is responsible for injecting the correct role. role = str(config.get("role", "sender")).lower() - if role not in {"sender", "receiver"}: + if role not in {"sender", "receiver", "dual"}: raise ValueError( - f"Invalid role={role!r} for MooncakeTransferEngineConnector. Expected 'sender' or 'receiver'." + f"Invalid role={role!r} for MooncakeTransferEngineConnector. Expected 'sender', 'receiver', or 'dual'." ) - self.can_put = role == "sender" + self.role = role + self.can_put = role in ("sender", "dual") self.engine_id = str(uuid.uuid4()) @@ -373,7 +397,7 @@ def __init__(self, config: dict[str, Any]): f" Role: can_put={self.can_put}, configured_role={config.get('role', 'sender')}" ) - # Only sender needs ZMQ listener to handle pull requests + # Sender and dual both need the ZMQ listener to handle pull requests. if self.can_put: self._last_ttl_check = _time_mod.monotonic() # reset after slow init self._listener_thread = threading.Thread(target=self._zmq_listener_loop, daemon=True) @@ -386,9 +410,16 @@ def __init__(self, config: dict[str, Any]): f"MooncakeTransferEngineConnector failed to bind ZMQ on " f"{self.host}:{self.zmq_port}: {self._bind_error}" ) from self._bind_error - logger.info( - f"MooncakeTransferEngineConnector started as SENDER (ZMQ listener on {self.host}:{self.zmq_port})" - ) + if self.role == "dual": + logger.info( + f"MooncakeTransferEngineConnector started as DUAL " + f"(ZMQ listener on {self.host}:{self.zmq_port}, " + f"upstream sender at {self.sender_host}:{self.sender_zmq_port})" + ) + else: + logger.info( + f"MooncakeTransferEngineConnector started as SENDER (ZMQ listener on {self.host}:{self.zmq_port})" + ) else: # Receiver mode — sender address is provided per-request via # metadata from put() through the queue, not pre-configured. diff --git a/vllm_omni/distributed/omni_connectors/connectors/mori_transfer_engine_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mori_transfer_engine_connector.py new file mode 100644 index 00000000000..9b6fdd6c462 --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/connectors/mori_transfer_engine_connector.py @@ -0,0 +1,1063 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""OmniConnector backed by the Mori RDMA transfer engine. + +Implements the ``OmniConnector`` interface using Mori's ``IOEngine`` / +``MemoryDesc`` API to perform zero-copy RDMA transfers between +disaggregated prefill and decode workers. + +Notable design points: + * Remote peers must be registered via ``register_remote_engine()`` + with the peer's ``EngineDesc`` before any data transfer can proceed. + * Memory regions are described by ``MemoryDesc`` objects (serialisable + via pack/unpack) rather than raw virtual addresses. + * Transfers are dispatched through ``IOEngine.batch_write()`` and + tracked asynchronously via ``TransferStatus`` objects. + +ZMQ handshake protocol: + * **Pull request** – receiver sends ``MoriPullRequest`` (msgspec-encoded) + containing its ``EngineDesc`` and pool ``MemoryDesc`` so the sender can + register the remote engine and RDMA-write directly into the receiver's + pool at the specified offset. + * **Query request** – receiver sends ``QUERY_INFO`` prefix + ``QueryRequest`` + to check whether data is ready on the sender side (metadata-less get path). +""" + +import os +import queue +import socket +import threading +import time as _time_mod +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +import msgspec +import torch +import zmq + +from ..utils.logging import get_connector_logger +from ..utils.serialization import OmniSerializer +from .base import OmniConnectorBase + +logger = get_connector_logger(__name__) + +try: + from mori.cpp import TransferStatus # noqa: F401 + from mori.io import ( + BackendType, + EngineDesc, + IOEngine, + IOEngineConfig, + MemoryDesc, + PollCqMode, + RdmaBackendConfig, + XgmiBackendConfig, + ) +except ImportError: + IOEngine = None + +# Supported backend types for Mori. Kept as string constants so configuration +# (YAML / CLI / dict) stays transport-agnostic. ``rdma`` uses NIC-based RDMA +# (RoCE / IB, GDR-capable); ``xgmi`` uses AMD Infinity Fabric GPU-to-GPU +# direct links and therefore requires a CUDA pool. +_SUPPORTED_BACKENDS = ("rdma", "xgmi") + +_BUFFER_TTL_SECONDS = 300 + +TRANS_DONE = b"trans_done" +TRANS_ERROR = b"trans_error" +QUERY_INFO = b"query_info" +INFO_NOT_FOUND = b"info_not_found" + + +# --------------------------------------------------------------------------- +# ZMQ message types +# --------------------------------------------------------------------------- + + +class MoriPullRequest(msgspec.Struct): + """Receiver → sender: request an RDMA write into the receiver's pool.""" + + request_id: str + engine_desc_packed: bytes + mem_desc_packed: bytes + dst_offset: int + length: int + + +class QueryRequest(msgspec.Struct): + """Receiver → sender: query whether data for *request_id* is ready.""" + + request_id: str + + +class QueryResponse(msgspec.Struct): + """Sender → receiver: metadata about a ready buffer.""" + + request_id: str + data_size: int + is_fast_path: bool + + +# --------------------------------------------------------------------------- +# Pool memory management (shared with MooncakeTransferEngineConnector) +# --------------------------------------------------------------------------- + + +class BufferAllocator: + """Simple first-fit allocator over a contiguous memory pool.""" + + def __init__(self, total_size: int, alignment: int = 4096): + self.total_size = total_size + self.alignment = alignment + self.lock = threading.Lock() + self.free_blocks: list[tuple[int, int]] = [(0, total_size)] + + def alloc(self, size: int) -> int: + aligned = (size + self.alignment - 1) // self.alignment * self.alignment + with self.lock: + for i, (start, bsz) in enumerate(self.free_blocks): + if bsz >= aligned: + remainder = bsz - aligned + if remainder > 0: + self.free_blocks[i] = (start + aligned, remainder) + else: + self.free_blocks.pop(i) + return start + raise MemoryError(f"Out of memory in buffer pool. Requested {size} bytes (aligned {aligned}).") + + def free(self, offset: int, size: int) -> None: + aligned = (size + self.alignment - 1) // self.alignment * self.alignment + with self.lock: + for start, length in self.free_blocks: + if offset == start and aligned == length: + return + if offset >= start and offset + aligned <= start + length: + return + if not (offset + aligned <= start or start + length <= offset): + raise RuntimeError( + f"Memory corruption: freeing {offset}-{offset + aligned} " + f"overlaps with free block {start}-{start + length}" + ) + self.free_blocks.append((offset, aligned)) + self.free_blocks.sort() + + i = 0 + while i < len(self.free_blocks) - 1: + cs, csz = self.free_blocks[i] + ns, nsz = self.free_blocks[i + 1] + if cs + csz == ns: + self.free_blocks[i] = (cs, csz + nsz) + self.free_blocks.pop(i + 1) + else: + i += 1 + + +class ManagedBuffer: + """Zero-copy view into the global memory pool. + + Callers **must** call :meth:`release` (or use the context manager) after + they are done reading/writing the buffer so the region is returned to the + :class:`BufferAllocator`. + """ + + def __init__( + self, + allocator: BufferAllocator, + offset: int, + size: int, + pool_tensor: torch.Tensor, + ): + self.allocator = allocator + self.offset = offset + self.size = size + self.pool_tensor = pool_tensor + self._released = False + + def release(self) -> None: + if not self._released: + self.allocator.free(self.offset, self.size) + self._released = True + + def __del__(self) -> None: + self.release() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.release() + + @property + def tensor(self) -> torch.Tensor: + return self.pool_tensor[self.offset : self.offset + self.size] + + def as_tensor(self, dtype: torch.dtype, shape: tuple) -> torch.Tensor: + itemsize = torch.tensor([], dtype=dtype).element_size() + expected = itemsize + for d in shape: + expected *= d + if expected != self.size: + raise ValueError(f"Shape {shape} dtype {dtype} needs {expected} bytes, buffer has {self.size}") + if self.offset % itemsize != 0: + raise RuntimeError(f"Buffer offset {self.offset} not aligned for {dtype}") + return self.tensor.view(dtype).reshape(shape) + + def to_bytes(self) -> bytes: + t = self.tensor + if t.is_cuda: + t = t.cpu() + return t.numpy().tobytes() + + +# --------------------------------------------------------------------------- +# Connector +# --------------------------------------------------------------------------- + + +class MoriTransferEngineConnector(OmniConnectorBase): + """OmniConnector backed by the Mori RDMA IOEngine. + + Topology: 1 sender ↔ 1 receiver per key (same as + ``MooncakeTransferEngineConnector``). + """ + + supports_raw_data: bool = True + + # ------------------------------------------------------------------ init + def __init__(self, config: dict[str, Any]): + if IOEngine is None: + raise ImportError("Mori is not available. Install via: pip install mori") + + self._closed = False + self._bind_error: Exception | None = None + + # Teardown-safe defaults + self._stop_event = threading.Event() + self._sender_executor: ThreadPoolExecutor | None = None + self._listener_thread: threading.Thread | None = None + self._listener_ready = threading.Event() + self._local_buffers: dict[str, Any] = {} + self._local_buffers_lock = threading.Lock() + self._req_local = threading.local() + self._worker_local = threading.local() + self._last_ttl_check: float = _time_mod.monotonic() + + # Track remote engines already registered with IOEngine + self._registered_engines: set[str] = set() + self._registered_engines_lock = threading.Lock() + + self._metrics = { + "puts": 0, + "gets": 0, + "bytes_transferred": 0, + "errors": 0, + "timeouts": 0, + } + + self.config = config + + # Stage id is read by OmniChunkTransferAdapter (process_pending_chunks / + # _send_single_request / _poll_single_request / ...) to decide whether + # this connector sits at the sender or receiver end of a stage pair. + # Kept parallel to ``SharedMemoryConnector`` and other OmniConnector + # implementations. Populated by ``build_stage_connectors`` via the + # ``stage_id`` key that ``get_connectors_config_for_stage`` injects. + self.stage_id = config.get("stage_id", -1) + + # ---- Host / ZMQ ---- + host_cfg = config.get("host", "127.0.0.1") + if host_cfg.lower() == "auto": + self.host = self._get_local_ip() + logger.info(f"Auto-detected local IP: {self.host}") + else: + self.host = host_cfg + self.zmq_port = config.get("zmq_port", 50051) + + # ---- Backend selection ---- + # Defaults to "rdma" so existing deployments (pre-XGMI support) keep + # working unchanged. ``xgmi`` selects AMD Infinity Fabric GPU-to-GPU + # direct links; see ``mori.io.XgmiBackendConfig``. + backend_type_cfg = str(config.get("backend_type", "rdma")).lower() + if backend_type_cfg not in _SUPPORTED_BACKENDS: + raise ValueError(f"Invalid backend_type={backend_type_cfg!r}. Supported: {list(_SUPPORTED_BACKENDS)}.") + self.backend_type = backend_type_cfg + + # ---- RDMA device (RDMA backend only) ---- + self.device_name = "" + if self.backend_type == "rdma": + self.device_name = config.get("device_name", "") + if not self.device_name: + env_dev = os.environ.get("MORI_RDMA_DEVICES", "") + if env_dev: + self.device_name = env_dev + logger.info(f"Using MORI_RDMA_DEVICES from env: {self.device_name}") + elif config.get("device_name"): + logger.warning( + "device_name=%r is ignored for backend_type='xgmi' (XGMI does not use an RDMA NIC).", + config.get("device_name"), + ) + + # ---- Pool config ---- + self.pool_size = config.get("memory_pool_size", 1024**3) + self.pool_device = config.get("memory_pool_device", "cpu") + if self.backend_type == "xgmi" and self.pool_device == "cpu": + raise ValueError( + "backend_type='xgmi' requires memory_pool_device='cuda': " + "XGMI is a GPU-to-GPU fabric and cannot address CPU memory." + ) + + # ---- Sender info (receiver / dual uses this when metadata=None) ---- + self.sender_host = config.get("sender_host", None) + self.sender_zmq_port = config.get("sender_zmq_port", None) + + # ---- Role ---- + # ``sender`` : binds the ZMQ listener, accepts ``put()``. + # ``receiver`` : no listener, only ``get()`` by querying an + # upstream sender whose endpoint lives in + # ``sender_host`` / ``sender_zmq_port``. + # ``dual`` : single instance that simultaneously binds a + # listener (serves inbound pull requests from a + # downstream receiver) AND maintains an upstream + # endpoint (so its own ``get()`` can pull from an + # upstream sender). Used by middle stages on the + # chunk_transfer_adapter path -- the adapter keeps + # its historical "one connector per stage" model + # (like ``SharedMemoryConnector``), and a dual + # instance is how a role-bound RDMA connector + # exposes put+get from the same object. The role + # is chosen by the framework (see + # ``get_connectors_config_for_stage``) based on + # whether the stage has incoming / outgoing edges; + # ``MoriTransferEngineConnector`` itself does not + # reverse-infer its deployment mode. + role = str(config.get("role", "sender")).lower() + if role not in {"sender", "receiver", "dual"}: + raise ValueError( + f"Invalid role={role!r} for MoriTransferEngineConnector. Expected 'sender', 'receiver', or 'dual'." + ) + self.role = role + self.can_put = role in ("sender", "dual") + + # ---- Mori IOEngine ---- + if self.device_name: + os.environ["MORI_RDMA_DEVICES"] = self.device_name + + engine_config = IOEngineConfig(host=self.host, port=0) + self.engine_key = f"omni-{role}-{uuid.uuid4().hex[:8]}-pid{os.getpid()}-{self.host}" + self.engine = IOEngine(self.engine_key, engine_config) + + # ---- Backend creation (per backend_type) ---- + # ``IOEngine.batch_write`` is backend-agnostic, so only construction + # differs; the data-plane / ZMQ handshake paths below are identical. + if self.backend_type == "xgmi": + xgmi_cfg = XgmiBackendConfig() + xgmi_cfg.num_streams = config.get("xgmi_num_streams", 64) + xgmi_cfg.num_events = config.get("xgmi_num_events", 64) + self.engine.create_backend(BackendType.XGMI, xgmi_cfg) + logger.info(f"Mori backend: XGMI (num_streams={xgmi_cfg.num_streams}, num_events={xgmi_cfg.num_events})") + else: # rdma + qp_per_transfer = config.get("qp_per_transfer", 1) + post_batch_size = config.get("post_batch_size", -1) + num_workers = config.get("num_worker_threads", 1) + rdma_cfg = RdmaBackendConfig( + qp_per_transfer, + post_batch_size, + num_workers, + PollCqMode.POLLING, + False, + ) + self.engine.create_backend(BackendType.RDMA, rdma_cfg) + logger.info( + f"Mori backend: RDMA (qp_per_transfer={qp_per_transfer}, " + f"num_workers={num_workers}, device={self.device_name or 'auto'})" + ) + + self.engine_desc: EngineDesc = self.engine.get_engine_desc() + self.engine_desc_packed: bytes = self.engine_desc.pack() + logger.info(f"Mori IOEngine ready: key={self.engine_key} at {self.engine_desc.host}:{self.engine_desc.port}") + + # ---- Pool allocation & Mori memory registration ---- + logger.info( + f"Allocating {self.backend_type.upper()} pool: {self.pool_size / 1024**2:.2f} MB on {self.pool_device}" + ) + try: + if self.pool_device == "cpu": + self.pool = torch.empty(self.pool_size, dtype=torch.uint8).pin_memory() + else: + self.pool = torch.empty( + self.pool_size, + dtype=torch.uint8, + device=self.pool_device, + ) + self.pool_mem_desc: MemoryDesc = self.engine.register_torch_tensor(self.pool) + self.pool_mem_desc_packed: bytes = self.pool_mem_desc.pack() + self.base_ptr = self.pool.data_ptr() + except Exception as e: + logger.error(f"Failed to allocate/register pool: {e}") + raise + + self.allocator = BufferAllocator(self.pool_size, alignment=4096) + + # ---- ZMQ / threading ---- + self.zmq_ctx = zmq.Context() + self._sender_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="mori-sender") + + logger.info( + f"MoriTransferEngineConnector config:\n" + f" Local: host={self.host}, zmq_port={self.zmq_port}\n" + f" Remote: sender_host={self.sender_host}, " + f"sender_zmq_port={self.sender_zmq_port}\n" + f" Role: can_put={self.can_put}, " + f"configured_role={config.get('role', 'sender')}" + ) + + if self.can_put: + self._last_ttl_check = _time_mod.monotonic() + self._listener_thread = threading.Thread(target=self._zmq_listener_loop, daemon=True) + self._listener_thread.start() + self._listener_ready.wait(timeout=1.0) + if self._bind_error is not None: + raise RuntimeError( + f"MoriTransferEngineConnector failed to bind ZMQ on {self.host}:{self.zmq_port}: {self._bind_error}" + ) from self._bind_error + if self.role == "dual": + logger.info( + f"MoriTransferEngineConnector DUAL ready " + f"(ZMQ listening on {self.host}:{self.zmq_port}, " + f"upstream sender at {self.sender_host}:{self.sender_zmq_port})" + ) + else: + logger.info(f"MoriTransferEngineConnector SENDER ready (ZMQ on {self.host}:{self.zmq_port})") + else: + if not self.sender_host or self.sender_host.lower() == "auto": + logger.info("MoriTransferEngineConnector RECEIVER: awaiting sender info via update_sender_info().") + else: + logger.info( + f"MoriTransferEngineConnector RECEIVER ready (sender at {self.sender_host}:{self.sender_zmq_port})" + ) + + # -------------------------------------------------------- public helpers + def get_connection_info(self) -> dict[str, Any]: + return { + "host": self.host, + "zmq_port": self.zmq_port, + "engine_key": self.engine_key, + "can_put": self.can_put, + } + + def update_sender_info(self, sender_host: str, sender_zmq_port: int) -> None: + """Inject the sender's ZMQ endpoint into the receiver connector.""" + self.sender_host = sender_host + self.sender_zmq_port = sender_zmq_port + logger.info(f"Sender info updated: host={sender_host!r}, zmq_port={sender_zmq_port}") + + # -------------------------------------------------- internal helpers + @staticmethod + def _get_local_ip() -> str: + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + except Exception: + try: + return socket.gethostbyname(socket.gethostname()) + except Exception: + return "127.0.0.1" + + def _get_req_socket(self, zmq_addr: str, timeout_ms: int) -> zmq.Socket: + cache: dict[str, zmq.Socket] | None = getattr(self._req_local, "cache", None) + if cache is None: + cache = {} + self._req_local.cache = cache + + sock = cache.get(zmq_addr) + if sock is None: + sock = self.zmq_ctx.socket(zmq.REQ) + sock.connect(zmq_addr) + cache[zmq_addr] = sock + sock.setsockopt(zmq.SNDTIMEO, timeout_ms) + sock.setsockopt(zmq.RCVTIMEO, timeout_ms) + return sock + + def _invalidate_req_socket(self, zmq_addr: str) -> None: + cache: dict[str, zmq.Socket] | None = getattr(self._req_local, "cache", None) + if cache is None: + return + sock = cache.pop(zmq_addr, None) + if sock is not None: + try: + sock.close(linger=0) + except Exception: + pass + + def _ensure_remote_registered(self, engine_desc_packed: bytes) -> str: + """Register a remote engine with the local IOEngine (idempotent).""" + desc = EngineDesc.unpack(engine_desc_packed) + key = desc.key + with self._registered_engines_lock: + if key not in self._registered_engines: + self.engine.register_remote_engine(desc) + self._registered_engines.add(key) + logger.debug(f"Registered remote Mori engine: {key}") + return key + + # ----------------------------------------------------------------- put() + def put( + self, + from_stage: str, + to_stage: str, + put_key: str, + data: Any, + ) -> tuple[bool, int, dict[str, Any] | None]: + if self._closed: + raise RuntimeError("Cannot put: MoriTransferEngineConnector is closed") + if not self.can_put: + logger.warning(f"Rejecting put for {put_key}: connector is receiver-only") + return False, 0, None + + put_key = self._make_key(put_key, from_stage, to_stage) + try: + src_offset = 0 + size = 0 + holder = None + is_fast_path = True + should_release = False + + # Reject empty payloads early + if isinstance(data, bytes) and len(data) == 0: + return False, 0, None + if isinstance(data, torch.Tensor) and data.nbytes == 0: + return False, 0, None + + # Serialize non-raw types + if not isinstance(data, (ManagedBuffer, torch.Tensor, bytes)): + data = OmniSerializer.serialize(data) + is_fast_path = False + + if isinstance(data, ManagedBuffer): + if data.pool_tensor.data_ptr() != self.pool.data_ptr(): + data = data.tensor.contiguous() + else: + src_offset = data.offset + size = data.size + holder = data + should_release = False + + if isinstance(data, (torch.Tensor, bytes)): + if isinstance(data, torch.Tensor): + size = data.nbytes + tensor_data = data + else: + size = len(data) + tensor_data = torch.frombuffer(data, dtype=torch.uint8) + + try: + offset = self.allocator.alloc(size) + holder = ManagedBuffer(self.allocator, offset, size, self.pool) + should_release = True + except MemoryError: + logger.error(f"Pool exhausted for {size} bytes") + return False, 0, None + + try: + dst_t = holder.tensor + if isinstance(data, torch.Tensor): + if not data.is_contiguous(): + data = data.contiguous() + src_view = data.view(torch.uint8).flatten() + if src_view.device != dst_t.device: + dst_t.copy_(src_view, non_blocking=True) + if src_view.is_cuda: + with torch.cuda.device(src_view.device): + torch.cuda.current_stream().synchronize() + elif dst_t.is_cuda: + with torch.cuda.device(dst_t.device): + torch.cuda.current_stream().synchronize() + else: + dst_t.copy_(src_view) + if dst_t.is_cuda: + with torch.cuda.device(dst_t.device): + torch.cuda.current_stream().synchronize() + else: + dst_t.copy_(tensor_data) + if dst_t.is_cuda: + with torch.cuda.device(dst_t.device): + torch.cuda.current_stream().synchronize() + except Exception as e: + holder.release() + logger.error(f"Failed to copy data to pool: {e}") + return False, 0, None + + src_offset = offset + + if size <= 0: + if should_release and isinstance(holder, ManagedBuffer): + holder.release() + return False, 0, None + + with self._local_buffers_lock: + old = self._local_buffers.pop(put_key, None) + if old: + _, _, oh, osr, _, _ = old + if osr and isinstance(oh, ManagedBuffer): + oh.release() + logger.warning(f"Released stale buffer for duplicate key: {put_key}") + # (offset, size, holder, should_release, is_fast_path, ts) + self._local_buffers[put_key] = ( + src_offset, + size, + holder, + should_release, + is_fast_path, + _time_mod.monotonic(), + ) + + metadata = { + "source_host": self.host, + "source_port": self.zmq_port, + "data_size": size, + "is_fast_path": is_fast_path, + } + self._metrics["puts"] += 1 + self._metrics["bytes_transferred"] += size + return True, size, metadata + + except Exception as e: + self._metrics["errors"] += 1 + logger.error(f"Put failed for {put_key}: {e}", exc_info=True) + return False, 0, None + + # ------------------------------------------------ query (no-metadata get) + def _query_metadata_from_sender(self, get_key: str) -> dict[str, Any] | None: + zmq_addr = f"tcp://{self.sender_host}:{self.sender_zmq_port}" + req_socket = self._get_req_socket(zmq_addr, timeout_ms=5000) + try: + q = QueryRequest(request_id=get_key) + req_socket.send(QUERY_INFO + msgspec.msgpack.encode(q)) + resp = req_socket.recv() + if resp == INFO_NOT_FOUND: + return None + qr = msgspec.msgpack.decode(resp, type=QueryResponse) + return { + "source_host": self.sender_host, + "source_port": self.sender_zmq_port, + "data_size": qr.data_size, + "is_fast_path": qr.is_fast_path, + } + except Exception as e: + self._invalidate_req_socket(zmq_addr) + logger.debug(f"Metadata query failed for {get_key}: {e}") + return None + + # ----------------------------------------------------------------- get() + def get( + self, + from_stage: str, + to_stage: str, + get_key: str, + metadata: dict[str, Any] | None = None, + ) -> tuple[Any, int] | None: + if self._closed: + raise RuntimeError("Cannot get: MoriTransferEngineConnector is closed") + + get_key = self._make_key(get_key, from_stage, to_stage) + _t0 = _time_mod.perf_counter() + + # Resolve metadata + if not metadata: + if not self.sender_host or not self.sender_zmq_port or str(self.sender_host).lower() == "auto": + raise RuntimeError("get(metadata=None) requires sender info. Call update_sender_info() first.") + metadata = self._query_metadata_from_sender(get_key) + if not metadata: + return None + + _t1 = _time_mod.perf_counter() + _query_ms = (_t1 - _t0) * 1000 + + src_host = metadata.get("source_host") + src_port = metadata.get("source_port") + data_size = metadata.get("data_size", 0) + is_fast_path = metadata.get("is_fast_path", False) + + if not src_host or not src_port or str(src_host).lower() == "auto": + logger.error(f"Invalid metadata for {get_key}") + return None + if data_size == 0: + logger.warning(f"Skipping get for {get_key}: data_size is 0") + return None + + # Allocate destination buffer + try: + offset = self.allocator.alloc(data_size) + recv_buf = ManagedBuffer(self.allocator, offset, data_size, self.pool) + except MemoryError: + logger.error(f"Failed to allocate {data_size} bytes for get") + return None + + _t2 = _time_mod.perf_counter() + _alloc_ms = (_t2 - _t1) * 1000 + + # Build pull request with Mori-specific EngineDesc + MemoryDesc + pull_req = MoriPullRequest( + request_id=get_key, + engine_desc_packed=self.engine_desc_packed, + mem_desc_packed=self.pool_mem_desc_packed, + dst_offset=offset, + length=data_size, + ) + + _base_timeout_ms = 30000 + _size_timeout_ms = max(0, (data_size // (100 * 1024 * 1024))) * 5000 + _total_timeout_ms = _base_timeout_ms + _size_timeout_ms + zmq_addr = f"tcp://{src_host}:{src_port}" + req_socket = self._get_req_socket(zmq_addr, timeout_ms=_total_timeout_ms) + + try: + req_socket.send(msgspec.msgpack.encode(pull_req)) + resp = req_socket.recv() + + _t3 = _time_mod.perf_counter() + _rdma_ms = (_t3 - _t2) * 1000 + + if resp == TRANS_DONE: + if self.pool.is_cuda: + with torch.cuda.device(self.pool.device): + torch.cuda.current_stream().synchronize() + + _t4 = _time_mod.perf_counter() + _sync_ms = (_t4 - _t3) * 1000 + + if is_fast_path: + _total_ms = (_time_mod.perf_counter() - _t0) * 1000 + _mbps = (data_size / 1024 / 1024) / (_total_ms / 1000) if _total_ms > 0 else 0 + logger.info( + f"[MORI GET] {get_key}: query={_query_ms:.1f}ms, " + f"alloc={_alloc_ms:.1f}ms, rdma={_rdma_ms:.1f}ms, " + f"sync={_sync_ms:.1f}ms, total={_total_ms:.1f}ms, " + f"{_mbps:.1f} MB/s (fast_path)" + ) + self._metrics["gets"] += 1 + self._metrics["bytes_transferred"] += data_size + return recv_buf, data_size + else: + try: + raw = recv_buf.to_bytes() + val = OmniSerializer.deserialize(raw) + _total_ms = (_time_mod.perf_counter() - _t0) * 1000 + _mbps = (data_size / 1024 / 1024) / (_total_ms / 1000) if _total_ms > 0 else 0 + logger.info(f"[MORI GET] {get_key}: total={_total_ms:.1f}ms, {_mbps:.1f} MB/s (deserialized)") + self._metrics["gets"] += 1 + self._metrics["bytes_transferred"] += data_size + return val, data_size + finally: + recv_buf.release() + else: + self._metrics["errors"] += 1 + logger.error(f"MORI get failed: received {resp} instead of TRANS_DONE") + recv_buf.release() + return None + except Exception as e: + self._invalidate_req_socket(zmq_addr) + self._metrics["timeouts"] += 1 + logger.error(f"MORI get error: {e}", exc_info=True) + recv_buf.release() + return None + + # ------------------------------------------------------------- cleanup() + def cleanup( + self, + request_id: str, + from_stage: str | None = None, + to_stage: str | None = None, + ) -> None: + if (from_stage is None) != (to_stage is None): + raise ValueError("cleanup() requires both from_stage and to_stage, or neither.") + if from_stage is not None and to_stage is not None: + request_id = self._make_key(request_id, from_stage, to_stage) + with self._local_buffers_lock: + item = self._local_buffers.pop(request_id, None) + if item: + _, _, holder, sr, _, _ = item + if sr and isinstance(holder, ManagedBuffer): + holder.release() + + # -------------------------------------------------------------- health() + def health(self) -> dict[str, Any]: + if self._closed: + return {"status": "unhealthy", "error": "Connector is closed"} + return { + "status": "healthy", + "host": self.host, + "pool_device": self.pool_device, + "pool_size": self.pool_size, + "engine_key": self.engine_key, + **self._metrics, + } + + # --------------------------------------------------------------- close() + def close(self) -> None: + if getattr(self, "_closed", True): + return + self._closed = True + logger.info("Closing MoriTransferEngineConnector...") + + self._stop_event.set() + + if self._listener_thread is not None and self._listener_thread.is_alive(): + self._listener_thread.join(timeout=2.0) + if self._listener_thread.is_alive(): + logger.warning("Listener thread did not stop gracefully") + + if self._sender_executor is not None: + self._sender_executor.shutdown(wait=True, cancel_futures=False) + + with self._local_buffers_lock: + for _k, item in list(self._local_buffers.items()): + _, _, holder, sr, _, _ = item + if sr and isinstance(holder, ManagedBuffer): + holder.release() + self._local_buffers.clear() + + cache: dict[str, zmq.Socket] | None = getattr(self._req_local, "cache", None) + if cache: + for _addr, sock in cache.items(): + try: + sock.close(linger=0) + except Exception: + pass + cache.clear() + + try: + if hasattr(self, "zmq_ctx"): + self.zmq_ctx.term() + except Exception as e: + logger.warning(f"Failed to terminate ZMQ context: {e}") + + self.pool = None # type: ignore[assignment] + logger.info("MoriTransferEngineConnector closed.") + + # ============================================================== LISTENER + def _cleanup_stale_buffers(self) -> None: + """Reclaim buffers older than ``_BUFFER_TTL_SECONDS``. + + Prevents permanent memory leaks when a receiver crashes or times out + without ever pulling the data. + + TODO(zejwang): In extreme rare case, long transfer time, there might + exist TTL cleanup vs in-flight RDMA transfer conflict, which will be + handled in a follow-up PR. Same race is acknowledged in + ``MooncakeTransferEngineConnector._cleanup_stale_buffers``. + """ + now = _time_mod.monotonic() + with self._local_buffers_lock: + stale = [k for k, v in self._local_buffers.items() if now - v[5] > _BUFFER_TTL_SECONDS] + for k in stale: + item = self._local_buffers.pop(k) + _, _, holder, sr, _, _ = item + if sr and isinstance(holder, ManagedBuffer): + holder.release() + logger.warning(f"TTL expired ({_BUFFER_TTL_SECONDS}s): reclaimed buffer for {k}") + + def _zmq_listener_loop(self) -> None: + sock = self.zmq_ctx.socket(zmq.ROUTER) + + try: + sock.bind(f"tcp://{self.host}:{self.zmq_port}") + except zmq.ZMQError as exc: + # Any bind failure (EADDRINUSE, EADDRNOTAVAIL, EACCES, etc.) is + # fatal for a sender — fail fast so __init__ propagates the error. + # There is no silent receiver fallback; roles are explicitly + # assigned (matches MooncakeTransferEngineConnector). + logger.error(f"ZMQ bind failed on {self.host}:{self.zmq_port}: {exc} (errno={exc.errno})") + self.can_put = False + self._bind_error = exc + self._listener_ready.set() + return + + self._listener_ready.set() + + notify_addr = f"inproc://mori-notify-{id(self)}" + notify_recv = self.zmq_ctx.socket(zmq.PULL) + notify_recv.bind(notify_addr) + + poller = zmq.Poller() + poller.register(sock, zmq.POLLIN) + poller.register(notify_recv, zmq.POLLIN) + + response_queue: queue.Queue = queue.Queue() + + try: + while not self._stop_event.is_set(): + try: + events = dict(poller.poll(1000)) + + if notify_recv in events: + while True: + try: + notify_recv.recv(zmq.NOBLOCK) + except zmq.Again: + break + + while True: + try: + identity, response = response_queue.get_nowait() + sock.send_multipart([identity, b"", response]) + except queue.Empty: + break + + now = _time_mod.monotonic() + if now - self._last_ttl_check >= 10.0: + self._last_ttl_check = now + self._cleanup_stale_buffers() + + if sock in events: + frames = sock.recv_multipart() + if len(frames) >= 2: + self._sender_executor.submit( + self._handle_pull_request, + response_queue, + notify_addr, + frames[0], + frames[-1], + ) + except zmq.ContextTerminated: + break + except Exception: + logger.debug("Listener loop error", exc_info=True) + finally: + try: + notify_recv.close(linger=0) + sock.close(linger=0) + except Exception: + pass + + def _handle_pull_request( + self, + response_queue: queue.Queue, + notify_addr: str, + identity: bytes, + payload: bytes, + ) -> None: + try: + if payload.startswith(QUERY_INFO): + self._handle_query_request( + response_queue, + notify_addr, + identity, + payload[len(QUERY_INFO) :], + ) + return + + pull = msgspec.msgpack.decode(payload, type=MoriPullRequest) + + with self._local_buffers_lock: + item = self._local_buffers.get(pull.request_id) + + if not item: + response_queue.put((identity, TRANS_ERROR)) + self._notify_listener(notify_addr) + return + + src_offset, src_size, _, _, _, _ = item + + # Validate against sender's own src_size before RDMA write: + # length < src_size silently truncates the payload; length > + # src_size reads past this allocation into adjacent in-flight + # buffers because pool_mem_desc covers the whole pool, not just + # this slot, so Mori does not bound-check. Refuse on mismatch + # rather than corrupt or leak data silently. + if pull.length != src_size: + logger.error( + "Length mismatch for %s: sender src_size=%d != receiver length=%d; refusing transfer", + pull.request_id, + src_size, + pull.length, + ) + response_queue.put((identity, TRANS_ERROR)) + self._notify_listener(notify_addr) + return + + # Register the receiver's IOEngine (idempotent) + self._ensure_remote_registered(pull.engine_desc_packed) + + # Reconstruct the receiver's pool MemoryDesc + remote_mem = MemoryDesc.unpack(pull.mem_desc_packed) + + # RDMA write from local pool → remote pool. Use src_size (sender's + # own record); it has just been validated to equal pull.length. + transfer_uid = self.engine.allocate_transfer_uid() + statuses = self.engine.batch_write( + [self.pool_mem_desc], + [[src_offset]], + [remote_mem], + [[pull.dst_offset]], + [[src_size]], + [transfer_uid], + ) + + success = True + for st in statuses: + st.Wait() + if st.Failed(): + logger.error(f"RDMA write failed for {pull.request_id}: {st.Message()}") + success = False + + if success: + self.cleanup(pull.request_id) + response_queue.put((identity, TRANS_DONE)) + else: + logger.warning(f"RDMA write failed for {pull.request_id}. Buffer retained for retry.") + response_queue.put((identity, TRANS_ERROR)) + + except Exception as e: + logger.error(f"Pull request handler error: {e}", exc_info=True) + response_queue.put((identity, TRANS_ERROR)) + + self._notify_listener(notify_addr) + + def _handle_query_request( + self, + response_queue: queue.Queue, + notify_addr: str, + identity: bytes, + payload: bytes, + ) -> None: + try: + q = msgspec.msgpack.decode(payload, type=QueryRequest) + with self._local_buffers_lock: + item = self._local_buffers.get(q.request_id) + if not item: + response_queue.put((identity, INFO_NOT_FOUND)) + else: + _, data_size, _, _, is_fast, _ = item + resp = QueryResponse( + request_id=q.request_id, + data_size=data_size, + is_fast_path=is_fast, + ) + response_queue.put((identity, msgspec.msgpack.encode(resp))) + except Exception as e: + logger.error(f"Query handler error: {e}") + response_queue.put((identity, INFO_NOT_FOUND)) + + self._notify_listener(notify_addr) + + def _notify_listener(self, notify_addr: str) -> None: + try: + local = self._worker_local + sock = getattr(local, "notify_socket", None) + cached_addr = getattr(local, "notify_addr", None) + if sock is None or cached_addr != notify_addr: + if sock is not None: + sock.close(linger=0) + sock = self.zmq_ctx.socket(zmq.PUSH) + sock.connect(notify_addr) + local.notify_socket = sock + local.notify_addr = notify_addr + sock.send(b"", zmq.NOBLOCK) + except Exception: + local.notify_socket = None + local.notify_addr = None diff --git a/vllm_omni/distributed/omni_connectors/factory.py b/vllm_omni/distributed/omni_connectors/factory.py index a680da953ee..9ba95a2e1e0 100644 --- a/vllm_omni/distributed/omni_connectors/factory.py +++ b/vllm_omni/distributed/omni_connectors/factory.py @@ -102,10 +102,22 @@ def _create_mooncake_transfer_engine_connector(config: dict[str, Any]) -> OmniCo return MooncakeTransferEngineConnector(config) +def _create_mori_transfer_engine_connector(config: dict[str, Any]) -> OmniConnectorBase: + try: + from .connectors.mori_transfer_engine_connector import MoriTransferEngineConnector + except ImportError: + import sys + + sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + from omni_connectors.connectors.mori_transfer_engine_connector import MoriTransferEngineConnector + return MoriTransferEngineConnector(config) + + # Register connectors OmniConnectorFactory.register_connector("MooncakeStoreConnector", _create_mooncake_store_connector) OmniConnectorFactory.register_connector("MooncakeTransferEngineConnector", _create_mooncake_transfer_engine_connector) OmniConnectorFactory.register_connector("SharedMemoryConnector", _create_shm_connector) OmniConnectorFactory.register_connector("YuanrongConnector", _create_yuanrong_connector) +OmniConnectorFactory.register_connector("MoriTransferEngineConnector", _create_mori_transfer_engine_connector) # Backward-compatible aliases – will be removed in the future OmniConnectorFactory.register_connector("MooncakeConnector", _create_mooncake_store_connector) diff --git a/vllm_omni/distributed/omni_connectors/utils/initialization.py b/vllm_omni/distributed/omni_connectors/utils/initialization.py index f012af3c9c3..e2313d8c897 100644 --- a/vllm_omni/distributed/omni_connectors/utils/initialization.py +++ b/vllm_omni/distributed/omni_connectors/utils/initialization.py @@ -4,6 +4,7 @@ """Utilities for OmniConnector configuration and validation.""" import json +import socket import sys from pathlib import Path from typing import TYPE_CHECKING, Any @@ -28,6 +29,44 @@ # Formula: zmq_port = base + KV_TRANSFER_PORT_OFFSET + rank * STRIDE + stage KV_RANK_PORT_STRIDE = 16 +# Connector types that carry a ZMQ-based sender/receiver handshake and +# therefore need per-edge endpoint derivation on the +# ``OmniChunkTransferAdapter`` path. Without this, a 3-stage intranode +# pipeline would instantiate three co-located senders trying to bind the +# same base ``zmq_port`` from yaml, and every receiver / dual-role +# instance would start with no ``sender_host`` / ``sender_zmq_port`` and +# therefore no way to reach its upstream listener. +# +# The orchestrator-level path (``create_connectors_from_config`` below) +# has its own Mooncake-specific port adjustment for PD disaggregation +# and is intentionally independent of this set. +_ROLE_BOUND_ZMQ_CONNECTORS = frozenset( + { + "MoriTransferEngineConnector", + "MooncakeTransferEngineConnector", + } +) + + +def _detect_local_ip() -> str: + """Best-effort local IP for framework-level endpoint derivation. + + Mirrors the connector-side ``_get_local_ip`` behaviour (used by + Mori/Mooncake when ``host: "auto"``) so that, on an intranode + pipeline, a sender's advertised listener host and its downstream + receiver's precomputed upstream host agree without an explicit + yaml setting. + """ + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + except Exception: + try: + return socket.gethostbyname(socket.gethostname()) + except Exception: + return "127.0.0.1" + def initialize_connectors_from_config( config_path: str | Path | None = None, @@ -132,44 +171,174 @@ def create_connectors_from_config( return connectors -def get_connectors_config_for_stage(transfer_config: OmniTransferConfig | None, stage_id: str | int) -> dict[str, Any]: +def _primary_upstream_stage(incoming_edges: list[tuple[str, ConnectorSpec]]) -> str | None: + """Return the lowest-stage-id upstream edge id, or *None* if there are no + incoming edges. + + Used by ``_inject_chunk_path_endpoints`` to pick which upstream sender a + dual / receiver instance should point at when a (future) fan-in topology + has more than one. Linear pipelines have exactly one incoming edge so + the choice is unambiguous today; picking the sorted-first keeps behaviour + deterministic if fan-in is introduced later. + """ + if not incoming_edges: + return None + try: + return sorted(incoming_edges, key=lambda e: int(e[0]))[0][0] + except (TypeError, ValueError): + return sorted(incoming_edges, key=lambda e: str(e[0]))[0][0] + + +def _inject_chunk_path_endpoints( + extra: dict[str, Any], + connector_name: str, + role: str, + own_stage: str, + upstream_stage: str | None, +) -> None: + """Compute per-edge ZMQ endpoints in-place for chunk-adapter connectors. + + Populates, for role-bound ZMQ connectors listed in + ``_ROLE_BOUND_ZMQ_CONNECTORS``: + + * ``zmq_port`` (sender / dual only): bind port for the local listener + = base_port + own_stage_id. Offsetting by stage id prevents multiple + co-located stage listeners from all binding the same base port on a + single-node pipeline. + * ``sender_zmq_port`` (receiver / dual only): upstream sender's bind + port = base_port + upstream_stage_id. + * ``sender_host`` (receiver / dual only): defaults to the + framework-detected local IP when the yaml uses ``host: "auto"``, + or to the yaml's explicit ``host`` value otherwise -- so an + intranode receiver can locate its upstream sender without an + external handshake. + + ``base_port`` is read from ``extra["zmq_port"]`` (default 50051). + Values already present in *extra* (explicit yaml override) are kept, + so cross-node deployments that already know the peer endpoint can + set ``sender_host`` / ``sender_zmq_port`` directly and are unaffected. + + For non-role-bound connectors (``SharedMemoryConnector``, + ``YuanrongConnector``, ...) this function is a no-op. Non-integer + stage ids (unusual pipeline keys such as ``"prefill"``) also return + without side effects. """ - Extract connector configurations relevant for a specific stage worker. + if connector_name not in _ROLE_BOUND_ZMQ_CONNECTORS: + return + try: + own_stage_id = int(own_stage) + except (TypeError, ValueError): + return + base_port = int(extra.get("zmq_port", 50051)) - Returns a dict compatible with worker initialization: - { - "from_stage_X": { - "spec": { - "name": "ConnectorName", - "extra": {...} - } - }, - ... - } + if role in ("sender", "dual"): + extra["zmq_port"] = base_port + own_stage_id + + if role in ("receiver", "dual") and upstream_stage is not None: + try: + upstream_stage_id = int(upstream_stage) + except (TypeError, ValueError): + upstream_stage_id = None + if upstream_stage_id is not None: + if "sender_zmq_port" not in extra or extra["sender_zmq_port"] is None: + extra["sender_zmq_port"] = base_port + upstream_stage_id + if not extra.get("sender_host"): + host_cfg = extra.get("host", "auto") + if isinstance(host_cfg, str) and host_cfg.lower() == "auto": + extra["sender_host"] = _detect_local_ip() + else: + extra["sender_host"] = host_cfg + + +def get_connectors_config_for_stage(transfer_config: OmniTransferConfig | None, stage_id: str | int) -> dict[str, Any]: + """Extract connector configurations relevant for a specific stage worker. + + Returns a dict shape compatible with worker initialization:: + + { + "from_stage_X": {"spec": {"name": ..., "extra": {..., "role": ...}}}, + "to_stage_Y": {"spec": {"name": ..., "extra": {..., "role": ...}}}, + ... + } + + The per-stage role is chosen by looking at which edges the stage + participates in: + + * stage has only outgoing edges -> ``role="sender"`` + * stage has only incoming edges -> ``role="receiver"`` + * stage has both (middle stage in a 3+ stage pipeline) -> + ``role="dual"``, and *both* the ``from_stage_*`` and + ``to_stage_*`` entries share the same composite extra (same + ``role``, same ``zmq_port``, same ``sender_host`` / + ``sender_zmq_port``) so that downstream flattening (e.g. + ``engine/stage_init_utils.get_stage_connector_spec`` returning + the first spec it sees) always recovers a self-consistent + per-stage connector config. + + For role-bound ZMQ connectors (see ``_ROLE_BOUND_ZMQ_CONNECTORS``) + per-edge ``zmq_port`` / ``sender_host`` / ``sender_zmq_port`` are + pre-computed from the sender's stage id so a co-located intranode + pipeline does not need an external handshake. Role-neutral + connectors (``SharedMemoryConnector``, ``YuanrongConnector``) are + unaffected by the endpoint injection but still receive the + direction-specific ``role`` key (connectors that ignore it keep + working verbatim). + + The orchestrator-level ``build_stage_connectors`` only filters + ``from_stage_*`` keys, so orchestrator instantiation behaviour is + bit-for-bit unchanged. """ if not transfer_config: return {} - stage_connectors_config = {} target_stage = str(stage_id) - # Iterate through all configured edges and inject direction-specific role. - # The shared edge-level ConnectorSpec is role-neutral; each stage gets - # the correct role ("sender" or "receiver") based on its position in - # the edge so that MooncakeTransferEngineConnector (and any future - # role-aware connector) initializes correctly. - for (from_stage, to_stage), spec in transfer_config.connectors.items(): - if to_stage == target_stage: - # Incoming edge → this stage is the receiver - extra = dict(spec.extra) if spec.extra else {} - extra.setdefault("role", "receiver") - stage_connectors_config[f"from_stage_{from_stage}"] = {"spec": {"name": spec.name, "extra": extra}} - elif from_stage == target_stage and target_stage == "0": - # Outgoing edge for stage 0 — included for async_chunk spec - # extraction (omni_stage.py), NOT for connector instantiation. - extra = dict(spec.extra) if spec.extra else {} - extra.setdefault("role", "sender") - stage_connectors_config[f"to_stage_{to_stage}"] = {"spec": {"name": spec.name, "extra": extra}} + incoming_edges: list[tuple[str, ConnectorSpec]] = [ + (from_s, spec) for (from_s, to_s), spec in transfer_config.connectors.items() if to_s == target_stage + ] + outgoing_edges: list[tuple[str, ConnectorSpec]] = [ + (to_s, spec) for (from_s, to_s), spec in transfer_config.connectors.items() if from_s == target_stage + ] + + if not incoming_edges and not outgoing_edges: + return {} + + if incoming_edges and outgoing_edges: + effective_role = "dual" + elif outgoing_edges: + effective_role = "sender" + else: + effective_role = "receiver" + + upstream_stage = _primary_upstream_stage(incoming_edges) + + stage_connectors_config: dict[str, Any] = {} + + # Incoming edges -> this stage is at the ``to`` end. + for from_s, spec in incoming_edges: + extra = dict(spec.extra) if spec.extra else {} + extra.setdefault("role", effective_role) + _inject_chunk_path_endpoints( + extra, + connector_name=spec.name, + role=effective_role, + own_stage=target_stage, + upstream_stage=upstream_stage, + ) + stage_connectors_config[f"from_stage_{from_s}"] = {"spec": {"name": spec.name, "extra": extra}} + + # Outgoing edges -> this stage is at the ``from`` end. + for to_s, spec in outgoing_edges: + extra = dict(spec.extra) if spec.extra else {} + extra.setdefault("role", effective_role) + _inject_chunk_path_endpoints( + extra, + connector_name=spec.name, + role=effective_role, + own_stage=target_stage, + upstream_stage=upstream_stage, + ) + stage_connectors_config[f"to_stage_{to_s}"] = {"spec": {"name": spec.name, "extra": extra}} return stage_connectors_config diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index e670a52d8a2..a46f9acc286 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1069,6 +1069,7 @@ async def _run_orchestrator() -> None: self._initialize_janus_queues() self._initialize_stages(stage_init_timeout) + pd_config = self._detect_pd_config() orchestrator = Orchestrator( request_async_queue=self.request_queue.async_q,