Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions tests/distributed/omni_connectors/test_shm_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for SharedMemoryConnector focusing on TP / CFG / metadata fallback."""

import pytest

from vllm_omni.distributed.omni_connectors.connectors.shm_connector import (
SharedMemoryConnector,
)

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


@pytest.fixture()
def connector():
c = SharedMemoryConnector({"shm_threshold_bytes": 64})
yield c
c.close()


# ── Key-based read (the fundamental SHM path) ────────────────────────


class TestKeyBasedReadWrite:
def test_put_then_get_by_key(self, connector):
data = {"hello": "world", "n": 42}
ok, size, meta = connector.put("s0", "s1", "test_key_1", data)
assert ok
assert size > 0
assert "shm" in meta
assert "test_key_1" in connector._pending_keys

result = connector.get("s0", "s1", "test_key_1", metadata=None)
assert result is not None
obj, rsize = result
assert obj == data
assert rsize == size
assert "test_key_1" not in connector._pending_keys

def test_get_nonexistent_key_returns_none(self, connector):
result = connector.get("s0", "s1", "no_such_key_xyz", metadata=None)
assert result is None

def test_rank_aware_keys_independent(self, connector):
"""Each TP rank writes/reads its own key — simulates homogeneous TP."""
payloads = {}
for rank in range(4):
key = f"req1_s0_0_{rank}_{rank}"
data = {"rank": rank, "values": list(range(rank, rank + 3))}
ok, _, _ = connector.put("s0", "s1", key, data)
assert ok
payloads[rank] = data

for rank in range(4):
key = f"req1_s0_0_{rank}_{rank}"
result = connector.get("s0", "s1", key, metadata=None)
assert result is not None
obj, _ = result
assert obj == payloads[rank]


# ── Metadata fallback behaviour ──────────────────────────────────────


class TestMetadataFallback:
def test_rdma_style_metadata_falls_back_to_key(self, connector):
"""source_host/source_port metadata should be ignored; key read used."""
data = {"payload": True}
connector.put("s0", "s1", "fb_key_1", data)

rdma_meta = {"source_host": "10.0.0.1", "source_port": 12345}
result = connector.get("s0", "s1", "fb_key_1", metadata=rdma_meta)
assert result is not None
obj, _ = result
assert obj == data

def test_non_dict_metadata_falls_back_to_key(self, connector):
data = {"val": 99}
connector.put("s0", "s1", "fb_key_2", data)

result = connector.get("s0", "s1", "fb_key_2", metadata="not_a_dict")
assert result is not None
obj, _ = result
assert obj == data

def test_empty_dict_metadata_falls_back_to_key(self, connector):
data = {"x": 1}
connector.put("s0", "s1", "fb_key_3", data)

result = connector.get("s0", "s1", "fb_key_3", metadata={})
assert result is not None
obj, _ = result
assert obj == data

def test_shm_handle_metadata_still_works(self, connector):
"""When metadata contains a proper 'shm' handle, use it directly."""
data = {"direct": True}
ok, size, meta = connector.put("s0", "s1", "shm_direct_1", data)
assert ok
result = connector.get("s0", "s1", "shm_direct_1", metadata=meta)
assert result is not None
obj, _ = result
assert obj == data

def test_metadata_keyed_by_request_id(self, connector):
"""Metadata wrapped as {get_key: actual_meta} should be unwrapped."""
data = {"wrapped": True}
ok, size, meta = connector.put("s0", "s1", "wrap_key", data)
assert ok
wrapped = {"wrap_key": meta}
result = connector.get("s0", "s1", "wrap_key", metadata=wrapped)
assert result is not None
obj, _ = result
assert obj == data


# ── Heterogeneous TP multi-key read ──────────────────────────────────


class TestHeteroTPMultiKey:
def test_receiver_reads_multiple_sender_keys(self, connector):
"""Simulates from_tp=2 -> to_tp=1: receiver reads 2 keys and merges."""
for sender_rank in range(2):
key = f"req1_s0_0_{sender_rank}_0"
data = {"sender": sender_rank, "shard": [sender_rank * 10]}
connector.put("s0", "s1", key, data)

shards = []
for sender_rank in range(2):
key = f"req1_s0_0_{sender_rank}_0"
result = connector.get("s0", "s1", key, metadata=None)
assert result is not None
obj, _ = result
shards.append(obj)

assert len(shards) == 2
assert shards[0]["sender"] == 0
assert shards[1]["sender"] == 1

def test_sender_writes_multiple_receiver_keys(self, connector):
"""Simulates from_tp=1 -> to_tp=2: sender writes 2 sliced keys."""
for recv_rank in range(2):
key = f"req1_s0_0_0_{recv_rank}"
data = {"target": recv_rank, "slice": list(range(recv_rank, recv_rank + 2))}
connector.put("s0", "s1", key, data)

for recv_rank in range(2):
key = f"req1_s0_0_0_{recv_rank}"
result = connector.get("s0", "s1", key, metadata=None)
assert result is not None
obj, _ = result
assert obj["target"] == recv_rank


# ── Cleanup ──────────────────────────────────────────────────────────


class TestCleanup:
def test_cleanup_removes_unconsumed_segment(self, connector):
data = {"leak": True}
connector.put("s0", "s1", "cleanup_req_42", data)
assert "cleanup_req_42" in connector._pending_keys

connector.cleanup("req_42")
assert "cleanup_req_42" not in connector._pending_keys

result = connector.get("s0", "s1", "cleanup_req_42", metadata=None)
assert result is None

def test_cleanup_noop_for_consumed_segment(self, connector):
data = {"consumed": True}
connector.put("s0", "s1", "consumed_req_99", data)
connector.get("s0", "s1", "consumed_req_99", metadata=None)

connector.cleanup("req_99")
assert "consumed_req_99" not in connector._pending_keys

def test_close_cleans_all_pending(self, connector):
for i in range(3):
connector.put("s0", "s1", f"close_test_{i}", {"i": i})

assert len(connector._pending_keys) == 3
connector.close()
assert len(connector._pending_keys) == 0
10 changes: 9 additions & 1 deletion vllm_omni/distributed/omni_connectors/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,21 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[
pass

@abstractmethod
def get(self, from_stage: str, to_stage: str, get_key: str, metadata=None) -> tuple[Any, int] | None:
def get(
self, from_stage: str, to_stage: str, get_key: str, metadata: dict[str, Any] | None = None
) -> tuple[Any, int] | None:
"""Retrieve Python object and payload size (bytes).

Args:
from_stage: Source stage identifier
to_stage: Destination stage identifier
get_key: Unique request identifier
metadata: Optional transport-specific metadata. When provided,
the connector uses it directly (e.g. source_host, source_port,
data_size) instead of querying the sender. For heterogeneous
TP the manager may supply partial metadata (host/port only);
the connector will query the sender at that address to fill
in data_size.

Returns:
Tuple of (Python object, serialized byte size) if found, None otherwise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,24 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[
try:
serialized_data = self.serialize_obj(data)
key = self._make_key(put_key, from_stage, to_stage)
self.store.put(key, serialized_data, self.pin)
put_rc = self.store.put(key, serialized_data, self.pin)
Comment on lines 79 to +81

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the 1-receiver-to-N-senders pattern (partial metadata / per-rank endpoint) also extend to MooncakeStoreConnector? Right now only MooncakeTransferEngineConnector has the multi-sender routing, but the store connector still uses a single store.put(key, data) — curious if heterogeneous TP will need similar changes there.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MooncakeStoreConnector does not need to create p2p side channel peers, and it naturally supports this.


if isinstance(put_rc, bool):
put_ok = put_rc
else:
put_ok = put_rc is None or put_rc == 0

if not put_ok:
self._metrics["errors"] += 1
logger.error(
"MooncakeStoreConnector put failed for %s (%s -> %s), rc=%r, %d bytes",
key,
from_stage,
to_stage,
put_rc,
len(serialized_data),
)
return False, 0, None

self._metrics["puts"] += 1
self._metrics["bytes_transferred"] += len(serialized_data)
Expand Down
Loading
Loading