diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index ea668615153c..1017de7e1c0f 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -126,9 +126,18 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \ - Set when prefiller and decoder are on different machines - Connection info is passed via KVTransferParams from prefiller to decoder for handshake -- `VLLM_NIXL_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional) - - Default: 480 - - If a request is aborted and the decoder has not yet read the KV-cache blocks through the nixl channel, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely. +- `VLLM_NIXL_KV_LEASE_DURATION`: Initial lease duration (in seconds) for KV blocks on the prefiller. (Optional) + - Default: 30 + - After a prefill completes, the prefiller holds KV blocks for this duration. The decoder sends periodic heartbeats to extend the lease. + - If no heartbeat is received within the lease duration, blocks are released. + +- `VLLM_NIXL_KV_LEASE_EXTENSION`: Lease extension (in seconds) granted per heartbeat. (Optional) + - Default: 20 + - Each heartbeat from the decoder extends the lease by this amount. + +- `VLLM_NIXL_KV_HEARTBEAT_INTERVAL`: Interval (in seconds) at which the decoder sends heartbeats. (Optional) + - Default: 5 + - Should be less than `VLLM_NIXL_KV_LEASE_EXTENSION` to ensure timely renewal. ## Multi-Instance Setup diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index b4ee97cd1d74..cbe73ce2a4e7 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1349,7 +1349,7 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): timeout = 6 monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout)) + monkeypatch.setenv("VLLM_NIXL_KV_LEASE_DURATION", str(timeout)) def run_test_and_cleanup(): llm = LLM(**llm_kwargs) @@ -1364,7 +1364,7 @@ def run_test_and_cleanup(): runtime_env = { "working_dir": working_dir, # ship fake nixl package "env_vars": { - "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout), + "VLLM_NIXL_KV_LEASE_DURATION": str(timeout), # TODO: for ray to carry over, remove once we set "NIXL_TELEMETRY_ENABLE": "1", }, @@ -1813,6 +1813,7 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init): patch.object(nixl_wrapper, "deregister_memory") as mock_dereg, ): worker._recving_transfers = {"req1": [123]} + worker._pending_transfers_by_engine = {"engine1": {"req1"}} # Mock register_kv_cache which registers local handle worker.src_xfer_handles_by_block_size = {worker.block_size: 455} # P TP = 2 * D TP case, we should register 2 local handles @@ -2383,6 +2384,393 @@ def test_compatibility_hash_validation( assert len(result) == 1 +class TestHeartbeatLeaseManagement: + """Tests for the heartbeat-based lease management system.""" + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_d_side_heartbeat_sending( + self, default_vllm_config, dist_init, monkeypatch + ): + """Test that D-side sends heartbeats to P engines with pending transfers.""" + # Set a short renewal interval for testing + monkeypatch.setenv("VLLM_NIXL_KV_HEARTBEAT_INTERVAL", "0.1") + + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + # Override the renewal interval since env var is read at init + worker._heartbeat_interval = 0.1 + + # Simulate remote agent registration (handshake complete) + remote_engine_id = FakeNixlConnectorWorker.REMOTE_ENGINE_ID + worker._remote_agents[remote_engine_id] = {0: "fake_agent"} + + # Track pending transfers for this engine. + # _recving_transfers keyed by req_id for completion tracking. + worker._recving_transfers = {"req1": [1], "req2": [2], "req3": [3]} + # _pending_transfers_by_engine keyed by engine_id for heartbeat batching. + worker._pending_transfers_by_engine[remote_engine_id] = { + "req1", + "req2", + "req3", + } + + # Track sent notifications + sent_notifs: list[tuple[str, bytes]] = [] + + def mock_send_notif(agent_name: str, notif_msg: bytes): + sent_notifs.append((agent_name, notif_msg)) + + worker.nixl_wrapper.send_notif = mock_send_notif + + # First call should send heartbeat (no previous renewal time) + worker._send_lease_heartbeats() + + assert len(sent_notifs) == 1 + agent_name, msg = sent_notifs[0] + assert agent_name == "fake_agent" + # Verify message format: "HB:req1,req2,req3" (order may vary) + assert msg.startswith(b"HB:") + req_ids = msg.decode()[3:].split(",") + assert set(req_ids) == {"req1", "req2", "req3"} + + # Immediate second call should NOT send (within renewal interval) + worker._send_lease_heartbeats() + assert len(sent_notifs) == 1 # Still just 1 + + # Wait for renewal interval and call again + time.sleep(0.15) + worker._send_lease_heartbeats() + assert len(sent_notifs) == 2 # Now 2 + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_d_side_heartbeat_sends_to_all_p_workers( + self, default_vllm_config, dist_init, monkeypatch + ): + """Test that D-side sends heartbeats to ALL P workers (P TP > D TP case).""" + monkeypatch.setenv("VLLM_NIXL_KV_HEARTBEAT_INTERVAL", "0.1") + + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + worker._heartbeat_interval = 0.1 + + # Simulate P TP=2 case: two remote agents for same engine + remote_engine_id = FakeNixlConnectorWorker.REMOTE_ENGINE_ID + worker._remote_agents[remote_engine_id] = { + 0: "p_worker_0", + 1: "p_worker_1", + } + + # Track pending transfers + worker._recving_transfers = {"req1": [1]} + worker._pending_transfers_by_engine[remote_engine_id] = {"req1"} + + # Track sent notifications + sent_notifs: list[tuple[str, bytes]] = [] + + def mock_send_notif(agent_name: str, notif_msg: bytes): + sent_notifs.append((agent_name, notif_msg)) + + worker.nixl_wrapper.send_notif = mock_send_notif + + # Send heartbeats + worker._send_lease_heartbeats() + + # Should send to BOTH P workers + assert len(sent_notifs) == 2 + agent_names = {notif[0] for notif in sent_notifs} + assert agent_names == {"p_worker_0", "p_worker_1"} + + # Both should have the same message + for _, msg in sent_notifs: + assert msg == b"HB:req1" + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_p_side_heartbeat_extends_lease( + self, default_vllm_config, dist_init, monkeypatch + ): + """Test that P-side extends lease when receiving heartbeat.""" + monkeypatch.setenv("VLLM_NIXL_KV_LEASE_EXTENSION", "30") + + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + # Setup kv_topo needed by _get_new_notifs + backend = get_current_attn_backend(vllm_config) + test_shape = backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + worker.kv_topo = TpKVTopology( + tp_rank=worker.tp_rank, + engine_id=worker.engine_id, + remote_tp_size=worker._tp_size, + remote_block_size=worker._block_size, + is_mla=worker.use_mla, + total_num_kv_heads=worker.model_config.get_total_num_kv_heads(), + attn_backends=[backend], + tensor_shape=test_shape, + ) + + # Register requests waiting to be sent (P-side tracking) + now = time.perf_counter() + initial_expiry = now + 5.0 # Expires in 5 seconds + worker._reqs_to_send["req1"] = initial_expiry + worker._reqs_to_send["req2"] = initial_expiry + worker._reqs_to_send["req3"] = initial_expiry + worker._reqs_to_process = {"req1", "req2", "req3"} + + # Simulate heartbeat notification from D + heartbeat_msg = b"HB:req1,req2" + + worker.nixl_wrapper.get_new_notifs = lambda: {"agent": [heartbeat_msg]} + + # Process notifications + worker._get_new_notifs() + + # req1 and req2 should have extended leases (now + 30s) + # req3 should still have original expiry + assert worker._reqs_to_send["req1"] > initial_expiry + 20 + assert worker._reqs_to_send["req2"] > initial_expiry + 20 + assert worker._reqs_to_send["req3"] == initial_expiry + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_p_side_lease_expiration(self, default_vllm_config, dist_init, monkeypatch): + """Test that P-side expires leases when no heartbeats received.""" + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + # Setup kv_topo needed by get_finished + backend = get_current_attn_backend(vllm_config) + test_shape = backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + worker.kv_topo = TpKVTopology( + tp_rank=worker.tp_rank, + engine_id=worker.engine_id, + remote_tp_size=worker._tp_size, + remote_block_size=worker._block_size, + is_mla=worker.use_mla, + total_num_kv_heads=worker.model_config.get_total_num_kv_heads(), + attn_backends=[backend], + tensor_shape=test_shape, + ) + + # Register a request with an already-expired lease + now = time.perf_counter() + worker._reqs_to_send["expired_req"] = now - 1.0 # Already expired + worker._reqs_to_process.add("expired_req") + + # Also register a request that's not expired yet + worker._reqs_to_send["valid_req"] = now + 100.0 # Far future + worker._reqs_to_process.add("valid_req") + + # get_finished should return the expired request + done_sending, _ = connector.get_finished(finished_req_ids=set()) + + assert "expired_req" in done_sending + assert "valid_req" not in done_sending + + # Verify expired request is cleaned up + assert "expired_req" not in worker._reqs_to_send + assert "expired_req" not in worker._reqs_to_process + + # Valid request should still be tracked + assert "valid_req" in worker._reqs_to_send + assert "valid_req" in worker._reqs_to_process + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_heartbeat_with_empty_requests(self, default_vllm_config, dist_init): + """Test heartbeat handling with empty request string.""" + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + # Setup kv_topo + backend = get_current_attn_backend(vllm_config) + test_shape = backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + worker.kv_topo = TpKVTopology( + tp_rank=worker.tp_rank, + engine_id=worker.engine_id, + remote_tp_size=worker._tp_size, + remote_block_size=worker._block_size, + is_mla=worker.use_mla, + total_num_kv_heads=worker.model_config.get_total_num_kv_heads(), + attn_backends=[backend], + tensor_shape=test_shape, + ) + + # Empty heartbeat message (just "HB:" with no request IDs) + worker.nixl_wrapper.get_new_notifs = lambda: {"agent": [b"HB:"]} + + # Should handle gracefully without error + notified = worker._get_new_notifs() + assert len(notified) == 0 + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_d_side_cleanup_on_transfer_complete(self, default_vllm_config, dist_init): + """Test that D-side removes completed transfers from heartbeat tracking.""" + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + remote_engine_id = FakeNixlConnectorWorker.REMOTE_ENGINE_ID + + # Register remote engine in shared topology dicts so that + # block_size_ratio_from_engine_id can resolve the remote engine. + worker._tp_size[remote_engine_id] = 1 + worker._block_size[remote_engine_id] = 16 + + # Simulate transfer metadata + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + RemoteMeta, + ReqMeta, + ) + + req_id = "test_transfer" + worker._recving_metadata[req_id] = ReqMeta( + local_block_ids=([1, 2, 3],), + local_physical_block_ids=([1, 2, 3],), + tp_size=1, + remote=RemoteMeta( + block_ids=([4, 5, 6],), + host="localhost", + port=1234, + engine_id=remote_engine_id, + request_id=f"prefill-{req_id}", + ), + ) + + # Simulate transfer handle completion. + # _recving_transfers keyed by req_id for completion tracking. + handle = 12345 + worker._recving_transfers[req_id] = [handle] + # _pending_transfers_by_engine keyed by engine_id for heartbeat batching. + worker._pending_transfers_by_engine[remote_engine_id] = {req_id} + + # Mock check_xfer_state to return DONE + worker.nixl_wrapper._cycles_before_xfer_done = 0 + + # get_finished should complete the transfer and clean up + _, done_recving = connector.get_finished(finished_req_ids=set()) + + assert req_id in done_recving + # Should be removed from _recving_transfers + assert req_id not in worker._recving_transfers + # Should be removed from _pending_transfers_by_engine + assert req_id not in worker._pending_transfers_by_engine.get( + remote_engine_id, set() + ) + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_heartbeat_send_failure_logged( + self, default_vllm_config, dist_init, caplog + ): + """Test that heartbeat send failures are logged as warnings.""" + import logging + + vllm_config = create_vllm_config() + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + worker._heartbeat_interval = 0 # Send immediately + + remote_engine_id = FakeNixlConnectorWorker.REMOTE_ENGINE_ID + worker._remote_agents[remote_engine_id] = {0: "fake_agent"} + worker._recving_transfers = {"req1": [1]} + worker._pending_transfers_by_engine[remote_engine_id] = {"req1"} + + # Make send_notif raise an exception + def failing_send_notif(agent_name: str, notif_msg: bytes): + raise RuntimeError("Network error") + + worker.nixl_wrapper.send_notif = failing_send_notif + + # Capture logs from the nixl_connector logger + nixl_logger = logging.getLogger( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" + ) + captured_logs: list[logging.LogRecord] = [] + + class LogCapture(logging.Handler): + def emit(self, record): + captured_logs.append(record) + + handler = LogCapture() + handler.setLevel(logging.WARNING) + nixl_logger.addHandler(handler) + + try: + # Should not raise, just log warning + worker._send_lease_heartbeats() + finally: + nixl_logger.removeHandler(handler) + + # Verify warning was logged + warning_logs = [r for r in captured_logs if r.levelno == logging.WARNING] + assert len(warning_logs) >= 1 + assert any("Failed to send heartbeat" in r.message for r in warning_logs) + + @pytest.mark.parametrize( "error_scenario", [ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a86a52a6a6fb..a0922a86a824 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -970,15 +970,18 @@ def request_finished( delay_free_blocks = any(len(group) > 0 for group in block_ids) if delay_free_blocks: - # Prefill request on remote. It will be read from D upon completion + # Prefill request on remote. It will be read from D upon completion. + # Use initial lease duration - D will send heartbeats to extend. + initial_lease = envs.VLLM_NIXL_KV_LEASE_DURATION logger.debug( - "NIXLConnector request_finished(%s) waiting for %d seconds " - "for remote decode to fetch blocks", + "NIXLConnector request_finished(%s) waiting for initial lease " + "of %d seconds for remote decode to fetch blocks " + "(will be extended by heartbeats)", request.request_id, - envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, + initial_lease, ) self._reqs_need_send[request.request_id] = ( - time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + time.perf_counter() + initial_lease ) # NOTE HMA will "mark" empty/null blocks in groups with 0s (eg SWA ones), # trimming down after allocating for the whole sequence length. Empty @@ -1175,9 +1178,14 @@ def __init__( self._registered_descs: list[Any] = [] # In progress transfers. - # [req_id -> list[handle]] + # In-progress transfer tracking (D-side / consumer). + # Keyed by req_id to ensure ALL handles complete before marking done. self._recving_metadata: dict[ReqId, ReqMeta] = {} self._recving_transfers = defaultdict[ReqId, list[TransferHandle]](list) + # Track which engines have pending transfers for each request. + # Used for batched heartbeat sending per P engine. + self._pending_transfers_by_engine = defaultdict[EngineId, set[ReqId]](set) + # Track the expiration time of requests that are waiting to be sent. self._reqs_to_send: dict[ReqId, float] = {} # Set of requests that have been part of a batch, regardless of status. @@ -1188,6 +1196,13 @@ def __init__( # requests that skipped transfer (handshake or transfer failures) self._failed_recv_reqs: set[ReqId] = set() + # Heartbeat/lease management for D-side (consumer). + # Single timestamp suffices - heartbeat interval limits overall send rate, + # not per-engine. New engines get fresh leases on P-side anyway. + self._last_heartbeat_time: float = 0.0 + self._heartbeat_interval: float = float(envs.VLLM_NIXL_KV_HEARTBEAT_INTERVAL) + self._lease_extension = float(envs.VLLM_NIXL_KV_LEASE_EXTENSION) + # Handshake metadata of this worker for NIXL transfers. self.xfer_handshake_metadata: NixlHandshakePayload | None = None # Background thread for initializing new NIXL handshakes. @@ -1486,6 +1501,12 @@ def _background_nixl_handshake( fut = self._handshake_futures.get(remote_engine_id) if fut is None: assert meta.remote is not None + # Opportunistically clean up empty (dead?) remote engine_ids entries from + # _pending_transfers_by_engine. Asymptotically this data structure can only + # grow indefinitely when new P remotes are added. + for k in list(self._pending_transfers_by_engine.keys()): + if not self._pending_transfers_by_engine[k]: + del self._pending_transfers_by_engine[k] fut = self._handshake_initiation_executor.submit( self._nixl_handshake, meta.remote.host, @@ -2295,6 +2316,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: meta = self._recving_metadata.pop(req_id, None) assert meta is not None, f"{req_id} not found in recving_metadata list" assert meta.remote is not None + if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) @@ -2315,7 +2337,9 @@ def get_finished(self) -> tuple[set[str], set[str]]: ) in block_ids_for_blocksize_post_process.items(): self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list) - # Handle timeout to avoid stranding blocks on remote. + # Handle lease expiration to avoid stranding blocks on remote. + # Leases start with VLLM_NIXL_KV_LEASE_DURATION and are extended + # by heartbeats from D. If no heartbeat is received, lease expires. now = time.perf_counter() while self._reqs_to_send: req_id, expires = next(iter(self._reqs_to_send.items())) @@ -2325,11 +2349,11 @@ def get_finished(self) -> tuple[set[str], set[str]]: count = self.consumer_notification_counts_by_req.pop(req_id, 0) self.xfer_stats.record_kv_expired_req() logger.warning( - "Releasing expired KV blocks for request %s which were " - "retrieved by %d decode worker(s) within %d seconds.", + "Releasing KV blocks for request %s: lease expired " + "(retrieved by %d decode worker(s)). This may indicate " + "D crashed or network issues preventing heartbeats.", req_id, count, - envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, ) self._reqs_to_process.remove(req_id) del self._reqs_to_send[req_id] @@ -2342,12 +2366,43 @@ def _get_new_notifs(self) -> set[str]: Get req_ids which got a remote xfer message. When multiple consumers are reading from the same producer (heterogeneous TP scenario), wait for all consumers to be done pulling. + + Also handles heartbeat messages from D (consumer) to extend KV block + leases. Heartbeat format: "HB:,,..." """ assert self.kv_topo is not None notified_req_ids: set[str] = set() + now = time.perf_counter() + for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: - req_id, tp_size = notif.decode("utf-8").rsplit(":", 1) + decoded = notif.decode("utf-8") + + # Handle heartbeat messages from D (consumer) to hold KV Cache blocks. + if decoded.startswith("HB:"): + req_ids_str = decoded[3:] # Skip "HB:" prefix + heartbeat_req_ids = [r for r in req_ids_str.split(",") if r] + extended_count = 0 + for req_id in heartbeat_req_ids: + if req_id in self._reqs_to_send: + # Extend the lease for this request. + self._reqs_to_send[req_id] = now + self._lease_extension + extended_count += 1 + else: + logger.warning( + "Received heartbeat message for unknown request %s. " + "This may indicate the request has already expired.", + req_id, + ) + if extended_count > 0: + logger.debug( + "Extended lease for %d requests via heartbeat", + extended_count, + ) + continue + + # Handle completion notifications (original format: "req_id:tp_size") + req_id, tp_size = decoded.rsplit(":", 1) if ( req_id not in self._reqs_to_send and req_id not in self._reqs_to_process @@ -2425,8 +2480,12 @@ def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: # Only report request as completed when all transfers are done. done_req_ids.add(req_id) del transfers[req_id] + # Clean up from pending_transfers_by_engine + for engine_reqs in self._pending_transfers_by_engine.values(): + engine_reqs.discard(req_id) else: transfers[req_id] = in_progress + return done_req_ids def _handle_failed_transfer(self, req_id: str, handle: int): @@ -2503,12 +2562,64 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): if req_id in self._reqs_to_process: self._reqs_to_send[req_id] = expiration_time + # Send heartbeats to P engines with pending transfers (D-side). + self._send_lease_heartbeats() + + def _send_lease_heartbeats(self) -> None: + """ + Send batched heartbeat notifications to P engines to extend KV block leases. + + This is called periodically from start_load_kv on the D (consumer) side. + Heartbeats are batched per P engine to minimize notification overhead. + Format: "HB:,,..." + """ + now = time.perf_counter() + + # Check if enough time has passed since last heartbeat. + if now - self._last_heartbeat_time < self._heartbeat_interval: + return + + num_notifs = 0 + for engine_id, req_ids in self._pending_transfers_by_engine.items(): + if not req_ids: + continue + + # Get agents for this engine. + remote_agents = self._remote_agents.get(engine_id) + if not remote_agents: + # Handshake not yet complete for this engine. + continue + + # Build batched heartbeat message: "HB:req1,req2,req3,..." + heartbeat_msg = ("HB:" + ",".join(req_ids)).encode() + + # Send to ALL remote agents we handhshaked with for this remote. + # Important for P TP > D TP case where we have multiple remote workers. + # For other cases, we actually only heartbeat one remote agent. + for agent_name in remote_agents.values(): + try: + self.nixl_wrapper.send_notif(agent_name, notif_msg=heartbeat_msg) + num_notifs += 1 + except Exception as e: + logger.warning( + "Failed to send heartbeat to engine %s agent %s: %s", + engine_id, + agent_name, + e, + ) + + if num_notifs > 0: + self._last_heartbeat_time = now + logger.debug("Sent %d heartbeat notifications", num_notifs) + def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): assert meta.remote is not None and self.kv_topo is not None + remote_engine_id = meta.remote.engine_id + remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( - meta.remote.engine_id + remote_engine_id ) - tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id) + tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) # D may have to perform multiple reads from different remote ranks. for i, remote_rank in enumerate(remote_ranks): if self.use_mla and tp_ratio < 0 and i > 0: @@ -2684,8 +2795,10 @@ def _read_blocks( # Begin async xfer. self.nixl_wrapper.transfer(handle) - # Use handle to check completion in future step(). + # Track handle for completion checking (keyed by req_id). self._recving_transfers[request_id].append(handle) + # Track engine for batched heartbeat sending (keyed by remote id). + self._pending_transfers_by_engine[dst_engine_id].add(request_id) except Exception as e: # mark all (logical) blocks for this request as invalid self._log_failure( @@ -2881,6 +2994,7 @@ def shutdown(self): for handle in handles: self.nixl_wrapper.release_xfer_handle(handle) self._recving_transfers.clear() + self._pending_transfers_by_engine.clear() for handle in self.src_xfer_handles_by_block_size.values(): self.nixl_wrapper.release_dlist_handle(handle) self.src_xfer_handles_by_block_size.clear() diff --git a/vllm/envs.py b/vllm/envs.py index ec8d663141a6..bcebba911249 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -196,7 +196,9 @@ ] = "NONE" VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None - VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 + VLLM_NIXL_KV_LEASE_DURATION: int = 15 + VLLM_NIXL_KV_LEASE_EXTENSION: int = 30 + VLLM_NIXL_KV_HEARTBEAT_INTERVAL: float = 5 VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False VLLM_MORIIO_QP_PER_TRANSFER: int = 1 VLLM_MORIIO_POST_BATCH_SIZE: int = -1 @@ -1409,12 +1411,23 @@ def _get_or_set_default() -> str: "VLLM_USE_NVFP4_CT_EMULATIONS": lambda: bool( int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0")) ), - # Time (in seconds) after which the KV cache on the producer side is - # automatically cleared if no READ notification is received from the - # consumer. This is only applicable when using NixlConnector in a - # disaggregated decode-prefill setup. - "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( - os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") + # Initial lease duration (in seconds) for KV blocks on the producer (P) + # side. If no heartbeat is received from the consumer (D) within this + # time, the blocks will be freed. Used in disaggregated P/D setup. + # D sends periodic heartbeats to extend the lease (see KV_LEASE_EXTENSION). + "VLLM_NIXL_KV_LEASE_DURATION": lambda: int( + os.getenv("VLLM_NIXL_KV_LEASE_DURATION", "30") + ), + # Lease extension (in seconds) granted when a heartbeat is received from + # the consumer (D). Each heartbeat extends the lease by this amount. + "VLLM_NIXL_KV_LEASE_EXTENSION": lambda: int( + os.getenv("VLLM_NIXL_KV_LEASE_EXTENSION", "20") + ), + # Interval (in seconds) at which the consumer (D) sends heartbeat + # notifications to the producer (P) to extend KV block leases. + # Should be less than VLLM_NIXL_KV_LEASE_EXTENSION to ensure timely renewal. + "VLLM_NIXL_KV_HEARTBEAT_INTERVAL": lambda: float( + os.getenv("VLLM_NIXL_KV_HEARTBEAT_INTERVAL", "5") ), # Controls the read mode for the Mori-IO connector "VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: (