diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 3f92b183dca7..66a6a93e589c 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -2750,3 +2750,116 @@ def test_mla_broadcast_notif_uses_remote_request_id( f"got {notif!r} (expected {expected_notif!r}, " f"buggy form would be {bad_notif!r})" ) + + +class _QueueingFakeNixlWrapper(FakeNixlWrapper): + """FakeNixlWrapper extension that lets a test queue notifications.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._queued: list[bytes] = [] + + def queue_notif(self, req_id: str, n_consumers: int = 1) -> None: + self._queued.append(f"{req_id}:{n_consumers}".encode()) + + def get_new_notifs(self) -> dict[str, list[bytes]]: + if not self._queued: + return {} + notifs = {"agent": list(self._queued)} + self._queued.clear() + return notifs + + +def _build_producer_worker(): + vllm_config = create_vllm_config() + kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2) + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, + connector.engine_id, + hand_shake_latency=0, + kv_cache_config=kv_cache_config, + ) + return connector.connector_worker + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper", + _QueueingFakeNixlWrapper, +) +def test_one_sibling_pull_releases_all_registered_siblings( + default_vllm_config, dist_init +): + """D-side prefix cache pulls only sibling 7; without the fix, siblings + 0..6 strand until the 480s expiry.""" + # Given a producer holding all 8 best_of siblings of one parent prompt. + # Sibling ids follow vllm/v1/engine/parallel_sampling.py:92's f"{i}_{parent}". + worker = _build_producer_worker() + parent = "cmpl-fb4bc0e0-7dec-4bdc-9774-55b4575bf876-0-9e35b663" + siblings = {f"{i}_{parent}" for i in range(8)} + worker._reqs_to_process = set(siblings) + worker._reqs_to_send = {sib: 0.0 for sib in siblings} + + # When only sibling 7's pull-complete notification arrives. + worker.nixl_wrapper.queue_notif(f"7_{parent}", n_consumers=1) + done_sending, _ = worker.get_finished() + + # Then all 8 siblings are released, not just sibling 7. + assert done_sending == siblings + assert worker._reqs_to_process == set() + assert worker._reqs_to_send == {} + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper", + _QueueingFakeNixlWrapper, +) +def test_sibling_registering_after_parent_pulled_is_freed_at_registration( + default_vllm_config, dist_init +): + """A late-registering sibling of an already-pulled parent must not strand.""" + # Given a producer that has already seen one sibling of this parent pulled. + worker = _build_producer_worker() + parent = "cmpl-0fd4875b-320a-4b95-82ff-289a3d30fd2b-0-ac8bcbf3" + worker._pulled_bases[parent] = None + + # When a late sibling of the same parent registers via start_load_kv. + late_sibling = f"3_{parent}" + metadata = NixlConnectorMetadata() + metadata.reqs_in_batch = {late_sibling} + worker.start_load_kv(metadata) + done_sending, _ = worker.get_finished() + + # Then the late sibling is freed via get_finished, not stranded. + assert late_sibling in done_sending + assert late_sibling not in worker._reqs_to_process + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper", + _QueueingFakeNixlWrapper, +) +def test_notification_arriving_before_registration_settles_on_registration( + default_vllm_config, dist_init +): + """The pull-complete notif can race start_load_kv; before the fix it was + dropped and the request stranded until the 480s expiry.""" + # Given the consumer's pull-complete notif arrives before the producer + # has registered the request in _reqs_to_process. + worker = _build_producer_worker() + req_id = "7_cmpl-35dba481-9b73-4fce-8636-7a5d40f3167a-0-b48d1be7" + worker.nixl_wrapper.queue_notif(req_id, n_consumers=1) + worker.get_finished() + assert req_id in worker._notif_n_consumers + assert worker.consumer_notification_counts_by_req[req_id] == 1 + + # When the producer registers the request via start_load_kv. + metadata = NixlConnectorMetadata() + metadata.reqs_in_batch = {req_id} + worker.start_load_kv(metadata) + done_sending, _ = worker.get_finished() + + # Then the req is released instead of stranding until the 480s expiry. + assert req_id in done_sending + assert req_id not in worker._notif_n_consumers + assert req_id not in worker._reqs_to_process diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index ea8b46c28f9c..016c7ea4670f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -433,6 +433,16 @@ def __init__( # With heterogeneous TP, P must wait for all assigned D TP workers to # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) + # FIX(early-notif race): a pull-complete notif can arrive before start_load_kv + # registers the req. Stage tp_size so registration can settle; FIFO-capped so + # never-registering reqs (aborts) can't accumulate. + self._notif_n_consumers: dict[ReqId, int] = {} + self._late_released: set[ReqId] = set() + # FIX(best_of fan-out): n>1 siblings f"{i}_{parent_id}" share one prompt KV; + # D-side prefix cache pulls only one, so the rest must free when any is + # pulled (else strand at VLLM_NIXL_ABORT_REQUEST_TIMEOUT). FIFO-capped. + self._pulled_bases: dict[str, None] = {} + self._fanout_released: set[ReqId] = set() self.xfer_stats = NixlKVConnectorStats() self._physical_blocks_per_logical_kv_block = 1 @@ -1691,6 +1701,12 @@ def post_process_device_kv_on_receive_heterogeneous_attn( indices=indices, ) + @staticmethod + def _best_of_parent(req_id: str) -> str | None: + # Children are named f"{index}_{parent_id}" (parallel_sampling.py:92). + head, sep, tail = req_id.partition("_") + return tail if sep and head.isdigit() else None + def get_finished(self) -> tuple[set[str], set[str]]: """ Get requests that are done sending or recving on this specific worker. @@ -1699,6 +1715,14 @@ def get_finished(self) -> tuple[set[str], set[str]]: """ assert self.transfer_topo is not None done_sending = self._get_new_notifs() + # FIX(best_of fan-out): siblings released at registration. + if self._fanout_released: + done_sending |= self._fanout_released + self._fanout_released = set() + # FIX(early-notif race): merge late settlements. + if self._late_released: + done_sending |= self._late_released + self._late_released = set() done_recving = self._pop_done_transfers(self._recving_transfers) # Drain queue of requests where handshake or transfer setup failed. @@ -1814,11 +1838,21 @@ def _get_new_notifs(self) -> set[str]: req_id not in self._reqs_to_send and req_id not in self._reqs_to_process ): - logger.error( - "Potentially invalid KV blocks for " - "unrecognized request %s were retrieved by " - "a decode worker. They may have expired.", + # FIX(early-notif race): record + count, settle on registration. + if ( + req_id not in self._notif_n_consumers + and len(self._notif_n_consumers) >= 8192 + ): + _ev = next(iter(self._notif_n_consumers)) + self._notif_n_consumers.pop(_ev) + self.consumer_notification_counts_by_req.pop(_ev, None) + self._notif_n_consumers[req_id] = int(tp_size) + self.consumer_notification_counts_by_req[req_id] += 1 + logger.debug( + "Early notif arrived for req %s before registration " + "(n_consumers=%s); will settle on registration.", req_id, + tp_size, ) continue @@ -1842,6 +1876,32 @@ def _get_new_notifs(self) -> set[str]: del self.consumer_notification_counts_by_req[req_id] self._reqs_to_process.remove(req_id) self._reqs_to_send.pop(req_id, None) + # FIX(early-notif race): drop any staged orphan entry so it + # doesn't linger until FIFO eviction drops an unrelated id. + self._notif_n_consumers.pop(req_id, None) + # FIX(best_of fan-out): release un-pulled siblings of this parent. + parent = self._best_of_parent(req_id) + if parent is not None and parent not in self._pulled_bases: + self._pulled_bases[parent] = None + if len(self._pulled_bases) > 8192: + self._pulled_bases.pop(next(iter(self._pulled_bases))) + _siblings = [ + r + for r in list(self._reqs_to_process) + if self._best_of_parent(r) == parent + ] + for _sib in _siblings: + notified_req_ids.add(_sib) + self.consumer_notification_counts_by_req.pop(_sib, None) + self._notif_n_consumers.pop(_sib, None) + self._reqs_to_process.discard(_sib) + self._reqs_to_send.pop(_sib, None) + if _siblings: + logger.debug( + "best_of fan-out: parent %s released %d sibling(s)", + parent, + len(_siblings), + ) return notified_req_ids def _handle_heartbeat(self, payload: str) -> None: @@ -1975,6 +2035,30 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): # expiration for requests that have not been read from D yet. for req_id in metadata.reqs_in_batch: self._reqs_to_process.add(req_id) + # FIX(early-notif race): settle if a notif arrived before registration. + if req_id in self._notif_n_consumers: + assert self.transfer_topo is not None + _n = self._notif_n_consumers[req_id] + _cpp = -self.transfer_topo.tp_ratio(_n) if _n > self.world_size else 1 + if self.consumer_notification_counts_by_req[req_id] >= _cpp: + self._late_released.add(req_id) + self.consumer_notification_counts_by_req.pop(req_id, None) + self._notif_n_consumers.pop(req_id, None) + self._reqs_to_process.discard(req_id) + self._reqs_to_send.pop(req_id, None) + logger.debug( + "Settled early notif for req %s after registration.", + req_id, + ) + continue + # FIX(best_of fan-out): late sibling of an already-pulled parent. + _parent = self._best_of_parent(req_id) + if _parent is not None and _parent in self._pulled_bases: + self._fanout_released.add(req_id) + self.consumer_notification_counts_by_req.pop(req_id, None) + self._notif_n_consumers.pop(req_id, None) + self._reqs_to_process.discard(req_id) + self._reqs_to_send.pop(req_id, None) # Remove all requests that are not to be processed (eg aborted). for req_id in metadata.reqs_not_processed: