From 7b131e4e4ed9972c0818ad339839f895781cc92b Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 20 Mar 2026 10:46:37 +0100 Subject: [PATCH 1/8] init Signed-off-by: NickLucche --- .../kv_connector/unit/test_nixl_connector.py | 337 +++++++++ .../kv_connector/unit/test_tp_kv_topology.py | 685 ++++++++++++++++++ .../kv_connector/v1/nixl_connector.py | 124 +++- vllm/envs.py | 19 + 4 files changed, 1153 insertions(+), 12 deletions(-) create mode 100644 tests/v1/kv_connector/unit/test_tp_kv_topology.py diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 472599747087..028f55b3dbcd 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -2380,6 +2380,343 @@ def test_compatibility_hash_validation( assert len(result) == 1 +class TestHeartbeatLeaseManagement: + """Tests for the heartbeat-based lease management system.""" + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_d_side_heartbeat_sending(self, default_vllm_config, dist_init, monkeypatch): + """Test that D-side sends heartbeats to P engines with pending transfers.""" + # Set a short renewal interval for testing + monkeypatch.setenv("VLLM_NIXL_LEASE_RENEWAL_INTERVAL", "0.1") + + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + # Override the renewal interval since env var is read at init + worker._lease_renewal_interval = 0.1 + + # Simulate remote agent registration (handshake complete) + remote_engine_id = FakeNixlConnectorWorker.REMOTE_ENGINE_ID + worker._remote_agents[remote_engine_id] = {0: "fake_agent"} + + # Track pending transfers for this engine + worker._pending_transfers_by_engine[remote_engine_id] = {"req1", "req2", "req3"} + + # Track sent notifications + sent_notifs: list[tuple[str, bytes]] = [] + + def mock_send_notif(agent_name: str, notif_msg: bytes): + sent_notifs.append((agent_name, notif_msg)) + + worker.nixl_wrapper.send_notif = mock_send_notif + + # First call should send heartbeat (no previous renewal time) + worker._send_lease_heartbeats() + + assert len(sent_notifs) == 1 + agent_name, msg = sent_notifs[0] + assert agent_name == "fake_agent" + # Verify message format: "HB:req1,req2,req3" (order may vary) + assert msg.startswith(b"HB:") + req_ids = msg.decode()[3:].split(",") + assert set(req_ids) == {"req1", "req2", "req3"} + + # Immediate second call should NOT send (within renewal interval) + worker._send_lease_heartbeats() + assert len(sent_notifs) == 1 # Still just 1 + + # Wait for renewal interval and call again + time.sleep(0.15) + worker._send_lease_heartbeats() + assert len(sent_notifs) == 2 # Now 2 + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_p_side_heartbeat_extends_lease( + self, default_vllm_config, dist_init, monkeypatch + ): + """Test that P-side extends lease when receiving heartbeat.""" + monkeypatch.setenv("VLLM_NIXL_LEASE_EXTENSION", "30") + + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + # Setup kv_topo needed by _get_new_notifs + backend = get_current_attn_backend(vllm_config) + test_shape = backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + worker.kv_topo = TpKVTopology( + tp_rank=worker.tp_rank, + engine_id=worker.engine_id, + remote_tp_size=worker._tp_size, + remote_block_size=worker._block_size, + is_mla=worker.use_mla, + total_num_kv_heads=worker.model_config.get_total_num_kv_heads(), + attn_backends=[backend], + tensor_shape=test_shape, + ) + + # Register requests waiting to be sent (P-side tracking) + now = time.perf_counter() + initial_expiry = now + 5.0 # Expires in 5 seconds + worker._reqs_to_send["req1"] = initial_expiry + worker._reqs_to_send["req2"] = initial_expiry + worker._reqs_to_send["req3"] = initial_expiry + worker._reqs_to_process = {"req1", "req2", "req3"} + + # Simulate heartbeat notification from D + heartbeat_msg = b"HB:req1,req2" + + worker.nixl_wrapper.get_new_notifs = lambda: {"agent": [heartbeat_msg]} + + # Process notifications + worker._get_new_notifs() + + # req1 and req2 should have extended leases (now + 30s) + # req3 should still have original expiry + assert worker._reqs_to_send["req1"] > initial_expiry + 20 + assert worker._reqs_to_send["req2"] > initial_expiry + 20 + assert worker._reqs_to_send["req3"] == initial_expiry + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_p_side_lease_expiration(self, default_vllm_config, dist_init, monkeypatch): + """Test that P-side expires leases when no heartbeats received.""" + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + # Setup kv_topo needed by get_finished + backend = get_current_attn_backend(vllm_config) + test_shape = backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + worker.kv_topo = TpKVTopology( + tp_rank=worker.tp_rank, + engine_id=worker.engine_id, + remote_tp_size=worker._tp_size, + remote_block_size=worker._block_size, + is_mla=worker.use_mla, + total_num_kv_heads=worker.model_config.get_total_num_kv_heads(), + attn_backends=[backend], + tensor_shape=test_shape, + ) + + # Register a request with an already-expired lease + now = time.perf_counter() + worker._reqs_to_send["expired_req"] = now - 1.0 # Already expired + worker._reqs_to_process.add("expired_req") + + # Also register a request that's not expired yet + worker._reqs_to_send["valid_req"] = now + 100.0 # Far future + worker._reqs_to_process.add("valid_req") + + # get_finished should return the expired request + done_sending, _ = connector.get_finished(finished_req_ids=set()) + + assert "expired_req" in done_sending + assert "valid_req" not in done_sending + + # Verify expired request is cleaned up + assert "expired_req" not in worker._reqs_to_send + assert "expired_req" not in worker._reqs_to_process + + # Valid request should still be tracked + assert "valid_req" in worker._reqs_to_send + assert "valid_req" in worker._reqs_to_process + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_heartbeat_with_empty_requests(self, default_vllm_config, dist_init): + """Test heartbeat handling with empty request string.""" + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + # Setup kv_topo + backend = get_current_attn_backend(vllm_config) + test_shape = backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + worker.kv_topo = TpKVTopology( + tp_rank=worker.tp_rank, + engine_id=worker.engine_id, + remote_tp_size=worker._tp_size, + remote_block_size=worker._block_size, + is_mla=worker.use_mla, + total_num_kv_heads=worker.model_config.get_total_num_kv_heads(), + attn_backends=[backend], + tensor_shape=test_shape, + ) + + # Empty heartbeat message (just "HB:" with no request IDs) + worker.nixl_wrapper.get_new_notifs = lambda: {"agent": [b"HB:"]} + + # Should handle gracefully without error + notified = worker._get_new_notifs() + assert len(notified) == 0 + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_d_side_cleanup_on_transfer_complete( + self, default_vllm_config, dist_init + ): + """Test that D-side removes completed transfers from heartbeat tracking.""" + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + remote_engine_id = FakeNixlConnectorWorker.REMOTE_ENGINE_ID + + # Setup kv_topo + backend = get_current_attn_backend(vllm_config) + test_shape = backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + worker.kv_topo = TpKVTopology( + tp_rank=worker.tp_rank, + engine_id=worker.engine_id, + remote_tp_size=worker._tp_size, + remote_block_size=worker._block_size, + is_mla=worker.use_mla, + total_num_kv_heads=worker.model_config.get_total_num_kv_heads(), + attn_backends=[backend], + tensor_shape=test_shape, + ) + + # Simulate transfer metadata + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + RemoteMeta, + ReqMeta, + ) + + req_id = "test_transfer" + worker._recving_metadata[req_id] = ReqMeta( + local_block_ids=([1, 2, 3],), + local_physical_block_ids=([1, 2, 3],), + tp_size=1, + remote=RemoteMeta( + block_ids=([4, 5, 6],), + host="localhost", + port=1234, + engine_id=remote_engine_id, + request_id=f"prefill-{req_id}", + ), + ) + + # Add to pending transfers (D-side heartbeat tracking) + worker._pending_transfers_by_engine[remote_engine_id].add(req_id) + + # Simulate transfer handle completion + handle = 12345 + worker._recving_transfers[req_id] = [handle] + + # Mock check_xfer_state to return DONE + worker.nixl_wrapper._cycles_before_xfer_done = 0 + + # get_finished should complete the transfer and clean up + _, done_recving = connector.get_finished(finished_req_ids=set()) + + assert req_id in done_recving + # Should be removed from pending transfers + assert req_id not in worker._pending_transfers_by_engine[remote_engine_id] + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_heartbeat_send_failure_logged( + self, default_vllm_config, dist_init, caplog + ): + """Test that heartbeat send failures are logged as warnings.""" + import logging + + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + worker._lease_renewal_interval = 0 # Send immediately + + remote_engine_id = FakeNixlConnectorWorker.REMOTE_ENGINE_ID + worker._remote_agents[remote_engine_id] = {0: "fake_agent"} + worker._pending_transfers_by_engine[remote_engine_id] = {"req1"} + + # Make send_notif raise an exception + def failing_send_notif(agent_name: str, notif_msg: bytes): + raise RuntimeError("Network error") + + worker.nixl_wrapper.send_notif = failing_send_notif + + # Capture logs from the nixl_connector logger + nixl_logger = logging.getLogger( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" + ) + captured_logs: list[logging.LogRecord] = [] + + class LogCapture(logging.Handler): + def emit(self, record): + captured_logs.append(record) + + handler = LogCapture() + handler.setLevel(logging.WARNING) + nixl_logger.addHandler(handler) + + try: + # Should not raise, just log warning + worker._send_lease_heartbeats() + finally: + nixl_logger.removeHandler(handler) + + # Verify warning was logged + warning_logs = [r for r in captured_logs if r.levelno == logging.WARNING] + assert len(warning_logs) >= 1 + assert any( + "Failed to send heartbeat" in r.message for r in warning_logs + ) + + @pytest.mark.parametrize( "error_scenario", [ diff --git a/tests/v1/kv_connector/unit/test_tp_kv_topology.py b/tests/v1/kv_connector/unit/test_tp_kv_topology.py new file mode 100644 index 000000000000..8c0552d5d33a --- /dev/null +++ b/tests/v1/kv_connector/unit/test_tp_kv_topology.py @@ -0,0 +1,685 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for TpKVTopology with various attention backend configurations. + +These tests validate layout detection, block_size_position, and multi-backend +model behavior without loading any models. We use mock backends that replicate +the get_kv_cache_shape signatures of real backends. + +Backend shape families: + - FlashAttn-like: (2, N, B, H, D) -- KV-first + - FlashInfer-like: (N, 2, B, H, D) -- blocks-first + - MLA-like: (N, B, D) -- 3-dim, no KV split + - Mamba-like: NotImplementedError -- no KV cache shape + - TritonAttn-like: (N, 2, B, H, D) -- blocks-first (same as FI) +""" + +import pytest +import torch + +from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology +from vllm.v1.attention.backend import AttentionBackend + + +# --------------------------------------------------------------------------- +# Mock Attention Backends +# --------------------------------------------------------------------------- +class MockFlashAttnBackend(AttentionBackend): + """Mimics FlashAttentionBackend: shape = (2, N, B, H, D)""" + + @staticmethod + def get_name() -> str: + return "MOCK_FLASH_ATTN" + + @staticmethod + def get_impl_cls(): + raise NotImplementedError + + @staticmethod + def get_builder_cls(): + raise NotImplementedError + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + if include_num_layers_dimension: + # HND cross-layer: (num_blocks, num_kv_heads, num_layers, 2, + # block_size, head_size) + return (2, 4, 0, 1, 3, 5) + return (0, 1, 3, 2, 4) + + +class MockFlashInferBackend(AttentionBackend): + """Mimics FlashInferBackend: shape = (N, 2, B, H, D)""" + + @staticmethod + def get_name() -> str: + return "MOCK_FLASHINFER" + + @staticmethod + def get_impl_cls(): + raise NotImplementedError + + @staticmethod + def get_builder_cls(): + raise NotImplementedError + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + if include_num_layers_dimension: + return (1, 0, 2, 3, 4, 5) + return (0, 1, 2, 3, 4) + + +class MockTritonAttnBackend(AttentionBackend): + """Mimics TritonAttentionBackend: shape = (N, 2, B, H, D) -- same as FI""" + + @staticmethod + def get_name() -> str: + return "MOCK_TRITON_ATTN" + + @staticmethod + def get_impl_cls(): + raise NotImplementedError + + @staticmethod + def get_builder_cls(): + raise NotImplementedError + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + +class MockMLABackend(AttentionBackend): + """Mimics MLA backends (FlashMLA, etc.): shape = (N, B, D) -- 3 dims""" + + @staticmethod + def get_name() -> str: + return "MOCK_MLA" + + @staticmethod + def get_impl_cls(): + raise NotImplementedError + + @staticmethod + def get_builder_cls(): + raise NotImplementedError + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + assert num_kv_heads == 1 + return (num_blocks, block_size, head_size) + + +class MockMambaBackend(AttentionBackend): + """Mimics Mamba backends: get_kv_cache_shape is not implemented.""" + + @staticmethod + def get_name() -> str: + return "MOCK_MAMBA" + + @staticmethod + def get_impl_cls(): + raise NotImplementedError + + @staticmethod + def get_builder_cls(): + raise NotImplementedError + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + raise NotImplementedError("Mamba backends do not have a KV cache shape") + + +# A ChunkedLocal backend that inherits FA's get_kv_cache_shape, +# exactly as the real ChunkedLocalAttention backend does. +class MockChunkedLocalFABackend(MockFlashAttnBackend): + """ + Mimics ChunkedLocalAttention backed by FlashAttn. + Inherits get_kv_cache_shape from FlashAttn -- same layout. + """ + + @staticmethod + def get_name() -> str: + return "MOCK_CHUNKED_LOCAL_FA" + + +class MockCPUAttnBackend(AttentionBackend): + """ + Mimics CPU attention backend: shape = (2, N, H, B, D) + Note different position of block_size vs num_kv_heads compared to FA. + """ + + @staticmethod + def get_name() -> str: + return "MOCK_CPU_ATTN" + + @staticmethod + def get_impl_cls(): + raise NotImplementedError + + @staticmethod + def get_builder_cls(): + raise NotImplementedError + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (2, num_blocks, num_kv_heads, block_size, head_size) + + +# --------------------------------------------------------------------------- +# Helper to build TpKVTopology with minimal required fields +# --------------------------------------------------------------------------- +def make_topo( + attn_backend: type[AttentionBackend], + is_mla: bool = False, + tp_rank: int = 0, + tp_size: int = 1, + total_num_kv_heads: int = 8, + block_size: int = 16, + tensor_shape: torch.Size | None = None, + engine_id: str = "test-engine", +) -> TpKVTopology: + remote_tp_size = {engine_id: tp_size} + remote_block_size = {engine_id: block_size} + return TpKVTopology( + tp_rank=tp_rank, + engine_id=engine_id, + remote_tp_size=remote_tp_size, + remote_block_size=remote_block_size, + is_mla=is_mla, + total_num_kv_heads=total_num_kv_heads, + attn_backend=attn_backend, + tensor_shape=tensor_shape, + ) + + +# =================================================================== +# 1. Layout Detection Tests +# =================================================================== +class TestLayoutDetection: + """Test is_kv_layout_blocks_first and split_k_and_v for each backend.""" + + def test_flash_attn_standard(self): + """FA: (2, N, B, H, D) -> blocks_first=False, split_k_and_v=True""" + topo = make_topo(MockFlashAttnBackend) + assert topo.is_kv_layout_blocks_first is False + assert topo.split_k_and_v is True + assert topo.cross_layers_blocks is False + + def test_flashinfer_standard(self): + """FI: (N, 2, B, H, D) -> blocks_first=True, split_k_and_v=False""" + topo = make_topo(MockFlashInferBackend) + assert topo.is_kv_layout_blocks_first is True + assert topo.split_k_and_v is False + assert topo.cross_layers_blocks is False + + def test_triton_attn_standard(self): + """Triton: (N, 2, B, H, D) -> same as FI (blocks_first=True)""" + topo = make_topo(MockTritonAttnBackend) + assert topo.is_kv_layout_blocks_first is True + assert topo.split_k_and_v is False + + def test_flash_attn_mla(self): + """FA with MLA: blocks_first=False, split_k_and_v=False (MLA overrides)""" + topo = make_topo(MockFlashAttnBackend, is_mla=True) + assert topo.is_kv_layout_blocks_first is False + assert topo.split_k_and_v is False + + def test_flashinfer_mla(self): + """FI with MLA: blocks_first=True, split_k_and_v=False""" + topo = make_topo(MockFlashInferBackend, is_mla=True) + assert topo.is_kv_layout_blocks_first is True + assert topo.split_k_and_v is False + + def test_mla_backend_3dim(self): + """ + Pure MLA backend (3-dim shape): blocks_first=False. + Shape is (N, B, D) -- 3 dims, first dim is num_blocks=1 (mock), + so the 5-dim blocks_first check fails. + """ + topo = make_topo(MockMLABackend, is_mla=True, total_num_kv_heads=1) + assert topo.is_kv_layout_blocks_first is False + assert topo.split_k_and_v is False + + def test_flash_attn_cross_layers(self): + """ + FA with cross-layer blocks: tensor_shape has one extra dim. + Shape from backend = (2, 1, 16, 1, 1) -> 5 dims + tensor_shape = (80, 2, 1, 16, 1, 1) -> 6 dims = 5 + 1 + => cross_layers_blocks=True, split_k_and_v=False + """ + kv_shape = MockFlashAttnBackend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + cross_layer_shape = torch.Size((80,) + kv_shape) + topo = make_topo(MockFlashAttnBackend, tensor_shape=cross_layer_shape) + assert topo.cross_layers_blocks is True + assert topo.split_k_and_v is False + + def test_flashinfer_cross_layers(self): + """FI with cross-layer blocks.""" + kv_shape = MockFlashInferBackend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + cross_layer_shape = torch.Size((80,) + kv_shape) + topo = make_topo(MockFlashInferBackend, tensor_shape=cross_layer_shape) + assert topo.cross_layers_blocks is True + assert topo.split_k_and_v is False + + def test_no_cross_layers_same_ndim(self): + """ + When tensor_shape has same ndim as kv_cache_shape, + cross_layers_blocks should be False. + """ + kv_shape = MockFlashAttnBackend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + topo = make_topo( + MockFlashAttnBackend, tensor_shape=torch.Size(kv_shape) + ) + assert topo.cross_layers_blocks is False + + def test_cpu_attn_layout(self): + """ + CPU attention: (2, N, H, B, D). + Not blocks_first (first dim is 2 with num_blocks mocked to 1), + and kv_cache_shape[0] != 1 when we have 5 dims. + """ + topo = make_topo(MockCPUAttnBackend) + # Shape with mocked values: (2, 1, 1, 16, 1) + # First dim = 2 (not 1), and len=5, so blocks_first check: + # len == 5 and shape[0] == 1? -> 5 dims but shape[0]=2 -> False + assert topo.is_kv_layout_blocks_first is False + assert topo.split_k_and_v is True + + +# =================================================================== +# 2. Block Size Position Tests +# =================================================================== +class TestBlockSizePosition: + """ + Verify block_size_position is correctly detected. + block_size_position is a negative index into the shape indicating + where the block_size dimension lives. + """ + + def test_flash_attn_block_size_position(self): + """FA shape: (2, N, B=16, H, D) -> B is at index 2, negative = -3""" + topo = make_topo(MockFlashAttnBackend) + assert topo.block_size_position == -3 + + def test_flashinfer_block_size_position(self): + """FI shape: (N, 2, B=16, H, D) -> B is at index 2, negative = -3""" + topo = make_topo(MockFlashInferBackend) + assert topo.block_size_position == -3 + + def test_mla_block_size_position(self): + """MLA shape: (N, B=16, D) -> B is at index 1, negative = -2""" + topo = make_topo(MockMLABackend, is_mla=True, total_num_kv_heads=1) + assert topo.block_size_position == -2 + + def test_cpu_attn_block_size_position(self): + """CPU shape: (2, N, H, B=16, D) -> B is at index 3, negative = -2""" + topo = make_topo(MockCPUAttnBackend) + assert topo.block_size_position == -2 + + def test_flash_attn_cross_layers_block_size_position(self): + """ + FA cross-layer: logical shape (L, 2, N, B, H, D), but after + stride_order permutation for HND cross-layer, the physical position + of B changes. + + Stride order for FA HND cross-layer: (2, 4, 0, 1, 3, 5) + Logical shape: (80, 2, 1, 16, 1, 1) + After permute: shape[2,4,0,1,3,5] = (1, 1, 80, 2, 16, 1) + B=16 is at physical index 4 -> negative = -2 + """ + kv_shape = MockFlashAttnBackend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + cross_layer_shape = torch.Size((80,) + kv_shape) + topo = make_topo(MockFlashAttnBackend, tensor_shape=cross_layer_shape) + assert topo.cross_layers_blocks is True + assert topo.block_size_position == -2 + + +# =================================================================== +# 3. Multi-Backend Model Configuration Tests +# =================================================================== +class TestMultiBackendModels: + """ + Test TpKVTopology behavior for model architectures that use + different attention backends across layers. + """ + + def test_qwen3_like_uniform_full_attn(self): + """ + Qwen3-like: All layers use FullAttentionSpec with FlashAttn backend. + Single backend family, all properties should be standard FA. + """ + topo = make_topo(MockFlashAttnBackend, is_mla=False) + assert topo.is_kv_layout_blocks_first is False + assert topo.split_k_and_v is True + assert topo.cross_layers_blocks is False + assert topo.block_size_position == -3 + + def test_deepseek_v3_mla(self): + """ + DeepSeek V3: All layers use MLAAttentionSpec with MLA backend. + 3-dim shape, is_mla=True, no KV split. + """ + topo = make_topo(MockMLABackend, is_mla=True, total_num_kv_heads=1) + assert topo.is_kv_layout_blocks_first is False + assert topo.split_k_and_v is False + assert topo.cross_layers_blocks is False + assert topo.block_size_position == -2 + + def test_llama4_hybrid_full_and_chunked(self): + """ + Llama4: Mix of FullAttentionSpec (global NoPE layers) and + ChunkedLocalAttentionSpec (local RoPE layers). + + Both backends inherit FlashAttn's get_kv_cache_shape, so + constructing TpKVTopology with either backend gives the same result. + This test documents that FA and ChunkedLocal-FA are interchangeable + for topology purposes. + """ + topo_fa = make_topo(MockFlashAttnBackend, is_mla=False) + topo_chunked = make_topo(MockChunkedLocalFABackend, is_mla=False) + + # Both should produce identical topology properties + assert topo_fa.is_kv_layout_blocks_first == topo_chunked.is_kv_layout_blocks_first + assert topo_fa.split_k_and_v == topo_chunked.split_k_and_v + assert topo_fa.cross_layers_blocks == topo_chunked.cross_layers_blocks + assert topo_fa.block_size_position == topo_chunked.block_size_position + + # Confirm they're standard FA properties + assert topo_fa.is_kv_layout_blocks_first is False + assert topo_fa.split_k_and_v is True + + def test_gemma3_sliding_window(self): + """ + Gemma3: All layers use FullAttentionSpec (some with sliding_window set). + From TpKVTopology's perspective, sliding_window doesn't change the + backend or cache shape. All layers use the same FA backend. + """ + # sliding_window is a KVCacheSpec concern, not a backend shape concern + topo = make_topo(MockFlashAttnBackend, is_mla=False) + assert topo.is_kv_layout_blocks_first is False + assert topo.split_k_and_v is True + assert topo.block_size_position == -3 + + def test_jamba_hybrid_mamba_backend_crashes(self): + """ + Jamba-like hybrid: If get_current_attn_backend() returns a Mamba + backend (because the first layer is Mamba), TpKVTopology construction + crashes because Mamba backends don't implement get_kv_cache_shape. + + This documents the current limitation that NIXL cannot work with + models where the first layer is a Mamba layer. + """ + with pytest.raises(NotImplementedError): + make_topo(MockMambaBackend, is_mla=False) + + def test_jamba_hybrid_attention_first_works(self): + """ + Jamba-like hybrid: If the first layer is an attention layer, + get_current_attn_backend() returns FA, and TpKVTopology works. + The Mamba layers are simply not registered with NIXL (they use + separate state management). + """ + # Simulates the case where first layer happens to be attention + topo = make_topo(MockFlashAttnBackend, is_mla=False) + assert topo.is_kv_layout_blocks_first is False + assert topo.split_k_and_v is True + + def test_flashinfer_with_chunked_local_inheriting(self): + """ + If a model uses ChunkedLocal attention backed by FlashInfer, + verify the topology correctly detects the FI layout. + """ + + class MockChunkedLocalFIBackend(MockFlashInferBackend): + @staticmethod + def get_name() -> str: + return "MOCK_CHUNKED_LOCAL_FI" + + topo = make_topo(MockChunkedLocalFIBackend, is_mla=False) + assert topo.is_kv_layout_blocks_first is True + assert topo.split_k_and_v is False + + def test_mixed_fa_and_fi_backends_differ(self): + """ + Hypothetical model with both FA and FI layers. + TpKVTopology constructed with FA vs FI gives different properties. + This documents why a single backend assumption matters. + """ + topo_fa = make_topo(MockFlashAttnBackend, is_mla=False) + topo_fi = make_topo(MockFlashInferBackend, is_mla=False) + + # Key property that differs between the two + assert topo_fa.is_kv_layout_blocks_first is False + assert topo_fi.is_kv_layout_blocks_first is True + + # split_k_and_v also differs + assert topo_fa.split_k_and_v is True + assert topo_fi.split_k_and_v is False + + # block_size_position is the same though + assert topo_fa.block_size_position == topo_fi.block_size_position == -3 + + +# =================================================================== +# 4. get_current_attn_backend Behavior Tests +# =================================================================== +class TestGetCurrentAttnBackend: + """ + Test get_current_attn_backend behavior with mocked static_forward_context. + """ + + def test_returns_first_layers_backend(self): + """ + get_current_attn_backend iterates static_forward_context (dict order) + and returns the first layer's backend. + """ + from unittest.mock import MagicMock, patch + + from vllm.distributed.kv_transfer.kv_connector.utils import ( + get_current_attn_backend, + ) + from vllm.model_executor.layers.attention_layer_base import ( + AttentionLayerBase, + ) + + # Create mock layers with different backends + layer0 = MagicMock(spec=AttentionLayerBase) + layer0.get_attn_backend.return_value = MockFlashAttnBackend + + layer1 = MagicMock(spec=AttentionLayerBase) + layer1.get_attn_backend.return_value = MockFlashInferBackend + + mock_context = {"attn_layer_0": layer0, "attn_layer_1": layer1} + + mock_config = MagicMock() + mock_config.compilation_config.static_forward_context = mock_context + + backend = get_current_attn_backend(mock_config) + assert backend is MockFlashAttnBackend + + def test_returns_second_when_first_is_different(self): + """ + Verify that only the FIRST layer's backend is returned, + even if subsequent layers use a different backend. + """ + from unittest.mock import MagicMock + + from vllm.distributed.kv_transfer.kv_connector.utils import ( + get_current_attn_backend, + ) + from vllm.model_executor.layers.attention_layer_base import ( + AttentionLayerBase, + ) + + # First layer is FI, second is FA + layer0 = MagicMock(spec=AttentionLayerBase) + layer0.get_attn_backend.return_value = MockFlashInferBackend + + layer1 = MagicMock(spec=AttentionLayerBase) + layer1.get_attn_backend.return_value = MockFlashAttnBackend + + mock_context = {"layer_0": layer0, "layer_1": layer1} + + mock_config = MagicMock() + mock_config.compilation_config.static_forward_context = mock_context + + backend = get_current_attn_backend(mock_config) + # Should be the first one + assert backend is MockFlashInferBackend + + def test_mamba_first_layer_returns_mamba(self): + """ + If the first layer is Mamba, get_current_attn_backend returns + the Mamba backend. This would cause TpKVTopology to crash. + + This documents the current problematic behavior that needs fixing: + get_current_attn_backend should skip non-attention backends. + """ + from unittest.mock import MagicMock + + from vllm.distributed.kv_transfer.kv_connector.utils import ( + get_current_attn_backend, + ) + from vllm.model_executor.layers.attention_layer_base import ( + AttentionLayerBase, + ) + + # First layer is Mamba, second is FA + mamba_layer = MagicMock(spec=AttentionLayerBase) + mamba_layer.get_attn_backend.return_value = MockMambaBackend + + attn_layer = MagicMock(spec=AttentionLayerBase) + attn_layer.get_attn_backend.return_value = MockFlashAttnBackend + + mock_context = {"mamba_layer_0": mamba_layer, "attn_layer_0": attn_layer} + + mock_config = MagicMock() + mock_config.compilation_config.static_forward_context = mock_context + + backend = get_current_attn_backend(mock_config) + # Current behavior: returns Mamba (the first layer's backend) + assert backend is MockMambaBackend + + # This will crash TpKVTopology: + with pytest.raises(NotImplementedError): + make_topo(backend, is_mla=False) + + def test_fallback_when_no_layers(self): + """ + When static_forward_context is empty, get_current_attn_backend + falls back to the attention selector. + """ + from unittest.mock import MagicMock, patch + + from vllm.distributed.kv_transfer.kv_connector.utils import ( + get_current_attn_backend, + ) + + mock_config = MagicMock() + mock_config.compilation_config.static_forward_context = {} + mock_config.model_config.get_head_size.return_value = 64 + mock_config.model_config.dtype = torch.float16 + mock_config.cache_config.cache_dtype = "auto" + mock_config.cache_config.block_size = 16 + mock_config.model_config.use_mla = False + + with patch( + "vllm.distributed.kv_transfer.kv_connector.utils.get_attn_backend" + ) as mock_selector: + mock_selector.return_value = MockFlashAttnBackend + backend = get_current_attn_backend(mock_config) + assert backend is MockFlashAttnBackend + mock_selector.assert_called_once() + + def test_all_layers_same_backend_consistency(self): + """ + When all layers use the same backend, any layer can be used + to construct TpKVTopology with identical results. + """ + from unittest.mock import MagicMock + + from vllm.distributed.kv_transfer.kv_connector.utils import ( + get_current_attn_backend, + ) + from vllm.model_executor.layers.attention_layer_base import ( + AttentionLayerBase, + ) + + layers = {} + for i in range(10): + layer = MagicMock(spec=AttentionLayerBase) + layer.get_attn_backend.return_value = MockFlashAttnBackend + layers[f"layer_{i}"] = layer + + mock_config = MagicMock() + mock_config.compilation_config.static_forward_context = layers + + backend = get_current_attn_backend(mock_config) + assert backend is MockFlashAttnBackend + + # All produce the same topology + topo = make_topo(backend) + for layer in layers.values(): + other = make_topo(layer.get_attn_backend()) + assert topo.is_kv_layout_blocks_first == other.is_kv_layout_blocks_first + assert topo.split_k_and_v == other.split_k_and_v + assert topo.block_size_position == other.block_size_position diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a86a52a6a6fb..1d02549f2ed9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -970,15 +970,18 @@ def request_finished( delay_free_blocks = any(len(group) > 0 for group in block_ids) if delay_free_blocks: - # Prefill request on remote. It will be read from D upon completion + # Prefill request on remote. It will be read from D upon completion. + # Use initial lease duration - D will send heartbeats to extend. + initial_lease = envs.VLLM_NIXL_INITIAL_LEASE_DURATION logger.debug( - "NIXLConnector request_finished(%s) waiting for %d seconds " - "for remote decode to fetch blocks", + "NIXLConnector request_finished(%s) waiting for initial lease " + "of %d seconds for remote decode to fetch blocks " + "(will be extended by heartbeats)", request.request_id, - envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, + initial_lease, ) self._reqs_need_send[request.request_id] = ( - time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + time.perf_counter() + initial_lease ) # NOTE HMA will "mark" empty/null blocks in groups with 0s (eg SWA ones), # trimming down after allocating for the whole sequence length. Empty @@ -1188,6 +1191,18 @@ def __init__( # requests that skipped transfer (handshake or transfer failures) self._failed_recv_reqs: set[ReqId] = set() + # Heartbeat/lease management for D-side (consumer). + # Track pending transfers by P engine for batched heartbeat sending. + self._pending_transfers_by_engine: dict[EngineId, set[ReqId]] = defaultdict( + set + ) + # Last time we sent a heartbeat to each P engine. + self._last_lease_renewal: dict[EngineId, float] = {} + # Lease renewal interval from env var. + self._lease_renewal_interval: float = float( + envs.VLLM_NIXL_LEASE_RENEWAL_INTERVAL + ) + # Handshake metadata of this worker for NIXL transfers. self.xfer_handshake_metadata: NixlHandshakePayload | None = None # Background thread for initializing new NIXL handshakes. @@ -2295,6 +2310,10 @@ def get_finished(self) -> tuple[set[str], set[str]]: meta = self._recving_metadata.pop(req_id, None) assert meta is not None, f"{req_id} not found in recving_metadata list" assert meta.remote is not None + + # Remove from pending transfers (D-side heartbeat tracking). + self._pending_transfers_by_engine[meta.remote.engine_id].discard(req_id) + if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) @@ -2315,7 +2334,9 @@ def get_finished(self) -> tuple[set[str], set[str]]: ) in block_ids_for_blocksize_post_process.items(): self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list) - # Handle timeout to avoid stranding blocks on remote. + # Handle lease expiration to avoid stranding blocks on remote. + # Leases start with VLLM_NIXL_INITIAL_LEASE_DURATION and are extended + # by heartbeats from D. If no heartbeat is received, lease expires. now = time.perf_counter() while self._reqs_to_send: req_id, expires = next(iter(self._reqs_to_send.items())) @@ -2325,11 +2346,11 @@ def get_finished(self) -> tuple[set[str], set[str]]: count = self.consumer_notification_counts_by_req.pop(req_id, 0) self.xfer_stats.record_kv_expired_req() logger.warning( - "Releasing expired KV blocks for request %s which were " - "retrieved by %d decode worker(s) within %d seconds.", + "Releasing KV blocks for request %s: lease expired " + "(retrieved by %d decode worker(s)). This may indicate " + "D crashed or network issues preventing heartbeats.", req_id, count, - envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, ) self._reqs_to_process.remove(req_id) del self._reqs_to_send[req_id] @@ -2342,12 +2363,40 @@ def _get_new_notifs(self) -> set[str]: Get req_ids which got a remote xfer message. When multiple consumers are reading from the same producer (heterogeneous TP scenario), wait for all consumers to be done pulling. + + Also handles heartbeat messages from D (consumer) to extend KV block + leases. Heartbeat format: "HB:,,..." """ assert self.kv_topo is not None notified_req_ids: set[str] = set() + now = time.perf_counter() + lease_extension = float(envs.VLLM_NIXL_LEASE_EXTENSION) + for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: - req_id, tp_size = notif.decode("utf-8").rsplit(":", 1) + decoded = notif.decode("utf-8") + + # Handle heartbeat messages from D (consumer). + if decoded.startswith("HB:"): + req_ids_str = decoded[3:] # Skip "HB:" prefix + if not req_ids_str: + continue + heartbeat_req_ids = req_ids_str.split(",") + extended_count = 0 + for req_id in heartbeat_req_ids: + if req_id in self._reqs_to_send: + # Extend the lease for this request. + self._reqs_to_send[req_id] = now + lease_extension + extended_count += 1 + if extended_count > 0: + logger.debug( + "Extended lease for %d requests via heartbeat", + extended_count, + ) + continue + + # Handle completion notifications (original format: "req_id:tp_size") + req_id, tp_size = decoded.rsplit(":", 1) if ( req_id not in self._reqs_to_send and req_id not in self._reqs_to_process @@ -2503,12 +2552,63 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): if req_id in self._reqs_to_process: self._reqs_to_send[req_id] = expiration_time + # Send heartbeats to P engines with pending transfers (D-side). + self._send_lease_heartbeats() + + def _send_lease_heartbeats(self) -> None: + """ + Send batched heartbeat notifications to P engines to extend KV block leases. + + This is called periodically from start_load_kv on the D (consumer) side. + Heartbeats are batched per P engine to minimize notification overhead. + Format: "HB:,,..." + """ + now = time.perf_counter() + + for engine_id, req_ids in self._pending_transfers_by_engine.items(): + if not req_ids: + continue + + # Check if enough time has passed since last heartbeat to this engine. + last_renewal = self._last_lease_renewal.get(engine_id, 0.0) + if now - last_renewal < self._lease_renewal_interval: + continue + + # Get any agent for this engine (all P workers share state). + remote_agents = self._remote_agents.get(engine_id) + if not remote_agents: + # Handshake not yet complete for this engine. + continue + + # Build batched heartbeat message: "HB:req1,req2,req3,..." + heartbeat_msg = ("HB:" + ",".join(req_ids)).encode() + + # Send to the first available agent for this engine. + agent_name = next(iter(remote_agents.values())) + try: + self.nixl_wrapper.send_notif(agent_name, notif_msg=heartbeat_msg) + self._last_lease_renewal[engine_id] = now + logger.debug( + "Sent heartbeat to engine %s for %d pending requests", + engine_id, + len(req_ids), + ) + except Exception as e: + logger.warning( + "Failed to send heartbeat to engine %s: %s", engine_id, e + ) + def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): assert meta.remote is not None and self.kv_topo is not None + remote_engine_id = meta.remote.engine_id + + # Track this transfer for heartbeat sending (D-side lease renewal). + self._pending_transfers_by_engine[remote_engine_id].add(req_id) + remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( - meta.remote.engine_id + remote_engine_id ) - tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id) + tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) # D may have to perform multiple reads from different remote ranks. for i, remote_rank in enumerate(remote_ranks): if self.use_mla and tp_ratio < 0 and i > 0: diff --git a/vllm/envs.py b/vllm/envs.py index 2f93b2cb3e0d..45ff4b3ace4b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -191,6 +191,9 @@ VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 + VLLM_NIXL_INITIAL_LEASE_DURATION: int = 15 + VLLM_NIXL_LEASE_EXTENSION: int = 30 + VLLM_NIXL_LEASE_RENEWAL_INTERVAL: int = 5 VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False VLLM_MORIIO_QP_PER_TRANSFER: int = 1 VLLM_MORIIO_POST_BATCH_SIZE: int = -1 @@ -1391,6 +1394,22 @@ def _get_or_set_default() -> str: "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") ), + # Initial lease duration (in seconds) for KV blocks on the producer (P) + # side. If no heartbeat is received from the consumer (D) within this + # time, the blocks will be freed. Used in disaggregated P/D setup. + "VLLM_NIXL_INITIAL_LEASE_DURATION": lambda: int( + os.getenv("VLLM_NIXL_INITIAL_LEASE_DURATION", "15") + ), + # Lease extension (in seconds) granted when a heartbeat is received from + # the consumer (D). Each heartbeat extends the lease by this amount. + "VLLM_NIXL_LEASE_EXTENSION": lambda: int( + os.getenv("VLLM_NIXL_LEASE_EXTENSION", "30") + ), + # Interval (in seconds) at which the consumer (D) sends heartbeat + # notifications to the producer (P) to extend KV block leases. + "VLLM_NIXL_LEASE_RENEWAL_INTERVAL": lambda: int( + os.getenv("VLLM_NIXL_LEASE_RENEWAL_INTERVAL", "5") + ), # Controls the read mode for the Mori-IO connector "VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: ( os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1") From 5e76e7bfb63fbfcc3a66e592bcc0cdde3815bdd9 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 24 Mar 2026 15:40:08 +0100 Subject: [PATCH 2/8] clean up and handle p tp > d tp Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 103 ++++++++++-------- vllm/envs.py | 28 ++--- 2 files changed, 69 insertions(+), 62 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1d02549f2ed9..afef4a0f66f2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -972,7 +972,7 @@ def request_finished( if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion. # Use initial lease duration - D will send heartbeats to extend. - initial_lease = envs.VLLM_NIXL_INITIAL_LEASE_DURATION + initial_lease = envs.VLLM_NIXL_KV_LEASE_DURATION logger.debug( "NIXLConnector request_finished(%s) waiting for initial lease " "of %d seconds for remote decode to fetch blocks " @@ -1178,9 +1178,14 @@ def __init__( self._registered_descs: list[Any] = [] # In progress transfers. - # [req_id -> list[handle]] + # In-progress transfer tracking (D-side / consumer). + # Keyed by req_id to ensure ALL handles complete before marking done. self._recving_metadata: dict[ReqId, ReqMeta] = {} self._recving_transfers = defaultdict[ReqId, list[TransferHandle]](list) + # Track which engines have pending transfers for each request. + # Used for batched heartbeat sending per P engine. + self._pending_transfers_by_engine = defaultdict[EngineId, set[ReqId]](set) + # Track the expiration time of requests that are waiting to be sent. self._reqs_to_send: dict[ReqId, float] = {} # Set of requests that have been part of a batch, regardless of status. @@ -1192,16 +1197,11 @@ def __init__( self._failed_recv_reqs: set[ReqId] = set() # Heartbeat/lease management for D-side (consumer). - # Track pending transfers by P engine for batched heartbeat sending. - self._pending_transfers_by_engine: dict[EngineId, set[ReqId]] = defaultdict( - set - ) - # Last time we sent a heartbeat to each P engine. - self._last_lease_renewal: dict[EngineId, float] = {} - # Lease renewal interval from env var. - self._lease_renewal_interval: float = float( - envs.VLLM_NIXL_LEASE_RENEWAL_INTERVAL - ) + # Single timestamp suffices - heartbeat interval limits overall send rate, + # not per-engine. New engines get fresh leases on P-side anyway. + self._last_heartbeat_time: float = 0.0 + self._heartbeat_interval: float = float(envs.VLLM_NIXL_KV_HEARTBEAT_INTERVAL) + self._lease_extension = float(envs.VLLM_NIXL_KV_LEASE_EXTENSION) # Handshake metadata of this worker for NIXL transfers. self.xfer_handshake_metadata: NixlHandshakePayload | None = None @@ -1501,6 +1501,12 @@ def _background_nixl_handshake( fut = self._handshake_futures.get(remote_engine_id) if fut is None: assert meta.remote is not None + # Opportunistically clean up empty (dead?) remote engine_ids entries from + # _pending_transfers_by_engine. Asymptotically this data structure can only + # grow indefinitely when new P remotes are added. + for k in list(self._pending_transfers_by_engine.keys()): + if not self._pending_transfers_by_engine[k]: + del self._pending_transfers_by_engine[k] fut = self._handshake_initiation_executor.submit( self._nixl_handshake, meta.remote.host, @@ -2311,9 +2317,6 @@ def get_finished(self) -> tuple[set[str], set[str]]: assert meta is not None, f"{req_id} not found in recving_metadata list" assert meta.remote is not None - # Remove from pending transfers (D-side heartbeat tracking). - self._pending_transfers_by_engine[meta.remote.engine_id].discard(req_id) - if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) @@ -2335,7 +2338,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list) # Handle lease expiration to avoid stranding blocks on remote. - # Leases start with VLLM_NIXL_INITIAL_LEASE_DURATION and are extended + # Leases start with VLLM_NIXL_KV_LEASE_DURATION and are extended # by heartbeats from D. If no heartbeat is received, lease expires. now = time.perf_counter() while self._reqs_to_send: @@ -2370,24 +2373,27 @@ def _get_new_notifs(self) -> set[str]: assert self.kv_topo is not None notified_req_ids: set[str] = set() now = time.perf_counter() - lease_extension = float(envs.VLLM_NIXL_LEASE_EXTENSION) for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: decoded = notif.decode("utf-8") - # Handle heartbeat messages from D (consumer). + # Handle heartbeat messages from D (consumer) to hold KV Cache blocks. if decoded.startswith("HB:"): req_ids_str = decoded[3:] # Skip "HB:" prefix - if not req_ids_str: - continue heartbeat_req_ids = req_ids_str.split(",") extended_count = 0 for req_id in heartbeat_req_ids: if req_id in self._reqs_to_send: # Extend the lease for this request. - self._reqs_to_send[req_id] = now + lease_extension + self._reqs_to_send[req_id] = now + self._lease_extension extended_count += 1 + else: + logger.warning( + "Received heartbeat message for unknown request %s. " + "This may indicate the request has already expired.", + req_id, + ) if extended_count > 0: logger.debug( "Extended lease for %d requests via heartbeat", @@ -2474,8 +2480,12 @@ def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: # Only report request as completed when all transfers are done. done_req_ids.add(req_id) del transfers[req_id] + # Clean up from pending_transfers_by_engine + for engine_reqs in self._pending_transfers_by_engine.values(): + engine_reqs.discard(req_id) else: transfers[req_id] = in_progress + return done_req_ids def _handle_failed_transfer(self, req_id: str, handle: int): @@ -2565,16 +2575,16 @@ def _send_lease_heartbeats(self) -> None: """ now = time.perf_counter() + # Check if enough time has passed since last heartbeat. + if now - self._last_heartbeat_time < self._heartbeat_interval: + return + + sent_any = False for engine_id, req_ids in self._pending_transfers_by_engine.items(): if not req_ids: continue - # Check if enough time has passed since last heartbeat to this engine. - last_renewal = self._last_lease_renewal.get(engine_id, 0.0) - if now - last_renewal < self._lease_renewal_interval: - continue - - # Get any agent for this engine (all P workers share state). + # Get agents for this engine. remote_agents = self._remote_agents.get(engine_id) if not remote_agents: # Handshake not yet complete for this engine. @@ -2583,28 +2593,28 @@ def _send_lease_heartbeats(self) -> None: # Build batched heartbeat message: "HB:req1,req2,req3,..." heartbeat_msg = ("HB:" + ",".join(req_ids)).encode() - # Send to the first available agent for this engine. - agent_name = next(iter(remote_agents.values())) - try: - self.nixl_wrapper.send_notif(agent_name, notif_msg=heartbeat_msg) - self._last_lease_renewal[engine_id] = now - logger.debug( - "Sent heartbeat to engine %s for %d pending requests", - engine_id, - len(req_ids), - ) - except Exception as e: - logger.warning( - "Failed to send heartbeat to engine %s: %s", engine_id, e - ) + # Send to ALL remote agents we handhshaked with for this remote. + # Important for P TP > D TP case where we have multiple remote workers. + # For other cases, we actually only heartbeat one remote agent. + for agent_name in remote_agents.values(): + try: + self.nixl_wrapper.send_notif(agent_name, notif_msg=heartbeat_msg) + sent_any = True + except Exception as e: + logger.warning( + "Failed to send heartbeat to engine %s agent %s: %s", + engine_id, + agent_name, + e, + ) + + if sent_any: + self._last_heartbeat_time = now def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): assert meta.remote is not None and self.kv_topo is not None remote_engine_id = meta.remote.engine_id - # Track this transfer for heartbeat sending (D-side lease renewal). - self._pending_transfers_by_engine[remote_engine_id].add(req_id) - remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( remote_engine_id ) @@ -2784,8 +2794,10 @@ def _read_blocks( # Begin async xfer. self.nixl_wrapper.transfer(handle) - # Use handle to check completion in future step(). + # Track handle for completion checking (keyed by req_id). self._recving_transfers[request_id].append(handle) + # Track engine for batched heartbeat sending (keyed by remote id). + self._pending_transfers_by_engine[dst_engine_id].add(request_id) except Exception as e: # mark all (logical) blocks for this request as invalid self._log_failure( @@ -2981,6 +2993,7 @@ def shutdown(self): for handle in handles: self.nixl_wrapper.release_xfer_handle(handle) self._recving_transfers.clear() + self._pending_transfers_by_engine.clear() for handle in self.src_xfer_handles_by_block_size.values(): self.nixl_wrapper.release_dlist_handle(handle) self.src_xfer_handles_by_block_size.clear() diff --git a/vllm/envs.py b/vllm/envs.py index 45ff4b3ace4b..4da45c5cd092 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -190,10 +190,9 @@ ] = "NONE" VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None - VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 - VLLM_NIXL_INITIAL_LEASE_DURATION: int = 15 - VLLM_NIXL_LEASE_EXTENSION: int = 30 - VLLM_NIXL_LEASE_RENEWAL_INTERVAL: int = 5 + VLLM_NIXL_KV_LEASE_DURATION: int = 15 + VLLM_NIXL_KV_LEASE_EXTENSION: int = 30 + VLLM_NIXL_KV_HEARTBEAT_INTERVAL: int = 5 VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False VLLM_MORIIO_QP_PER_TRANSFER: int = 1 VLLM_MORIIO_POST_BATCH_SIZE: int = -1 @@ -1387,28 +1386,23 @@ def _get_or_set_default() -> str: "VLLM_USE_NVFP4_CT_EMULATIONS": lambda: bool( int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0")) ), - # Time (in seconds) after which the KV cache on the producer side is - # automatically cleared if no READ notification is received from the - # consumer. This is only applicable when using NixlConnector in a - # disaggregated decode-prefill setup. - "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( - os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") - ), # Initial lease duration (in seconds) for KV blocks on the producer (P) # side. If no heartbeat is received from the consumer (D) within this # time, the blocks will be freed. Used in disaggregated P/D setup. - "VLLM_NIXL_INITIAL_LEASE_DURATION": lambda: int( - os.getenv("VLLM_NIXL_INITIAL_LEASE_DURATION", "15") + # D sends periodic heartbeats to extend the lease (see KV_LEASE_EXTENSION). + "VLLM_NIXL_KV_LEASE_DURATION": lambda: int( + os.getenv("VLLM_NIXL_KV_LEASE_DURATION", "15") ), # Lease extension (in seconds) granted when a heartbeat is received from # the consumer (D). Each heartbeat extends the lease by this amount. - "VLLM_NIXL_LEASE_EXTENSION": lambda: int( - os.getenv("VLLM_NIXL_LEASE_EXTENSION", "30") + "VLLM_NIXL_KV_LEASE_EXTENSION": lambda: int( + os.getenv("VLLM_NIXL_KV_LEASE_EXTENSION", "30") ), # Interval (in seconds) at which the consumer (D) sends heartbeat # notifications to the producer (P) to extend KV block leases. - "VLLM_NIXL_LEASE_RENEWAL_INTERVAL": lambda: int( - os.getenv("VLLM_NIXL_LEASE_RENEWAL_INTERVAL", "5") + # Should be less than VLLM_NIXL_KV_LEASE_EXTENSION to ensure timely renewal. + "VLLM_NIXL_KV_HEARTBEAT_INTERVAL": lambda: int( + os.getenv("VLLM_NIXL_KV_HEARTBEAT_INTERVAL", "5") ), # Controls the read mode for the Mori-IO connector "VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: ( From 4575240c4f8bd59d56134dc49950f57a609e24b2 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 24 Mar 2026 15:40:20 +0100 Subject: [PATCH 3/8] update docs Signed-off-by: NickLucche --- docs/features/nixl_connector_usage.md | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index a9039f0daf84..f63f17b576c0 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -124,9 +124,18 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \ - Set when prefiller and decoder are on different machines - Connection info is passed via KVTransferParams from prefiller to decoder for handshake -- `VLLM_NIXL_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional) - - Default: 480 - - If a request is aborted and the decoder has not yet read the KV-cache blocks through the nixl channel, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely. +- `VLLM_NIXL_KV_LEASE_DURATION`: Initial lease duration (in seconds) for KV blocks on the prefiller. (Optional) + - Default: 15 + - After a prefill completes, the prefiller holds KV blocks for this duration. The decoder sends periodic heartbeats to extend the lease. + - If no heartbeat is received within the lease duration, blocks are released. + +- `VLLM_NIXL_KV_LEASE_EXTENSION`: Lease extension (in seconds) granted per heartbeat. (Optional) + - Default: 30 + - Each heartbeat from the decoder extends the lease by this amount. + +- `VLLM_NIXL_KV_HEARTBEAT_INTERVAL`: Interval (in seconds) at which the decoder sends heartbeats. (Optional) + - Default: 5 + - Should be less than `VLLM_NIXL_KV_LEASE_EXTENSION` to ensure timely renewal. ## Multi-Instance Setup From fb062465e6876ea5924f1118a73c0bf955f3e936 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 24 Mar 2026 15:40:41 +0100 Subject: [PATCH 4/8] tests update Signed-off-by: NickLucche --- .../kv_connector/unit/test_nixl_connector.py | 104 ++++++++++++++---- 1 file changed, 83 insertions(+), 21 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 028f55b3dbcd..e015ede1aaf9 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1346,7 +1346,7 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): timeout = 6 monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout)) + monkeypatch.setenv("VLLM_NIXL_KV_LEASE_DURATION", str(timeout)) def run_test_and_cleanup(): llm = LLM(**llm_kwargs) @@ -1361,7 +1361,7 @@ def run_test_and_cleanup(): runtime_env = { "working_dir": working_dir, # ship fake nixl package "env_vars": { - "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout), + "VLLM_NIXL_KV_LEASE_DURATION": str(timeout), # TODO: for ray to carry over, remove once we set "NIXL_TELEMETRY_ENABLE": "1", }, @@ -1810,6 +1810,7 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init): patch.object(nixl_wrapper, "deregister_memory") as mock_dereg, ): worker._recving_transfers = {"req1": [123]} + worker._pending_transfers_by_engine = {"engine1": {"req1"}} # Mock register_kv_cache which registers local handle worker.src_xfer_handles_by_block_size = {worker.block_size: 455} # P TP = 2 * D TP case, we should register 2 local handles @@ -2387,10 +2388,12 @@ class TestHeartbeatLeaseManagement: "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper, ) - def test_d_side_heartbeat_sending(self, default_vllm_config, dist_init, monkeypatch): + def test_d_side_heartbeat_sending( + self, default_vllm_config, dist_init, monkeypatch + ): """Test that D-side sends heartbeats to P engines with pending transfers.""" # Set a short renewal interval for testing - monkeypatch.setenv("VLLM_NIXL_LEASE_RENEWAL_INTERVAL", "0.1") + monkeypatch.setenv("VLLM_NIXL_KV_HEARTBEAT_INTERVAL", "0.1") vllm_config = create_vllm_config() connector = NixlConnector( @@ -2401,14 +2404,21 @@ def test_d_side_heartbeat_sending(self, default_vllm_config, dist_init, monkeypa ) worker = connector.connector_worker # Override the renewal interval since env var is read at init - worker._lease_renewal_interval = 0.1 + worker._heartbeat_interval = 0.1 # Simulate remote agent registration (handshake complete) remote_engine_id = FakeNixlConnectorWorker.REMOTE_ENGINE_ID worker._remote_agents[remote_engine_id] = {0: "fake_agent"} - # Track pending transfers for this engine - worker._pending_transfers_by_engine[remote_engine_id] = {"req1", "req2", "req3"} + # Track pending transfers for this engine. + # _recving_transfers keyed by req_id for completion tracking. + worker._recving_transfers = {"req1": [1], "req2": [2], "req3": [3]} + # _pending_transfers_by_engine keyed by engine_id for heartbeat batching. + worker._pending_transfers_by_engine[remote_engine_id] = { + "req1", + "req2", + "req3", + } # Track sent notifications sent_notifs: list[tuple[str, bytes]] = [] @@ -2438,6 +2448,57 @@ def mock_send_notif(agent_name: str, notif_msg: bytes): worker._send_lease_heartbeats() assert len(sent_notifs) == 2 # Now 2 + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_d_side_heartbeat_sends_to_all_p_workers( + self, default_vllm_config, dist_init, monkeypatch + ): + """Test that D-side sends heartbeats to ALL P workers (P TP > D TP case).""" + monkeypatch.setenv("VLLM_NIXL_KV_HEARTBEAT_INTERVAL", "0.1") + + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + worker._heartbeat_interval = 0.1 + + # Simulate P TP=2 case: two remote agents for same engine + remote_engine_id = FakeNixlConnectorWorker.REMOTE_ENGINE_ID + worker._remote_agents[remote_engine_id] = { + 0: "p_worker_0", + 1: "p_worker_1", + } + + # Track pending transfers + worker._recving_transfers = {"req1": [1]} + worker._pending_transfers_by_engine[remote_engine_id] = {"req1"} + + # Track sent notifications + sent_notifs: list[tuple[str, bytes]] = [] + + def mock_send_notif(agent_name: str, notif_msg: bytes): + sent_notifs.append((agent_name, notif_msg)) + + worker.nixl_wrapper.send_notif = mock_send_notif + + # Send heartbeats + worker._send_lease_heartbeats() + + # Should send to BOTH P workers + assert len(sent_notifs) == 2 + agent_names = {notif[0] for notif in sent_notifs} + assert agent_names == {"p_worker_0", "p_worker_1"} + + # Both should have the same message + for _, msg in sent_notifs: + assert msg == b"HB:req1" + @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper, @@ -2446,7 +2507,7 @@ def test_p_side_heartbeat_extends_lease( self, default_vllm_config, dist_init, monkeypatch ): """Test that P-side extends lease when receiving heartbeat.""" - monkeypatch.setenv("VLLM_NIXL_LEASE_EXTENSION", "30") + monkeypatch.setenv("VLLM_NIXL_KV_LEASE_EXTENSION", "30") vllm_config = create_vllm_config() connector = NixlConnector( @@ -2591,9 +2652,7 @@ def test_heartbeat_with_empty_requests(self, default_vllm_config, dist_init): "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper, ) - def test_d_side_cleanup_on_transfer_complete( - self, default_vllm_config, dist_init - ): + def test_d_side_cleanup_on_transfer_complete(self, default_vllm_config, dist_init): """Test that D-side removes completed transfers from heartbeat tracking.""" vllm_config = create_vllm_config() connector = NixlConnector( @@ -2642,12 +2701,12 @@ def test_d_side_cleanup_on_transfer_complete( ), ) - # Add to pending transfers (D-side heartbeat tracking) - worker._pending_transfers_by_engine[remote_engine_id].add(req_id) - - # Simulate transfer handle completion + # Simulate transfer handle completion. + # _recving_transfers keyed by req_id for completion tracking. handle = 12345 worker._recving_transfers[req_id] = [handle] + # _pending_transfers_by_engine keyed by engine_id for heartbeat batching. + worker._pending_transfers_by_engine[remote_engine_id] = {req_id} # Mock check_xfer_state to return DONE worker.nixl_wrapper._cycles_before_xfer_done = 0 @@ -2656,8 +2715,12 @@ def test_d_side_cleanup_on_transfer_complete( _, done_recving = connector.get_finished(finished_req_ids=set()) assert req_id in done_recving - # Should be removed from pending transfers - assert req_id not in worker._pending_transfers_by_engine[remote_engine_id] + # Should be removed from _recving_transfers + assert req_id not in worker._recving_transfers + # Should be removed from _pending_transfers_by_engine + assert req_id not in worker._pending_transfers_by_engine.get( + remote_engine_id, set() + ) @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", @@ -2677,10 +2740,11 @@ def test_heartbeat_send_failure_logged( vllm_config, connector.engine_id, hand_shake_latency=0 ) worker = connector.connector_worker - worker._lease_renewal_interval = 0 # Send immediately + worker._heartbeat_interval = 0 # Send immediately remote_engine_id = FakeNixlConnectorWorker.REMOTE_ENGINE_ID worker._remote_agents[remote_engine_id] = {0: "fake_agent"} + worker._recving_transfers = {"req1": [1]} worker._pending_transfers_by_engine[remote_engine_id] = {"req1"} # Make send_notif raise an exception @@ -2712,9 +2776,7 @@ def emit(self, record): # Verify warning was logged warning_logs = [r for r in captured_logs if r.levelno == logging.WARNING] assert len(warning_logs) >= 1 - assert any( - "Failed to send heartbeat" in r.message for r in warning_logs - ) + assert any("Failed to send heartbeat" in r.message for r in warning_logs) @pytest.mark.parametrize( From 0d699aef7dcacc521159869f0ff214a6c10c3284 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 24 Mar 2026 17:34:55 +0000 Subject: [PATCH 5/8] cruft Signed-off-by: NickLucche --- .../kv_connector/unit/test_tp_kv_topology.py | 685 ------------------ 1 file changed, 685 deletions(-) delete mode 100644 tests/v1/kv_connector/unit/test_tp_kv_topology.py diff --git a/tests/v1/kv_connector/unit/test_tp_kv_topology.py b/tests/v1/kv_connector/unit/test_tp_kv_topology.py deleted file mode 100644 index 8c0552d5d33a..000000000000 --- a/tests/v1/kv_connector/unit/test_tp_kv_topology.py +++ /dev/null @@ -1,685 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Unit tests for TpKVTopology with various attention backend configurations. - -These tests validate layout detection, block_size_position, and multi-backend -model behavior without loading any models. We use mock backends that replicate -the get_kv_cache_shape signatures of real backends. - -Backend shape families: - - FlashAttn-like: (2, N, B, H, D) -- KV-first - - FlashInfer-like: (N, 2, B, H, D) -- blocks-first - - MLA-like: (N, B, D) -- 3-dim, no KV split - - Mamba-like: NotImplementedError -- no KV cache shape - - TritonAttn-like: (N, 2, B, H, D) -- blocks-first (same as FI) -""" - -import pytest -import torch - -from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology -from vllm.v1.attention.backend import AttentionBackend - - -# --------------------------------------------------------------------------- -# Mock Attention Backends -# --------------------------------------------------------------------------- -class MockFlashAttnBackend(AttentionBackend): - """Mimics FlashAttentionBackend: shape = (2, N, B, H, D)""" - - @staticmethod - def get_name() -> str: - return "MOCK_FLASH_ATTN" - - @staticmethod - def get_impl_cls(): - raise NotImplementedError - - @staticmethod - def get_builder_cls(): - raise NotImplementedError - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - return (2, num_blocks, block_size, num_kv_heads, head_size) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - if include_num_layers_dimension: - # HND cross-layer: (num_blocks, num_kv_heads, num_layers, 2, - # block_size, head_size) - return (2, 4, 0, 1, 3, 5) - return (0, 1, 3, 2, 4) - - -class MockFlashInferBackend(AttentionBackend): - """Mimics FlashInferBackend: shape = (N, 2, B, H, D)""" - - @staticmethod - def get_name() -> str: - return "MOCK_FLASHINFER" - - @staticmethod - def get_impl_cls(): - raise NotImplementedError - - @staticmethod - def get_builder_cls(): - raise NotImplementedError - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - return (num_blocks, 2, block_size, num_kv_heads, head_size) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - if include_num_layers_dimension: - return (1, 0, 2, 3, 4, 5) - return (0, 1, 2, 3, 4) - - -class MockTritonAttnBackend(AttentionBackend): - """Mimics TritonAttentionBackend: shape = (N, 2, B, H, D) -- same as FI""" - - @staticmethod - def get_name() -> str: - return "MOCK_TRITON_ATTN" - - @staticmethod - def get_impl_cls(): - raise NotImplementedError - - @staticmethod - def get_builder_cls(): - raise NotImplementedError - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - return (num_blocks, 2, block_size, num_kv_heads, head_size) - - -class MockMLABackend(AttentionBackend): - """Mimics MLA backends (FlashMLA, etc.): shape = (N, B, D) -- 3 dims""" - - @staticmethod - def get_name() -> str: - return "MOCK_MLA" - - @staticmethod - def get_impl_cls(): - raise NotImplementedError - - @staticmethod - def get_builder_cls(): - raise NotImplementedError - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - assert num_kv_heads == 1 - return (num_blocks, block_size, head_size) - - -class MockMambaBackend(AttentionBackend): - """Mimics Mamba backends: get_kv_cache_shape is not implemented.""" - - @staticmethod - def get_name() -> str: - return "MOCK_MAMBA" - - @staticmethod - def get_impl_cls(): - raise NotImplementedError - - @staticmethod - def get_builder_cls(): - raise NotImplementedError - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - raise NotImplementedError("Mamba backends do not have a KV cache shape") - - -# A ChunkedLocal backend that inherits FA's get_kv_cache_shape, -# exactly as the real ChunkedLocalAttention backend does. -class MockChunkedLocalFABackend(MockFlashAttnBackend): - """ - Mimics ChunkedLocalAttention backed by FlashAttn. - Inherits get_kv_cache_shape from FlashAttn -- same layout. - """ - - @staticmethod - def get_name() -> str: - return "MOCK_CHUNKED_LOCAL_FA" - - -class MockCPUAttnBackend(AttentionBackend): - """ - Mimics CPU attention backend: shape = (2, N, H, B, D) - Note different position of block_size vs num_kv_heads compared to FA. - """ - - @staticmethod - def get_name() -> str: - return "MOCK_CPU_ATTN" - - @staticmethod - def get_impl_cls(): - raise NotImplementedError - - @staticmethod - def get_builder_cls(): - raise NotImplementedError - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - return (2, num_blocks, num_kv_heads, block_size, head_size) - - -# --------------------------------------------------------------------------- -# Helper to build TpKVTopology with minimal required fields -# --------------------------------------------------------------------------- -def make_topo( - attn_backend: type[AttentionBackend], - is_mla: bool = False, - tp_rank: int = 0, - tp_size: int = 1, - total_num_kv_heads: int = 8, - block_size: int = 16, - tensor_shape: torch.Size | None = None, - engine_id: str = "test-engine", -) -> TpKVTopology: - remote_tp_size = {engine_id: tp_size} - remote_block_size = {engine_id: block_size} - return TpKVTopology( - tp_rank=tp_rank, - engine_id=engine_id, - remote_tp_size=remote_tp_size, - remote_block_size=remote_block_size, - is_mla=is_mla, - total_num_kv_heads=total_num_kv_heads, - attn_backend=attn_backend, - tensor_shape=tensor_shape, - ) - - -# =================================================================== -# 1. Layout Detection Tests -# =================================================================== -class TestLayoutDetection: - """Test is_kv_layout_blocks_first and split_k_and_v for each backend.""" - - def test_flash_attn_standard(self): - """FA: (2, N, B, H, D) -> blocks_first=False, split_k_and_v=True""" - topo = make_topo(MockFlashAttnBackend) - assert topo.is_kv_layout_blocks_first is False - assert topo.split_k_and_v is True - assert topo.cross_layers_blocks is False - - def test_flashinfer_standard(self): - """FI: (N, 2, B, H, D) -> blocks_first=True, split_k_and_v=False""" - topo = make_topo(MockFlashInferBackend) - assert topo.is_kv_layout_blocks_first is True - assert topo.split_k_and_v is False - assert topo.cross_layers_blocks is False - - def test_triton_attn_standard(self): - """Triton: (N, 2, B, H, D) -> same as FI (blocks_first=True)""" - topo = make_topo(MockTritonAttnBackend) - assert topo.is_kv_layout_blocks_first is True - assert topo.split_k_and_v is False - - def test_flash_attn_mla(self): - """FA with MLA: blocks_first=False, split_k_and_v=False (MLA overrides)""" - topo = make_topo(MockFlashAttnBackend, is_mla=True) - assert topo.is_kv_layout_blocks_first is False - assert topo.split_k_and_v is False - - def test_flashinfer_mla(self): - """FI with MLA: blocks_first=True, split_k_and_v=False""" - topo = make_topo(MockFlashInferBackend, is_mla=True) - assert topo.is_kv_layout_blocks_first is True - assert topo.split_k_and_v is False - - def test_mla_backend_3dim(self): - """ - Pure MLA backend (3-dim shape): blocks_first=False. - Shape is (N, B, D) -- 3 dims, first dim is num_blocks=1 (mock), - so the 5-dim blocks_first check fails. - """ - topo = make_topo(MockMLABackend, is_mla=True, total_num_kv_heads=1) - assert topo.is_kv_layout_blocks_first is False - assert topo.split_k_and_v is False - - def test_flash_attn_cross_layers(self): - """ - FA with cross-layer blocks: tensor_shape has one extra dim. - Shape from backend = (2, 1, 16, 1, 1) -> 5 dims - tensor_shape = (80, 2, 1, 16, 1, 1) -> 6 dims = 5 + 1 - => cross_layers_blocks=True, split_k_and_v=False - """ - kv_shape = MockFlashAttnBackend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - cross_layer_shape = torch.Size((80,) + kv_shape) - topo = make_topo(MockFlashAttnBackend, tensor_shape=cross_layer_shape) - assert topo.cross_layers_blocks is True - assert topo.split_k_and_v is False - - def test_flashinfer_cross_layers(self): - """FI with cross-layer blocks.""" - kv_shape = MockFlashInferBackend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - cross_layer_shape = torch.Size((80,) + kv_shape) - topo = make_topo(MockFlashInferBackend, tensor_shape=cross_layer_shape) - assert topo.cross_layers_blocks is True - assert topo.split_k_and_v is False - - def test_no_cross_layers_same_ndim(self): - """ - When tensor_shape has same ndim as kv_cache_shape, - cross_layers_blocks should be False. - """ - kv_shape = MockFlashAttnBackend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - topo = make_topo( - MockFlashAttnBackend, tensor_shape=torch.Size(kv_shape) - ) - assert topo.cross_layers_blocks is False - - def test_cpu_attn_layout(self): - """ - CPU attention: (2, N, H, B, D). - Not blocks_first (first dim is 2 with num_blocks mocked to 1), - and kv_cache_shape[0] != 1 when we have 5 dims. - """ - topo = make_topo(MockCPUAttnBackend) - # Shape with mocked values: (2, 1, 1, 16, 1) - # First dim = 2 (not 1), and len=5, so blocks_first check: - # len == 5 and shape[0] == 1? -> 5 dims but shape[0]=2 -> False - assert topo.is_kv_layout_blocks_first is False - assert topo.split_k_and_v is True - - -# =================================================================== -# 2. Block Size Position Tests -# =================================================================== -class TestBlockSizePosition: - """ - Verify block_size_position is correctly detected. - block_size_position is a negative index into the shape indicating - where the block_size dimension lives. - """ - - def test_flash_attn_block_size_position(self): - """FA shape: (2, N, B=16, H, D) -> B is at index 2, negative = -3""" - topo = make_topo(MockFlashAttnBackend) - assert topo.block_size_position == -3 - - def test_flashinfer_block_size_position(self): - """FI shape: (N, 2, B=16, H, D) -> B is at index 2, negative = -3""" - topo = make_topo(MockFlashInferBackend) - assert topo.block_size_position == -3 - - def test_mla_block_size_position(self): - """MLA shape: (N, B=16, D) -> B is at index 1, negative = -2""" - topo = make_topo(MockMLABackend, is_mla=True, total_num_kv_heads=1) - assert topo.block_size_position == -2 - - def test_cpu_attn_block_size_position(self): - """CPU shape: (2, N, H, B=16, D) -> B is at index 3, negative = -2""" - topo = make_topo(MockCPUAttnBackend) - assert topo.block_size_position == -2 - - def test_flash_attn_cross_layers_block_size_position(self): - """ - FA cross-layer: logical shape (L, 2, N, B, H, D), but after - stride_order permutation for HND cross-layer, the physical position - of B changes. - - Stride order for FA HND cross-layer: (2, 4, 0, 1, 3, 5) - Logical shape: (80, 2, 1, 16, 1, 1) - After permute: shape[2,4,0,1,3,5] = (1, 1, 80, 2, 16, 1) - B=16 is at physical index 4 -> negative = -2 - """ - kv_shape = MockFlashAttnBackend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - cross_layer_shape = torch.Size((80,) + kv_shape) - topo = make_topo(MockFlashAttnBackend, tensor_shape=cross_layer_shape) - assert topo.cross_layers_blocks is True - assert topo.block_size_position == -2 - - -# =================================================================== -# 3. Multi-Backend Model Configuration Tests -# =================================================================== -class TestMultiBackendModels: - """ - Test TpKVTopology behavior for model architectures that use - different attention backends across layers. - """ - - def test_qwen3_like_uniform_full_attn(self): - """ - Qwen3-like: All layers use FullAttentionSpec with FlashAttn backend. - Single backend family, all properties should be standard FA. - """ - topo = make_topo(MockFlashAttnBackend, is_mla=False) - assert topo.is_kv_layout_blocks_first is False - assert topo.split_k_and_v is True - assert topo.cross_layers_blocks is False - assert topo.block_size_position == -3 - - def test_deepseek_v3_mla(self): - """ - DeepSeek V3: All layers use MLAAttentionSpec with MLA backend. - 3-dim shape, is_mla=True, no KV split. - """ - topo = make_topo(MockMLABackend, is_mla=True, total_num_kv_heads=1) - assert topo.is_kv_layout_blocks_first is False - assert topo.split_k_and_v is False - assert topo.cross_layers_blocks is False - assert topo.block_size_position == -2 - - def test_llama4_hybrid_full_and_chunked(self): - """ - Llama4: Mix of FullAttentionSpec (global NoPE layers) and - ChunkedLocalAttentionSpec (local RoPE layers). - - Both backends inherit FlashAttn's get_kv_cache_shape, so - constructing TpKVTopology with either backend gives the same result. - This test documents that FA and ChunkedLocal-FA are interchangeable - for topology purposes. - """ - topo_fa = make_topo(MockFlashAttnBackend, is_mla=False) - topo_chunked = make_topo(MockChunkedLocalFABackend, is_mla=False) - - # Both should produce identical topology properties - assert topo_fa.is_kv_layout_blocks_first == topo_chunked.is_kv_layout_blocks_first - assert topo_fa.split_k_and_v == topo_chunked.split_k_and_v - assert topo_fa.cross_layers_blocks == topo_chunked.cross_layers_blocks - assert topo_fa.block_size_position == topo_chunked.block_size_position - - # Confirm they're standard FA properties - assert topo_fa.is_kv_layout_blocks_first is False - assert topo_fa.split_k_and_v is True - - def test_gemma3_sliding_window(self): - """ - Gemma3: All layers use FullAttentionSpec (some with sliding_window set). - From TpKVTopology's perspective, sliding_window doesn't change the - backend or cache shape. All layers use the same FA backend. - """ - # sliding_window is a KVCacheSpec concern, not a backend shape concern - topo = make_topo(MockFlashAttnBackend, is_mla=False) - assert topo.is_kv_layout_blocks_first is False - assert topo.split_k_and_v is True - assert topo.block_size_position == -3 - - def test_jamba_hybrid_mamba_backend_crashes(self): - """ - Jamba-like hybrid: If get_current_attn_backend() returns a Mamba - backend (because the first layer is Mamba), TpKVTopology construction - crashes because Mamba backends don't implement get_kv_cache_shape. - - This documents the current limitation that NIXL cannot work with - models where the first layer is a Mamba layer. - """ - with pytest.raises(NotImplementedError): - make_topo(MockMambaBackend, is_mla=False) - - def test_jamba_hybrid_attention_first_works(self): - """ - Jamba-like hybrid: If the first layer is an attention layer, - get_current_attn_backend() returns FA, and TpKVTopology works. - The Mamba layers are simply not registered with NIXL (they use - separate state management). - """ - # Simulates the case where first layer happens to be attention - topo = make_topo(MockFlashAttnBackend, is_mla=False) - assert topo.is_kv_layout_blocks_first is False - assert topo.split_k_and_v is True - - def test_flashinfer_with_chunked_local_inheriting(self): - """ - If a model uses ChunkedLocal attention backed by FlashInfer, - verify the topology correctly detects the FI layout. - """ - - class MockChunkedLocalFIBackend(MockFlashInferBackend): - @staticmethod - def get_name() -> str: - return "MOCK_CHUNKED_LOCAL_FI" - - topo = make_topo(MockChunkedLocalFIBackend, is_mla=False) - assert topo.is_kv_layout_blocks_first is True - assert topo.split_k_and_v is False - - def test_mixed_fa_and_fi_backends_differ(self): - """ - Hypothetical model with both FA and FI layers. - TpKVTopology constructed with FA vs FI gives different properties. - This documents why a single backend assumption matters. - """ - topo_fa = make_topo(MockFlashAttnBackend, is_mla=False) - topo_fi = make_topo(MockFlashInferBackend, is_mla=False) - - # Key property that differs between the two - assert topo_fa.is_kv_layout_blocks_first is False - assert topo_fi.is_kv_layout_blocks_first is True - - # split_k_and_v also differs - assert topo_fa.split_k_and_v is True - assert topo_fi.split_k_and_v is False - - # block_size_position is the same though - assert topo_fa.block_size_position == topo_fi.block_size_position == -3 - - -# =================================================================== -# 4. get_current_attn_backend Behavior Tests -# =================================================================== -class TestGetCurrentAttnBackend: - """ - Test get_current_attn_backend behavior with mocked static_forward_context. - """ - - def test_returns_first_layers_backend(self): - """ - get_current_attn_backend iterates static_forward_context (dict order) - and returns the first layer's backend. - """ - from unittest.mock import MagicMock, patch - - from vllm.distributed.kv_transfer.kv_connector.utils import ( - get_current_attn_backend, - ) - from vllm.model_executor.layers.attention_layer_base import ( - AttentionLayerBase, - ) - - # Create mock layers with different backends - layer0 = MagicMock(spec=AttentionLayerBase) - layer0.get_attn_backend.return_value = MockFlashAttnBackend - - layer1 = MagicMock(spec=AttentionLayerBase) - layer1.get_attn_backend.return_value = MockFlashInferBackend - - mock_context = {"attn_layer_0": layer0, "attn_layer_1": layer1} - - mock_config = MagicMock() - mock_config.compilation_config.static_forward_context = mock_context - - backend = get_current_attn_backend(mock_config) - assert backend is MockFlashAttnBackend - - def test_returns_second_when_first_is_different(self): - """ - Verify that only the FIRST layer's backend is returned, - even if subsequent layers use a different backend. - """ - from unittest.mock import MagicMock - - from vllm.distributed.kv_transfer.kv_connector.utils import ( - get_current_attn_backend, - ) - from vllm.model_executor.layers.attention_layer_base import ( - AttentionLayerBase, - ) - - # First layer is FI, second is FA - layer0 = MagicMock(spec=AttentionLayerBase) - layer0.get_attn_backend.return_value = MockFlashInferBackend - - layer1 = MagicMock(spec=AttentionLayerBase) - layer1.get_attn_backend.return_value = MockFlashAttnBackend - - mock_context = {"layer_0": layer0, "layer_1": layer1} - - mock_config = MagicMock() - mock_config.compilation_config.static_forward_context = mock_context - - backend = get_current_attn_backend(mock_config) - # Should be the first one - assert backend is MockFlashInferBackend - - def test_mamba_first_layer_returns_mamba(self): - """ - If the first layer is Mamba, get_current_attn_backend returns - the Mamba backend. This would cause TpKVTopology to crash. - - This documents the current problematic behavior that needs fixing: - get_current_attn_backend should skip non-attention backends. - """ - from unittest.mock import MagicMock - - from vllm.distributed.kv_transfer.kv_connector.utils import ( - get_current_attn_backend, - ) - from vllm.model_executor.layers.attention_layer_base import ( - AttentionLayerBase, - ) - - # First layer is Mamba, second is FA - mamba_layer = MagicMock(spec=AttentionLayerBase) - mamba_layer.get_attn_backend.return_value = MockMambaBackend - - attn_layer = MagicMock(spec=AttentionLayerBase) - attn_layer.get_attn_backend.return_value = MockFlashAttnBackend - - mock_context = {"mamba_layer_0": mamba_layer, "attn_layer_0": attn_layer} - - mock_config = MagicMock() - mock_config.compilation_config.static_forward_context = mock_context - - backend = get_current_attn_backend(mock_config) - # Current behavior: returns Mamba (the first layer's backend) - assert backend is MockMambaBackend - - # This will crash TpKVTopology: - with pytest.raises(NotImplementedError): - make_topo(backend, is_mla=False) - - def test_fallback_when_no_layers(self): - """ - When static_forward_context is empty, get_current_attn_backend - falls back to the attention selector. - """ - from unittest.mock import MagicMock, patch - - from vllm.distributed.kv_transfer.kv_connector.utils import ( - get_current_attn_backend, - ) - - mock_config = MagicMock() - mock_config.compilation_config.static_forward_context = {} - mock_config.model_config.get_head_size.return_value = 64 - mock_config.model_config.dtype = torch.float16 - mock_config.cache_config.cache_dtype = "auto" - mock_config.cache_config.block_size = 16 - mock_config.model_config.use_mla = False - - with patch( - "vllm.distributed.kv_transfer.kv_connector.utils.get_attn_backend" - ) as mock_selector: - mock_selector.return_value = MockFlashAttnBackend - backend = get_current_attn_backend(mock_config) - assert backend is MockFlashAttnBackend - mock_selector.assert_called_once() - - def test_all_layers_same_backend_consistency(self): - """ - When all layers use the same backend, any layer can be used - to construct TpKVTopology with identical results. - """ - from unittest.mock import MagicMock - - from vllm.distributed.kv_transfer.kv_connector.utils import ( - get_current_attn_backend, - ) - from vllm.model_executor.layers.attention_layer_base import ( - AttentionLayerBase, - ) - - layers = {} - for i in range(10): - layer = MagicMock(spec=AttentionLayerBase) - layer.get_attn_backend.return_value = MockFlashAttnBackend - layers[f"layer_{i}"] = layer - - mock_config = MagicMock() - mock_config.compilation_config.static_forward_context = layers - - backend = get_current_attn_backend(mock_config) - assert backend is MockFlashAttnBackend - - # All produce the same topology - topo = make_topo(backend) - for layer in layers.values(): - other = make_topo(layer.get_attn_backend()) - assert topo.is_kv_layout_blocks_first == other.is_kv_layout_blocks_first - assert topo.split_k_and_v == other.split_k_and_v - assert topo.block_size_position == other.block_size_position From 487a4f32e646207a4c7f183b8c558b8b6f654514 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 24 Mar 2026 18:14:00 +0000 Subject: [PATCH 6/8] update tests and default values Signed-off-by: NickLucche --- .../kv_connector/unit/test_nixl_connector.py | 19 ++++--------------- .../kv_connector/v1/nixl_connector.py | 7 ++++--- vllm/envs.py | 8 ++++---- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e015ede1aaf9..ddb7f40b0126 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -2665,21 +2665,10 @@ def test_d_side_cleanup_on_transfer_complete(self, default_vllm_config, dist_ini remote_engine_id = FakeNixlConnectorWorker.REMOTE_ENGINE_ID - # Setup kv_topo - backend = get_current_attn_backend(vllm_config) - test_shape = backend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - worker.kv_topo = TpKVTopology( - tp_rank=worker.tp_rank, - engine_id=worker.engine_id, - remote_tp_size=worker._tp_size, - remote_block_size=worker._block_size, - is_mla=worker.use_mla, - total_num_kv_heads=worker.model_config.get_total_num_kv_heads(), - attn_backends=[backend], - tensor_shape=test_shape, - ) + # Register remote engine in shared topology dicts so that + # block_size_ratio_from_engine_id can resolve the remote engine. + worker._tp_size[remote_engine_id] = 1 + worker._block_size[remote_engine_id] = 16 # Simulate transfer metadata from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index afef4a0f66f2..397d7664b95e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -2579,7 +2579,7 @@ def _send_lease_heartbeats(self) -> None: if now - self._last_heartbeat_time < self._heartbeat_interval: return - sent_any = False + num_notifs = 0 for engine_id, req_ids in self._pending_transfers_by_engine.items(): if not req_ids: continue @@ -2599,7 +2599,7 @@ def _send_lease_heartbeats(self) -> None: for agent_name in remote_agents.values(): try: self.nixl_wrapper.send_notif(agent_name, notif_msg=heartbeat_msg) - sent_any = True + num_notifs += 1 except Exception as e: logger.warning( "Failed to send heartbeat to engine %s agent %s: %s", @@ -2608,8 +2608,9 @@ def _send_lease_heartbeats(self) -> None: e, ) - if sent_any: + if num_notifs > 0: self._last_heartbeat_time = now + logger.debug("Sent %d heartbeat notifications", num_notifs) def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): assert meta.remote is not None and self.kv_topo is not None diff --git a/vllm/envs.py b/vllm/envs.py index 4da45c5cd092..9d4946b0b332 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -192,7 +192,7 @@ VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_KV_LEASE_DURATION: int = 15 VLLM_NIXL_KV_LEASE_EXTENSION: int = 30 - VLLM_NIXL_KV_HEARTBEAT_INTERVAL: int = 5 + VLLM_NIXL_KV_HEARTBEAT_INTERVAL: float = 5 VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False VLLM_MORIIO_QP_PER_TRANSFER: int = 1 VLLM_MORIIO_POST_BATCH_SIZE: int = -1 @@ -1391,17 +1391,17 @@ def _get_or_set_default() -> str: # time, the blocks will be freed. Used in disaggregated P/D setup. # D sends periodic heartbeats to extend the lease (see KV_LEASE_EXTENSION). "VLLM_NIXL_KV_LEASE_DURATION": lambda: int( - os.getenv("VLLM_NIXL_KV_LEASE_DURATION", "15") + os.getenv("VLLM_NIXL_KV_LEASE_DURATION", "30") ), # Lease extension (in seconds) granted when a heartbeat is received from # the consumer (D). Each heartbeat extends the lease by this amount. "VLLM_NIXL_KV_LEASE_EXTENSION": lambda: int( - os.getenv("VLLM_NIXL_KV_LEASE_EXTENSION", "30") + os.getenv("VLLM_NIXL_KV_LEASE_EXTENSION", "20") ), # Interval (in seconds) at which the consumer (D) sends heartbeat # notifications to the producer (P) to extend KV block leases. # Should be less than VLLM_NIXL_KV_LEASE_EXTENSION to ensure timely renewal. - "VLLM_NIXL_KV_HEARTBEAT_INTERVAL": lambda: int( + "VLLM_NIXL_KV_HEARTBEAT_INTERVAL": lambda: float( os.getenv("VLLM_NIXL_KV_HEARTBEAT_INTERVAL", "5") ), # Controls the read mode for the Mori-IO connector From fbb85bf2e43b16eacbe882c17e1911be13300066 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 24 Mar 2026 18:29:40 +0000 Subject: [PATCH 7/8] docs update Signed-off-by: NickLucche --- docs/features/nixl_connector_usage.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index f63f17b576c0..b2d34aa35eff 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -125,12 +125,12 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \ - Connection info is passed via KVTransferParams from prefiller to decoder for handshake - `VLLM_NIXL_KV_LEASE_DURATION`: Initial lease duration (in seconds) for KV blocks on the prefiller. (Optional) - - Default: 15 + - Default: 30 - After a prefill completes, the prefiller holds KV blocks for this duration. The decoder sends periodic heartbeats to extend the lease. - If no heartbeat is received within the lease duration, blocks are released. - `VLLM_NIXL_KV_LEASE_EXTENSION`: Lease extension (in seconds) granted per heartbeat. (Optional) - - Default: 30 + - Default: 20 - Each heartbeat from the decoder extends the lease by this amount. - `VLLM_NIXL_KV_HEARTBEAT_INTERVAL`: Interval (in seconds) at which the decoder sends heartbeats. (Optional) From c2861cff8152e8c288afac3e260e4bf706cbf9fb Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 25 Mar 2026 10:35:34 +0000 Subject: [PATCH 8/8] minor Signed-off-by: NickLucche --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 397d7664b95e..a0922a86a824 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -2381,7 +2381,7 @@ def _get_new_notifs(self) -> set[str]: # Handle heartbeat messages from D (consumer) to hold KV Cache blocks. if decoded.startswith("HB:"): req_ids_str = decoded[3:] # Skip "HB:" prefix - heartbeat_req_ids = req_ids_str.split(",") + heartbeat_req_ids = [r for r in req_ids_str.split(",") if r] extended_count = 0 for req_id in heartbeat_req_ids: if req_id in self._reqs_to_send: