From f33501075403f3379d2bf92d1df0f25792396a3a Mon Sep 17 00:00:00 2001 From: natureofnature Date: Mon, 13 Apr 2026 15:18:23 +0800 Subject: [PATCH 1/3] connector: support 1-receiver-to-N-senders, SHM metadata fallback & cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mooncake RDMA connector — multi-source receiving: - Add per-rank sender endpoint registry (_sender_endpoints) and update_sender_info(sender_rank=...) for N-sender registration - Add _resolve_sender_endpoint() for rank-based endpoint routing - Add get() Path 2: partial metadata (host/port only, no data_size) queries the specified sender then RDMA pulls — enables heterogeneous TP where one receiver pulls KV shards from multiple sender ranks - Extract _query_metadata_at() from _query_metadata_from_sender() to deduplicate ZMQ query logic (~53 lines saved) - Fix data_size check from falsy to "data_size" not in metadata SHM connector — metadata fallback & lifecycle: - Add _get_by_key() fallback when metadata lacks SHM handles (e.g. RDMA-style metadata passed to SHM connector) - Track _pending_keys for cleanup(request_id) and close() lifecycle Other: - base.py: document metadata parameter semantics for heterogeneous TP - mooncake_store_connector: align with updated connector interface - initialization: add KV_RANK_PORT_STRIDE constant for per-rank ZMQ port - tests: add test_shm_connector covering key-based R/W, metadata fallback, heterogeneous TP multi-key, and cleanup/close Signed-off-by: natureofnature --- .../omni_connectors/test_shm_connector.py | 184 ++++++++++++++++++ .../omni_connectors/connectors/base.py | 6 + .../connectors/mooncake_store_connector.py | 19 +- .../mooncake_transfer_engine_connector.py | 172 ++++++++++------ .../connectors/shm_connector.py | 105 +++++++--- .../omni_connectors/utils/initialization.py | 5 + 6 files changed, 405 insertions(+), 86 deletions(-) create mode 100644 tests/distributed/omni_connectors/test_shm_connector.py diff --git a/tests/distributed/omni_connectors/test_shm_connector.py b/tests/distributed/omni_connectors/test_shm_connector.py new file mode 100644 index 00000000000..e702318e3f3 --- /dev/null +++ b/tests/distributed/omni_connectors/test_shm_connector.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for SharedMemoryConnector focusing on TP / CFG / metadata fallback.""" + +import pytest + +from vllm_omni.distributed.omni_connectors.connectors.shm_connector import ( + SharedMemoryConnector, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +@pytest.fixture() +def connector(): + c = SharedMemoryConnector({"shm_threshold_bytes": 64}) + yield c + c.close() + + +# ── Key-based read (the fundamental SHM path) ──────────────────────── + + +class TestKeyBasedReadWrite: + def test_put_then_get_by_key(self, connector): + data = {"hello": "world", "n": 42} + ok, size, meta = connector.put("s0", "s1", "test_key_1", data) + assert ok + assert size > 0 + assert "shm" in meta + assert "test_key_1" in connector._pending_keys + + result = connector.get("s0", "s1", "test_key_1", metadata=None) + assert result is not None + obj, rsize = result + assert obj == data + assert rsize == size + assert "test_key_1" not in connector._pending_keys + + def test_get_nonexistent_key_returns_none(self, connector): + result = connector.get("s0", "s1", "no_such_key_xyz", metadata=None) + assert result is None + + def test_rank_aware_keys_independent(self, connector): + """Each TP rank writes/reads its own key — simulates homogeneous TP.""" + payloads = {} + for rank in range(4): + key = f"req1_s0_0_{rank}_{rank}" + data = {"rank": rank, "values": list(range(rank, rank + 3))} + ok, _, _ = connector.put("s0", "s1", key, data) + assert ok + payloads[rank] = data + + for rank in range(4): + key = f"req1_s0_0_{rank}_{rank}" + result = connector.get("s0", "s1", key, metadata=None) + assert result is not None + obj, _ = result + assert obj == payloads[rank] + + +# ── Metadata fallback behaviour ────────────────────────────────────── + + +class TestMetadataFallback: + def test_rdma_style_metadata_falls_back_to_key(self, connector): + """source_host/source_port metadata should be ignored; key read used.""" + data = {"payload": True} + connector.put("s0", "s1", "fb_key_1", data) + + rdma_meta = {"source_host": "10.0.0.1", "source_port": 12345} + result = connector.get("s0", "s1", "fb_key_1", metadata=rdma_meta) + assert result is not None + obj, _ = result + assert obj == data + + def test_non_dict_metadata_falls_back_to_key(self, connector): + data = {"val": 99} + connector.put("s0", "s1", "fb_key_2", data) + + result = connector.get("s0", "s1", "fb_key_2", metadata="not_a_dict") + assert result is not None + obj, _ = result + assert obj == data + + def test_empty_dict_metadata_falls_back_to_key(self, connector): + data = {"x": 1} + connector.put("s0", "s1", "fb_key_3", data) + + result = connector.get("s0", "s1", "fb_key_3", metadata={}) + assert result is not None + obj, _ = result + assert obj == data + + def test_shm_handle_metadata_still_works(self, connector): + """When metadata contains a proper 'shm' handle, use it directly.""" + data = {"direct": True} + ok, size, meta = connector.put("s0", "s1", "shm_direct_1", data) + assert ok + result = connector.get("s0", "s1", "shm_direct_1", metadata=meta) + assert result is not None + obj, _ = result + assert obj == data + + def test_metadata_keyed_by_request_id(self, connector): + """Metadata wrapped as {get_key: actual_meta} should be unwrapped.""" + data = {"wrapped": True} + ok, size, meta = connector.put("s0", "s1", "wrap_key", data) + assert ok + wrapped = {"wrap_key": meta} + result = connector.get("s0", "s1", "wrap_key", metadata=wrapped) + assert result is not None + obj, _ = result + assert obj == data + + +# ── Heterogeneous TP multi-key read ────────────────────────────────── + + +class TestHeteroTPMultiKey: + def test_receiver_reads_multiple_sender_keys(self, connector): + """Simulates from_tp=2 -> to_tp=1: receiver reads 2 keys and merges.""" + for sender_rank in range(2): + key = f"req1_s0_0_{sender_rank}_0" + data = {"sender": sender_rank, "shard": [sender_rank * 10]} + connector.put("s0", "s1", key, data) + + shards = [] + for sender_rank in range(2): + key = f"req1_s0_0_{sender_rank}_0" + result = connector.get("s0", "s1", key, metadata=None) + assert result is not None + obj, _ = result + shards.append(obj) + + assert len(shards) == 2 + assert shards[0]["sender"] == 0 + assert shards[1]["sender"] == 1 + + def test_sender_writes_multiple_receiver_keys(self, connector): + """Simulates from_tp=1 -> to_tp=2: sender writes 2 sliced keys.""" + for recv_rank in range(2): + key = f"req1_s0_0_0_{recv_rank}" + data = {"target": recv_rank, "slice": list(range(recv_rank, recv_rank + 2))} + connector.put("s0", "s1", key, data) + + for recv_rank in range(2): + key = f"req1_s0_0_0_{recv_rank}" + result = connector.get("s0", "s1", key, metadata=None) + assert result is not None + obj, _ = result + assert obj["target"] == recv_rank + + +# ── Cleanup ────────────────────────────────────────────────────────── + + +class TestCleanup: + def test_cleanup_removes_unconsumed_segment(self, connector): + data = {"leak": True} + connector.put("s0", "s1", "cleanup_req_42", data) + assert "cleanup_req_42" in connector._pending_keys + + connector.cleanup("req_42") + assert "cleanup_req_42" not in connector._pending_keys + + result = connector.get("s0", "s1", "cleanup_req_42", metadata=None) + assert result is None + + def test_cleanup_noop_for_consumed_segment(self, connector): + data = {"consumed": True} + connector.put("s0", "s1", "consumed_req_99", data) + connector.get("s0", "s1", "consumed_req_99", metadata=None) + + connector.cleanup("req_99") + assert "consumed_req_99" not in connector._pending_keys + + def test_close_cleans_all_pending(self, connector): + for i in range(3): + connector.put("s0", "s1", f"close_test_{i}", {"i": i}) + + assert len(connector._pending_keys) == 3 + connector.close() + assert len(connector._pending_keys) == 0 diff --git a/vllm_omni/distributed/omni_connectors/connectors/base.py b/vllm_omni/distributed/omni_connectors/connectors/base.py index 83edb2ab0ae..23748400d07 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/base.py +++ b/vllm_omni/distributed/omni_connectors/connectors/base.py @@ -41,6 +41,12 @@ def get(self, from_stage: str, to_stage: str, get_key: str, metadata=None) -> tu from_stage: Source stage identifier to_stage: Destination stage identifier get_key: Unique request identifier + metadata: Optional transport-specific metadata. When provided, + the connector uses it directly (e.g. source_host, source_port, + data_size) instead of querying the sender. For heterogeneous + TP the manager may supply partial metadata (host/port only); + the connector will query the sender at that address to fill + in data_size. Returns: Tuple of (Python object, serialized byte size) if found, None otherwise diff --git a/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py index c672e35f793..fa1fc3286db 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py @@ -78,7 +78,24 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[ try: serialized_data = self.serialize_obj(data) key = self._make_key(put_key, from_stage, to_stage) - self.store.put(key, serialized_data, self.pin) + put_rc = self.store.put(key, serialized_data, self.pin) + + if isinstance(put_rc, bool): + put_ok = put_rc + else: + put_ok = put_rc is None or put_rc == 0 + + if not put_ok: + self._metrics["errors"] += 1 + logger.error( + "MooncakeStoreConnector put failed for %s (%s -> %s), rc=%r, %d bytes", + key, + from_stage, + to_stage, + put_rc, + len(serialized_data), + ) + return False, 0, None self._metrics["puts"] += 1 self._metrics["bytes_transferred"] += len(serialized_data) diff --git a/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py index 96a528963f4..84224c6738d 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py @@ -230,16 +230,19 @@ class MooncakeTransferEngineConnector(OmniConnectorBase): sender immediately cleans up the buffer (``cleanup()``), so only the first receiver to pull a given key will succeed. Broadcast / multicast (1 sender → N receivers sharing the same data) is not yet supported. - - **1 receiver → 1 sender**: ``update_sender_info()`` stores a single - ``(sender_host, sender_zmq_port)`` pair, so a receiver can only query - metadata from one sender at a time. + - **1 receiver → N senders**: Supported via partial metadata. The + manager constructs metadata with the target sender's + ``source_host`` / ``source_port`` (computed from ``from_rank``) + and passes it to ``get(metadata=...)``. The connector detects + that ``data_size`` is missing, queries the specified sender at + the given address to fill it in, then performs the RDMA pull. + This enables heterogeneous TP (sender TP > receiver TP) where a + single receiver must pull KV shards from multiple sender ranks. Future work: - Support 1 sender → N receivers (e.g. reference-counted buffers, or explicit ``retain()`` / ``release()`` semantics so the buffer survives multiple pulls). - - Support 1 receiver → N senders (e.g. a sender registry mapping - ``get_key`` prefixes to different sender endpoints). """ # RDMA connector copies raw bytes/tensor directly to the memory pool @@ -267,6 +270,7 @@ def __init__(self, config: dict[str, Any]): self._req_local = threading.local() self._worker_local = threading.local() self._last_ttl_check: float = _time_mod.monotonic() + self._sender_endpoints: dict[int, tuple[str, int]] = {} self._metrics = { "puts": 0, @@ -408,16 +412,38 @@ def get_connection_info(self) -> dict[str, Any]: "can_put": self.can_put, } - def update_sender_info(self, sender_host: str, sender_zmq_port: int) -> None: - """ - Inject the sender's ZMQ endpoint into the receiver connector. - Used for NO METADATA GET calls.(E.g: KV-cache transfer path) - Must be called before using get() without metadata! - Otherwise, get() will raise an error. + def update_sender_info( + self, + sender_host: str, + sender_zmq_port: int, + sender_rank: int | None = None, + ) -> None: + """Inject a sender's ZMQ endpoint into the receiver connector. + + When ``sender_rank`` is ``None`` (default), sets the single default + sender used by ``get()`` when no rank is specified — this preserves + backward-compatible 1:1 semantics. + + When ``sender_rank`` is an integer, the endpoint is stored in a + per-rank registry for internal use (e.g. by + ``_query_metadata_from_sender(sender_rank=R)``). """ - self.sender_host = sender_host - self.sender_zmq_port = sender_zmq_port - logger.info(f"Sender info updated: host={sender_host!r}, zmq_port={sender_zmq_port}") + if sender_rank is not None: + self._sender_endpoints[sender_rank] = (sender_host, sender_zmq_port) + logger.info( + "Sender info updated for rank %s: host=%r, zmq_port=%s", + sender_rank, + sender_host, + sender_zmq_port, + ) + else: + self.sender_host = sender_host + self.sender_zmq_port = sender_zmq_port + logger.info( + "Sender info updated (default): host=%r, zmq_port=%s", + sender_host, + sender_zmq_port, + ) def _get_local_ip(self) -> str: """ @@ -657,56 +683,75 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[ logger.error(f"RDMA Put failed for {put_key}: {e}", exc_info=True) return False, 0, None - def _query_metadata_from_sender(self, get_key: str) -> dict[str, Any] | None: - """Query metadata from sender via ZMQ (fallback when ``metadata=None``). - - ``get()`` supports two metadata resolution paths:: + def _resolve_sender_endpoint(self, sender_rank: int | None = None) -> tuple[str, int] | None: + """Return ``(host, zmq_port)`` for *sender_rank*. - get(metadata=?) - ├── metadata provided (adapter path) - │ → use metadata directly (source_host/port/data_size) - │ → RDMA pull - └── metadata=None (KV-transfer polling path) - → _query_metadata_from_sender(get_key) ← this method - │ - ├── sender_host resolved (via update_sender_info) - │ → ZMQ query → get data_size/is_fast_path - │ → construct metadata → RDMA pull - └── sender_host unresolved ("auto" / None) - → return None → caller retries or times out - - For the second path, the caller must call - :meth:`update_sender_info` before ``get()`` to resolve the sender's ZMQ endpoint. - Support the two paths in case that the orchestrator pushes the request info - to different stages at the same time knowing metadata or not. + Resolution order: + 1. Per-rank registry (``_sender_endpoints[sender_rank]``) + 2. Default sender (``sender_host`` / ``sender_zmq_port``) + 3. ``None`` if nothing is configured. """ - zmq_addr = f"tcp://{self.sender_host}:{self.sender_zmq_port}" + if sender_rank is not None and sender_rank in self._sender_endpoints: + return self._sender_endpoints[sender_rank] + host = getattr(self, "sender_host", None) + port = getattr(self, "sender_zmq_port", None) + if host and port and str(host).lower() != "auto": + return (host, int(port)) + return None + + def _query_metadata_at(self, get_key: str, host: str, port: int) -> dict[str, Any] | None: + """Query metadata from a sender endpoint via ZMQ. + + Returns ``{source_host, source_port, data_size, is_fast_path}`` + or ``None`` when the key is not found / the query fails. + """ + zmq_addr = f"tcp://{host}:{port}" req_socket = self._get_req_socket(zmq_addr, timeout_ms=5000) - try: - # Send query request - query = QueryRequest(request_id=get_key) - req_socket.send(QUERY_INFO + msgspec.msgpack.encode(query)) + req_socket.send(QUERY_INFO + msgspec.msgpack.encode(QueryRequest(request_id=get_key))) resp = req_socket.recv() - if resp == INFO_NOT_FOUND: return None - - # Parse response query_resp = msgspec.msgpack.decode(resp, type=QueryResponse) return { - # source_host/source_port are used for verification - "source_host": self.sender_host, - "source_port": self.sender_zmq_port, + "source_host": host, + "source_port": port, "data_size": query_resp.data_size, "is_fast_path": query_resp.is_fast_path, } except Exception as e: - # Socket may be stuck in bad state after timeout; discard it self._invalidate_req_socket(zmq_addr) - logger.debug(f"Failed to query metadata for {get_key}: {e}") + logger.debug("Failed to query metadata at %s for %s: %s", zmq_addr, get_key, e) return None + def _query_metadata_from_sender(self, get_key: str, sender_rank: int | None = None) -> dict[str, Any] | None: + """Query metadata from sender via ZMQ (fallback when ``metadata=None``). + + ``get()`` supports three metadata resolution paths:: + + get(metadata=?) + ├── Path 1: metadata has data_size (adapter path) + │ → use metadata directly → RDMA pull + ├── Path 2: metadata has source_host/port but no data_size + │ → _query_metadata_at(host, port) → get data_size → RDMA pull + └── Path 3: metadata=None (KV-transfer polling path) + → _query_metadata_from_sender(get_key) ← this method + │ + ├── sender endpoint resolved (via update_sender_info) + │ → ZMQ query → get data_size/is_fast_path + │ → construct metadata → RDMA pull + └── sender endpoint unresolved + → return None → caller retries or times out + + When *sender_rank* is provided, the query is routed to that + rank's endpoint (registered via ``update_sender_info(rank=...)``). + Otherwise the default sender is used. + """ + endpoint = self._resolve_sender_endpoint(sender_rank) + if endpoint is None: + return None + return self._query_metadata_at(get_key, *endpoint) + def get( self, from_stage: str, @@ -714,12 +759,18 @@ def get( get_key: str, metadata: dict[str, Any] | None = None, ) -> tuple[Any, int] | None: - """ - Consumer Side. - Allocates from local pool and pulls data via RDMA. + """Consumer Side. Allocates from local pool and pulls data via RDMA. + + Metadata resolution: - If metadata is not provided, will attempt to query it from sender - using configured sender_host/sender_zmq_port. + 1. ``metadata`` provided **with** ``data_size`` → use directly (RDMA pull). + 2. ``metadata`` provided with ``source_host``/``source_port`` but + **without** ``data_size`` → query that specific sender for + ``data_size`` / ``is_fast_path``, then RDMA pull. This is the + heterogeneous-TP path where the manager knows the target sender + endpoint but not the payload size. + 3. ``metadata=None`` → query the default sender (set via + ``update_sender_info()``) for the full metadata. Returns: ``(data, size)`` on success, ``None`` on failure. @@ -727,9 +778,6 @@ def get( - **is_fast_path=True** (tensor *or* bytes payload): Returns ``(ManagedBuffer, size)``. **CALLER MUST call ``ManagedBuffer.release()`` after consuming.** - Note: even if the producer ``put()`` raw ``bytes``, the consumer - receives a ``ManagedBuffer`` — use ``buf.to_bytes()`` to obtain - a ``bytes`` copy, or ``buf.tensor`` for zero-copy access. - **is_fast_path=False** (serialized Python object): Returns ``(DeserializedObject, size)``. Buffer is auto-released internally after deserialization. @@ -741,9 +789,8 @@ def get( _t0 = _time_mod.perf_counter() - # If no metadata provided, try to query from sender if not metadata: - # Must insert sender info before using get() without metadata. + # Path 3: no metadata at all — query default sender if not self.sender_host or not self.sender_zmq_port or str(self.sender_host).lower() == "auto": raise RuntimeError( f"get(metadata=None) requires sender info to be resolved, " @@ -753,6 +800,15 @@ def get( metadata = self._query_metadata_from_sender(get_key) if not metadata: return None + elif "data_size" not in metadata: + # Path 2: partial metadata (host/port only) — query that sender + partial_host = metadata.get("source_host") + partial_port = metadata.get("source_port") + if partial_host and partial_port: + queried = self._query_metadata_at(get_key, str(partial_host), int(partial_port)) + if not queried: + return None + metadata = queried _t1 = _time_mod.perf_counter() _query_ms = (_t1 - _t0) * 1000 diff --git a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py index 5c7384c1f8b..6468cf2c85b 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py @@ -15,9 +15,13 @@ class SharedMemoryConnector(OmniConnectorBase): - """ - Connector that uses SharedMemory for large objects and inline data for small objects. - Acts as a unified replacement for the legacy IPC fallback logic. + """Key-addressed local shared-memory connector. + + SHM is a local-only transport: it reads/writes POSIX shared memory + segments identified purely by *key*. It does **not** understand + remote-transport metadata such as ``source_host`` / ``source_port`` + (that is the RDMA connector's job). When such metadata is passed in, + the connector silently falls back to key-based lookup. """ def __init__(self, config: dict[str, Any]): @@ -25,6 +29,7 @@ def __init__(self, config: dict[str, Any]): self.stage_id = config.get("stage_id", -1) self.device = config.get("device", "cuda:0") self.threshold = int(config.get("shm_threshold_bytes", 65536)) + self._pending_keys: set[str] = set() self._metrics = { "puts": 0, "gets": 0, @@ -59,6 +64,7 @@ def put( # meta contains {'name': ..., 'size': ...} metadata = {"shm": meta, "size": size} + self._pending_keys.add(put_key) self._metrics["shm_writes"] += 1 else: # Inline - pass bytes directly to avoid double serialization of the object @@ -93,6 +99,25 @@ def _get_data_with_lock(self, lock_file: str, shm_handle: dict): if obj and os.path.exists(lock_file): os.remove(lock_file) + def _get_by_key(self, get_key: str) -> tuple[Any, int] | None: + """Read a SHM segment addressed purely by *get_key*.""" + shm = None + try: + shm = shm_pkg.SharedMemory(name=get_key) + if shm is None or shm.size == 0: + return None + lock_file = f"/dev/shm/shm_{get_key}_lockfile.lock" + shm_handle = {"name": get_key, "size": shm.size} + result = self._get_data_with_lock(lock_file, shm_handle) + if result is not None: + self._pending_keys.discard(get_key) + return result + except Exception: + return None + finally: + if shm: + shm.close() + def get( self, from_stage: str, @@ -101,16 +126,16 @@ def get( metadata=None, ) -> tuple[Any, int] | None: if metadata is not None: - # Some callers may wrap metadata by request id. if isinstance(metadata, dict) and get_key in metadata: metadata = metadata.get(get_key) if not isinstance(metadata, dict): - return None + return self._get_by_key(get_key) if "inline_bytes" in metadata: try: obj = self.deserialize_obj(metadata["inline_bytes"]) + self._pending_keys.discard(get_key) return obj, int(metadata.get("size", 0)) except Exception as e: logger.error(f"SharedMemoryConnector inline get failed for req {get_key}: {e}") @@ -119,33 +144,59 @@ def get( if "shm" in metadata: shm_handle = metadata["shm"] lock_file = f"/dev/shm/shm_{shm_handle['name']}_lockfile.lock" - return self._get_data_with_lock(lock_file, shm_handle) + result = self._get_data_with_lock(lock_file, shm_handle) + if result is not None: + self._pending_keys.discard(get_key) + return result - return None - shm = None - try: - shm = shm_pkg.SharedMemory(name=get_key) - if shm is None or shm.size == 0: - return None - lock_file = f"/dev/shm/shm_{get_key}_lockfile.lock" - shm_handle = {"name": get_key, "size": shm.size} - return self._get_data_with_lock(lock_file, shm_handle) - except Exception: - return None - finally: - if shm: - shm.close() + # Metadata is a dict but has no SHM-specific handle (e.g. RDMA- + # style source_host/source_port). Fall back to key-based read. + return self._get_by_key(get_key) + + return self._get_by_key(get_key) def cleanup(self, request_id: str) -> None: - # SHM segments are automatically unlinked during 'get' (shm_read_bytes). - # If 'get' is never called (e.g. error flow), the SHM segment might leak. - # A robust implementation might track created segments and unlink them here - # if they haven't been consumed. - # For now, we rely on the consumer to read and unlink. - pass + """Best-effort cleanup of unconsumed SHM segments for *request_id*. + + If ``put()`` wrote a segment keyed by *request_id* (or containing it) + but ``get()`` was never called, we unlink it here so /dev/shm doesn't + leak. + """ + stale = [k for k in self._pending_keys if request_id in k] + for key in stale: + self._pending_keys.discard(key) + try: + seg = shm_pkg.SharedMemory(name=key) + seg.close() + seg.unlink() + logger.debug("cleanup: unlinked unconsumed SHM segment %s", key) + except FileNotFoundError: + pass + except Exception as e: + logger.debug("cleanup: failed to unlink SHM segment %s: %s", key, e) + lock_file = f"/dev/shm/shm_{key}_lockfile.lock" + if os.path.exists(lock_file): + try: + os.remove(lock_file) + except OSError: + pass def close(self) -> None: - pass + """Unlink all remaining tracked SHM segments.""" + for key in list(self._pending_keys): + try: + seg = shm_pkg.SharedMemory(name=key) + seg.close() + seg.unlink() + except Exception: + pass + lock_file = f"/dev/shm/shm_{key}_lockfile.lock" + if os.path.exists(lock_file): + try: + os.remove(lock_file) + except OSError: + pass + self._pending_keys.clear() def health(self) -> dict[str, Any]: return {"status": "healthy", "threshold": self.threshold, **self._metrics} diff --git a/vllm_omni/distributed/omni_connectors/utils/initialization.py b/vllm_omni/distributed/omni_connectors/utils/initialization.py index 37b7d0d7f83..0497bbb3a23 100644 --- a/vllm_omni/distributed/omni_connectors/utils/initialization.py +++ b/vllm_omni/distributed/omni_connectors/utils/initialization.py @@ -23,6 +23,11 @@ # collide with request-forwarding endpoints that share the same base port. KV_TRANSFER_PORT_OFFSET = 100 +# Port stride between TP ranks so each worker binds a unique ZMQ port +# when TP > 1. Must be larger than the maximum number of pipeline stages. +# Formula: zmq_port = base + KV_TRANSFER_PORT_OFFSET + rank * STRIDE + stage +KV_RANK_PORT_STRIDE = 16 + def initialize_connectors_from_config( config_path: str | Path | None = None, From 22898a09ffdf9c500ad455e15c4220470ee6c20d Mon Sep 17 00:00:00 2001 From: natureofnature Date: Mon, 13 Apr 2026 10:11:15 +0000 Subject: [PATCH 2/3] update meta type Signed-off-by: natureofnature --- vllm_omni/distributed/omni_connectors/connectors/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm_omni/distributed/omni_connectors/connectors/base.py b/vllm_omni/distributed/omni_connectors/connectors/base.py index 23748400d07..0df428f2ff5 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/base.py +++ b/vllm_omni/distributed/omni_connectors/connectors/base.py @@ -34,7 +34,9 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[ pass @abstractmethod - def get(self, from_stage: str, to_stage: str, get_key: str, metadata=None) -> tuple[Any, int] | None: + def get( + self, from_stage: str, to_stage: str, get_key: str, metadata: dict[str, Any] | None = None + ) -> tuple[Any, int] | None: """Retrieve Python object and payload size (bytes). Args: From 272e520af346b5d2b39c68d830c196127b85efd6 Mon Sep 17 00:00:00 2001 From: natureofnature Date: Mon, 13 Apr 2026 15:57:11 +0000 Subject: [PATCH 3/3] update for review Signed-off-by: natureofnature --- .../mooncake_transfer_engine_connector.py | 16 +++++++++++----- .../omni_connectors/connectors/shm_connector.py | 16 ++++++++++++---- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py index 84224c6738d..bd4160f3e63 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py @@ -804,11 +804,17 @@ def get( # Path 2: partial metadata (host/port only) — query that sender partial_host = metadata.get("source_host") partial_port = metadata.get("source_port") - if partial_host and partial_port: - queried = self._query_metadata_at(get_key, str(partial_host), int(partial_port)) - if not queried: - return None - metadata = queried + if not partial_host or not partial_port: + logger.warning( + "get(%s): partial metadata missing source_host/source_port, cannot resolve data_size. metadata=%s", + get_key, + metadata, + ) + return None + queried = self._query_metadata_at(get_key, str(partial_host), int(partial_port)) + if not queried: + return None + metadata = queried _t1 = _time_mod.perf_counter() _query_ms = (_t1 - _t0) * 1000 diff --git a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py index 6468cf2c85b..6cf5c2f15b5 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py @@ -112,7 +112,10 @@ def _get_by_key(self, get_key: str) -> tuple[Any, int] | None: if result is not None: self._pending_keys.discard(get_key) return result + except FileNotFoundError: + return None except Exception: + logger.debug("_get_by_key: unexpected error reading SHM segment %s", get_key, exc_info=True) return None finally: if shm: @@ -158,11 +161,16 @@ def get( def cleanup(self, request_id: str) -> None: """Best-effort cleanup of unconsumed SHM segments for *request_id*. - If ``put()`` wrote a segment keyed by *request_id* (or containing it) - but ``get()`` was never called, we unlink it here so /dev/shm doesn't - leak. + Matches pending keys where *request_id* appears as the full key, + as a ``_``-delimited prefix, or as a ``_``-delimited suffix. + If ``get()`` was never called, we unlink it here so /dev/shm + doesn't leak. """ - stale = [k for k in self._pending_keys if request_id in k] + stale = [ + k + for k in self._pending_keys + if k == request_id or k.startswith(request_id + "_") or k.endswith("_" + request_id) + ] for key in stale: self._pending_keys.discard(key) try: