diff --git a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py index 5c7384c1f8b..9c5118c84cd 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py @@ -48,8 +48,7 @@ def put( payload = self.serialize_obj(data) size = len(payload) - # Currently, we always use SHM. - if True: + if size > self.threshold: # Use Shared Memory lock_file = f"/dev/shm/shm_{put_key}_lockfile.lock" with open(lock_file, "wb+") as lockf: @@ -122,13 +121,21 @@ def get( return self._get_data_with_lock(lock_file, shm_handle) return None + shm = None try: shm = shm_pkg.SharedMemory(name=get_key) if shm is None or shm.size == 0: return None - lock_file = f"/dev/shm/shm_{get_key}_lockfile.lock" + except Exception as e: + # Probable cause: producer and consumer not on the same node + # and metadata is not propagated to consumer + logger.error(f"SharedMemoryConnector shm get failed for req {get_key}: {e}") + return None + + try: shm_handle = {"name": get_key, "size": shm.size} + lock_file = f"/dev/shm/shm_{shm_handle['name']}_lockfile.lock" return self._get_data_with_lock(lock_file, shm_handle) except Exception: return None diff --git a/vllm_omni/distributed/omni_connectors/connectors/test_shm_connector.py b/vllm_omni/distributed/omni_connectors/connectors/test_shm_connector.py new file mode 100644 index 00000000000..33a8b9fdbda --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/connectors/test_shm_connector.py @@ -0,0 +1,41 @@ +import pytest + +from vllm_omni.distributed.omni_connectors.connectors.shm_connector import SharedMemoryConnector + + +@pytest.mark.parametrize( + "size, threshold", + [ + (50, 100), # size < threshold, inline + (200, 100), # size > threshold, shm + ], +) +def test_shared_memory_connector(size, threshold): + from_stage = "dummy_from_stage" + to_stage = "dummy_to_stage" + key = "dummy_key" + data = b" " * size + + tx = SharedMemoryConnector({"shm_threshold_bytes": threshold}) + rx = SharedMemoryConnector({"shm_threshold_bytes": threshold}) + + success, _, metadata = tx.put( + from_stage=from_stage, + to_stage=to_stage, + put_key=key, + data=data, + ) + assert success + + if size < threshold: + assert "inline_bytes" in metadata + else: + assert "shm" in metadata + + rx_data, _ = rx.get( + from_stage=from_stage, + to_stage=to_stage, + get_key=key, + metadata=metadata, + ) + assert rx_data == data