From bf6924c331e01363d0220f9f2bb277ab1f8f32cc Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 3 Feb 2026 22:40:51 +0000 Subject: [PATCH 1/4] add default_vllm_config so tests pass Signed-off-by: Randall Smith --- tests/v1/kv_connector/unit/test_moriio_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index 1cc6988635d8..0abd9eb604a3 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -392,7 +392,7 @@ def test_read_mode_loads_remote_block_ids(moriio_read_mode): @pytest.mark.skipif( not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend" ) -def test_register_kv_caches(mock_parallel_groups): +def test_register_kv_caches(default_vllm_config, mock_parallel_groups): """Test that MoRIIOConnector.register_kv_caches correctly registers kv caches.""" ROLE = "kv_consumer" IP = get_ip() @@ -486,7 +486,7 @@ def test_register_kv_caches(mock_parallel_groups): @pytest.mark.skipif( not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend" ) -def test_moriio_handshake_returns_metadata(mock_parallel_groups): +def test_moriio_handshake_returns_metadata(default_vllm_config, mock_parallel_groups): """MoRIIO handshake socket returns valid agent metadata over ZMQ.""" ROLE = "kv_consumer" From dcf163926f4222d17e8421776b70c7797fe78204 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 3 Feb 2026 23:05:56 +0000 Subject: [PATCH 2/4] use set_current_vllm_config Signed-off-by: Randall Smith --- .../unit/test_moriio_connector.py | 243 +++++++++--------- 1 file changed, 124 insertions(+), 119 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index 0abd9eb604a3..f21bcfec3964 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -17,6 +17,7 @@ ModelConfig, SchedulerConfig, VllmConfig, + set_current_vllm_config, ) from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import ( MoRIIOAgentMetadata, @@ -392,7 +393,7 @@ def test_read_mode_loads_remote_block_ids(moriio_read_mode): @pytest.mark.skipif( not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend" ) -def test_register_kv_caches(default_vllm_config, mock_parallel_groups): +def test_register_kv_caches(mock_parallel_groups): """Test that MoRIIOConnector.register_kv_caches correctly registers kv caches.""" ROLE = "kv_consumer" IP = get_ip() @@ -404,89 +405,92 @@ def test_register_kv_caches(default_vllm_config, mock_parallel_groups): backend_cls = AiterFlashAttentionBackend - # Create test kv cache tensors using proper backend shape - kv_cache_shape = backend_cls.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 - ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - kv_caches = { - "layer0": shared_tensor, - "layer1": unique_tensor, - "layer2": shared_tensor, - } - - with ( - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event" - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread" - ), - ): - # Create connector - vllm_config.kv_transfer_config.kv_connector_extra_config.update( - { - "proxy_ip": "127.0.0.1", - "proxy_ping_port": 12345, - "http_port": 12346, - } + with set_current_vllm_config(vllm_config): + # Create test kv cache tensors using proper backend shape + kv_cache_shape = backend_cls.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } - connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) - connector.connector_worker = FakeMorIIOConnectorWorker( - vllm_config, connector.engine_id, hand_shake_latency=0 - ) + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event" + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread" + ), + ): + # Create connector + vllm_config.kv_transfer_config.kv_connector_extra_config.update( + { + "proxy_ip": "127.0.0.1", + "proxy_ping_port": 12345, + "http_port": 12346, + } + ) - from mori.io import ( - MemoryDesc, - ) + connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeMorIIOConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) - # Execute register_kv_caches - connector.register_kv_caches(kv_caches) - - # Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata - assert ( - shared_tensor.data_ptr() - == MemoryDesc.unpack( - connector.connector_worker.layer_name_to_local_kv_cache_metadata[ - "layer0" - ][0] - ).data - ) - assert ( - unique_tensor.data_ptr() - == MemoryDesc.unpack( - connector.connector_worker.layer_name_to_local_kv_cache_metadata[ - "layer1" - ][0] - ).data - ) - assert ( - shared_tensor.data_ptr() - == MemoryDesc.unpack( - connector.connector_worker.layer_name_to_local_kv_cache_metadata[ - "layer2" - ][0] - ).data - ) + from mori.io import ( + MemoryDesc, + ) - # Verify engine keys - expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" - assert ( - MemoryDesc.unpack( - connector.connector_worker.layer_name_to_local_kv_cache_metadata[ - "layer0" - ][0] - ).engine_key - == expected_engine_key - ) + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) + + # Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata + assert ( + shared_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer0" + ][0] + ).data + ) + assert ( + unique_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer1" + ][0] + ).data + ) + assert ( + shared_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer2" + ][0] + ).data + ) + + # Verify engine keys + expected_engine_key = ( + f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" + ) + assert ( + MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer0" + ][0] + ).engine_key + == expected_engine_key + ) @pytest.mark.skipif( not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend" ) -def test_moriio_handshake_returns_metadata(default_vllm_config, mock_parallel_groups): +def test_moriio_handshake_returns_metadata(mock_parallel_groups): """MoRIIO handshake socket returns valid agent metadata over ZMQ.""" ROLE = "kv_consumer" @@ -495,51 +499,52 @@ def test_moriio_handshake_returns_metadata(default_vllm_config, mock_parallel_gr backend_cls = AiterFlashAttentionBackend - # Create test kv cache tensors using proper backend shape - kv_cache_shape = backend_cls.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 - ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - kv_caches = { - "layer0": shared_tensor, - "layer1": unique_tensor, - "layer2": shared_tensor, - } - - with ( - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", - FakeMorIIOWrapper, - ), - ): - handshake_port = _find_free_port() - # Create connector - vllm_config.kv_transfer_config.kv_connector_extra_config.update( - { - "proxy_ip": "127.0.0.1", - "proxy_ping_port": 12345, - "http_port": 12346, - "handshake_port": handshake_port, - } + with set_current_vllm_config(vllm_config): + # Create test kv cache tensors using proper backend shape + kv_cache_shape = backend_cls.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 ) - connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) - - # Execute register_kv_caches - connector.register_kv_caches(kv_caches) - - # Connect to handshake socket and request metadata - path = make_zmq_path("tcp", "127.0.0.1", handshake_port) - with zmq_ctx(zmq.DEALER, path) as sock: - sock.send(MoRIIOConstants.GET_META_MSG) - received_frame = sock.recv_multipart() - - if len(received_frame) != 2 or received_frame[0] != b"": - raise ValueError(f"Unexpected frame! {received_frame = }") + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } - metadata_bytes = received_frame[1] - decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) - metadata = decoder.decode(metadata_bytes) - assert isinstance(metadata, MoRIIOAgentMetadata), ( - "Decoded metadata is not MoRIIOAgentMetadata" + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", + FakeMorIIOWrapper, + ), + ): + handshake_port = _find_free_port() + # Create connector + vllm_config.kv_transfer_config.kv_connector_extra_config.update( + { + "proxy_ip": "127.0.0.1", + "proxy_ping_port": 12345, + "http_port": 12346, + "handshake_port": handshake_port, + } ) + connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) + + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) + + # Connect to handshake socket and request metadata + path = make_zmq_path("tcp", "127.0.0.1", handshake_port) + with zmq_ctx(zmq.DEALER, path) as sock: + sock.send(MoRIIOConstants.GET_META_MSG) + received_frame = sock.recv_multipart() + + if len(received_frame) != 2 or received_frame[0] != b"": + raise ValueError(f"Unexpected frame! {received_frame = }") + + metadata_bytes = received_frame[1] + decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) + metadata = decoder.decode(metadata_bytes) + assert isinstance(metadata, MoRIIOAgentMetadata), ( + "Decoded metadata is not MoRIIOAgentMetadata" + ) From 89a9ff9894c94af951e4ac7c309873e858bcb7f7 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 18 Feb 2026 19:33:58 +0000 Subject: [PATCH 3/4] move the with statement up Signed-off-by: Randall Smith --- .../unit/test_moriio_connector.py | 112 +++++++++--------- 1 file changed, 56 insertions(+), 56 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index f21bcfec3964..10771edec65c 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -405,36 +405,36 @@ def test_register_kv_caches(mock_parallel_groups): backend_cls = AiterFlashAttentionBackend - with set_current_vllm_config(vllm_config): - # Create test kv cache tensors using proper backend shape - kv_cache_shape = backend_cls.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 - ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - kv_caches = { - "layer0": shared_tensor, - "layer1": unique_tensor, - "layer2": shared_tensor, - } + # Create test kv cache tensors using proper backend shape + kv_cache_shape = backend_cls.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } - with ( - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event" - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread" - ), - ): - # Create connector - vllm_config.kv_transfer_config.kv_connector_extra_config.update( - { - "proxy_ip": "127.0.0.1", - "proxy_ping_port": 12345, - "http_port": 12346, - } - ) + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event" + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread" + ), + ): + # Create connector + vllm_config.kv_transfer_config.kv_connector_extra_config.update( + { + "proxy_ip": "127.0.0.1", + "proxy_ping_port": 12345, + "http_port": 12346, + } + ) + with set_current_vllm_config(vllm_config): connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) connector.connector_worker = FakeMorIIOConnectorWorker( vllm_config, connector.engine_id, hand_shake_latency=0 @@ -499,35 +499,35 @@ def test_moriio_handshake_returns_metadata(mock_parallel_groups): backend_cls = AiterFlashAttentionBackend - with set_current_vllm_config(vllm_config): - # Create test kv cache tensors using proper backend shape - kv_cache_shape = backend_cls.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 - ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - kv_caches = { - "layer0": shared_tensor, - "layer1": unique_tensor, - "layer2": shared_tensor, - } + # Create test kv cache tensors using proper backend shape + kv_cache_shape = backend_cls.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } - with ( - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", - FakeMorIIOWrapper, - ), - ): - handshake_port = _find_free_port() - # Create connector - vllm_config.kv_transfer_config.kv_connector_extra_config.update( - { - "proxy_ip": "127.0.0.1", - "proxy_ping_port": 12345, - "http_port": 12346, - "handshake_port": handshake_port, - } - ) + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", + FakeMorIIOWrapper, + ), + ): + handshake_port = _find_free_port() + # Create connector + vllm_config.kv_transfer_config.kv_connector_extra_config.update( + { + "proxy_ip": "127.0.0.1", + "proxy_ping_port": 12345, + "http_port": 12346, + "handshake_port": handshake_port, + } + ) + with set_current_vllm_config(vllm_config): connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) # Execute register_kv_caches From ead9be706403a87b9a75224cee013101a88d2a3a Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 18 Feb 2026 21:26:56 +0000 Subject: [PATCH 4/4] just put the connector creation in the with Signed-off-by: Randall Smith --- .../unit/test_moriio_connector.py | 114 +++++++++--------- 1 file changed, 56 insertions(+), 58 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index 10771edec65c..1eca4964fd6c 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -440,51 +440,49 @@ def test_register_kv_caches(mock_parallel_groups): vllm_config, connector.engine_id, hand_shake_latency=0 ) - from mori.io import ( - MemoryDesc, - ) + from mori.io import ( + MemoryDesc, + ) - # Execute register_kv_caches - connector.register_kv_caches(kv_caches) - - # Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata - assert ( - shared_tensor.data_ptr() - == MemoryDesc.unpack( - connector.connector_worker.layer_name_to_local_kv_cache_metadata[ - "layer0" - ][0] - ).data - ) - assert ( - unique_tensor.data_ptr() - == MemoryDesc.unpack( - connector.connector_worker.layer_name_to_local_kv_cache_metadata[ - "layer1" - ][0] - ).data - ) - assert ( - shared_tensor.data_ptr() - == MemoryDesc.unpack( - connector.connector_worker.layer_name_to_local_kv_cache_metadata[ - "layer2" - ][0] - ).data - ) + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) + + # Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata + assert ( + shared_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer0" + ][0] + ).data + ) + assert ( + unique_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer1" + ][0] + ).data + ) + assert ( + shared_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer2" + ][0] + ).data + ) - # Verify engine keys - expected_engine_key = ( - f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" - ) - assert ( - MemoryDesc.unpack( - connector.connector_worker.layer_name_to_local_kv_cache_metadata[ - "layer0" - ][0] - ).engine_key - == expected_engine_key - ) + # Verify engine keys + expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" + assert ( + MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer0" + ][0] + ).engine_key + == expected_engine_key + ) @pytest.mark.skipif( @@ -530,21 +528,21 @@ def test_moriio_handshake_returns_metadata(mock_parallel_groups): with set_current_vllm_config(vllm_config): connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) - # Execute register_kv_caches - connector.register_kv_caches(kv_caches) + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) - # Connect to handshake socket and request metadata - path = make_zmq_path("tcp", "127.0.0.1", handshake_port) - with zmq_ctx(zmq.DEALER, path) as sock: - sock.send(MoRIIOConstants.GET_META_MSG) - received_frame = sock.recv_multipart() + # Connect to handshake socket and request metadata + path = make_zmq_path("tcp", "127.0.0.1", handshake_port) + with zmq_ctx(zmq.DEALER, path) as sock: + sock.send(MoRIIOConstants.GET_META_MSG) + received_frame = sock.recv_multipart() - if len(received_frame) != 2 or received_frame[0] != b"": - raise ValueError(f"Unexpected frame! {received_frame = }") + if len(received_frame) != 2 or received_frame[0] != b"": + raise ValueError(f"Unexpected frame! {received_frame = }") - metadata_bytes = received_frame[1] - decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) - metadata = decoder.decode(metadata_bytes) - assert isinstance(metadata, MoRIIOAgentMetadata), ( - "Decoded metadata is not MoRIIOAgentMetadata" - ) + metadata_bytes = received_frame[1] + decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) + metadata = decoder.decode(metadata_bytes) + assert isinstance(metadata, MoRIIOAgentMetadata), ( + "Decoded metadata is not MoRIIOAgentMetadata" + )