Skip to content
113 changes: 113 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2750,3 +2750,116 @@ def test_mla_broadcast_notif_uses_remote_request_id(
f"got {notif!r} (expected {expected_notif!r}, "
f"buggy form would be {bad_notif!r})"
)


class _QueueingFakeNixlWrapper(FakeNixlWrapper):
"""FakeNixlWrapper extension that lets a test queue notifications."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._queued: list[bytes] = []

def queue_notif(self, req_id: str, n_consumers: int = 1) -> None:
self._queued.append(f"{req_id}:{n_consumers}".encode())

def get_new_notifs(self) -> dict[str, list[bytes]]:
if not self._queued:
return {}
notifs = {"agent": list(self._queued)}
self._queued.clear()
return notifs


def _build_producer_worker():
vllm_config = create_vllm_config()
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config,
connector.engine_id,
hand_shake_latency=0,
kv_cache_config=kv_cache_config,
)
return connector.connector_worker


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
_QueueingFakeNixlWrapper,
)
def test_one_sibling_pull_releases_all_registered_siblings(
default_vllm_config, dist_init
):
"""D-side prefix cache pulls only sibling 7; without the fix, siblings
0..6 strand until the 480s expiry."""
# Given a producer holding all 8 best_of siblings of one parent prompt.
# Sibling ids follow vllm/v1/engine/parallel_sampling.py:92's f"{i}_{parent}".
worker = _build_producer_worker()
parent = "cmpl-fb4bc0e0-7dec-4bdc-9774-55b4575bf876-0-9e35b663"
siblings = {f"{i}_{parent}" for i in range(8)}
worker._reqs_to_process = set(siblings)
worker._reqs_to_send = {sib: 0.0 for sib in siblings}

# When only sibling 7's pull-complete notification arrives.
worker.nixl_wrapper.queue_notif(f"7_{parent}", n_consumers=1)
done_sending, _ = worker.get_finished()

# Then all 8 siblings are released, not just sibling 7.
assert done_sending == siblings
assert worker._reqs_to_process == set()
assert worker._reqs_to_send == {}


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
_QueueingFakeNixlWrapper,
)
def test_sibling_registering_after_parent_pulled_is_freed_at_registration(
default_vllm_config, dist_init
):
"""A late-registering sibling of an already-pulled parent must not strand."""
# Given a producer that has already seen one sibling of this parent pulled.
worker = _build_producer_worker()
parent = "cmpl-0fd4875b-320a-4b95-82ff-289a3d30fd2b-0-ac8bcbf3"
worker._pulled_bases[parent] = None

# When a late sibling of the same parent registers via start_load_kv.
late_sibling = f"3_{parent}"
metadata = NixlConnectorMetadata()
metadata.reqs_in_batch = {late_sibling}
worker.start_load_kv(metadata)
done_sending, _ = worker.get_finished()

# Then the late sibling is freed via get_finished, not stranded.
assert late_sibling in done_sending
assert late_sibling not in worker._reqs_to_process


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
_QueueingFakeNixlWrapper,
)
def test_notification_arriving_before_registration_settles_on_registration(
default_vllm_config, dist_init
):
"""The pull-complete notif can race start_load_kv; before the fix it was
dropped and the request stranded until the 480s expiry."""
# Given the consumer's pull-complete notif arrives before the producer
# has registered the request in _reqs_to_process.
worker = _build_producer_worker()
req_id = "7_cmpl-35dba481-9b73-4fce-8636-7a5d40f3167a-0-b48d1be7"
worker.nixl_wrapper.queue_notif(req_id, n_consumers=1)
worker.get_finished()
assert req_id in worker._notif_n_consumers
assert worker.consumer_notification_counts_by_req[req_id] == 1

# When the producer registers the request via start_load_kv.
metadata = NixlConnectorMetadata()
metadata.reqs_in_batch = {req_id}
worker.start_load_kv(metadata)
done_sending, _ = worker.get_finished()

# Then the req is released instead of stranding until the 480s expiry.
assert req_id in done_sending
assert req_id not in worker._notif_n_consumers
assert req_id not in worker._reqs_to_process
92 changes: 88 additions & 4 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,16 @@ def __init__(
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
# FIX(early-notif race): a pull-complete notif can arrive before start_load_kv
# registers the req. Stage tp_size so registration can settle; FIFO-capped so
# never-registering reqs (aborts) can't accumulate.
self._notif_n_consumers: dict[ReqId, int] = {}
self._late_released: set[ReqId] = set()
# FIX(best_of fan-out): n>1 siblings f"{i}_{parent_id}" share one prompt KV;
# D-side prefix cache pulls only one, so the rest must free when any is
# pulled (else strand at VLLM_NIXL_ABORT_REQUEST_TIMEOUT). FIFO-capped.
self._pulled_bases: dict[str, None] = {}
self._fanout_released: set[ReqId] = set()
self.xfer_stats = NixlKVConnectorStats()

self._physical_blocks_per_logical_kv_block = 1
Expand Down Expand Up @@ -1691,6 +1701,12 @@ def post_process_device_kv_on_receive_heterogeneous_attn(
indices=indices,
)

@staticmethod
def _best_of_parent(req_id: str) -> str | None:
# Children are named f"{index}_{parent_id}" (parallel_sampling.py:92).
head, sep, tail = req_id.partition("_")
Copy link
Copy Markdown
Collaborator

@chaunceyjiang chaunceyjiang May 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is more like a fix specifically for NIXL. In reality, other kv_connectors are facing the same issue as well, so I think this is currently a limitation in the PD design itself.

I’m more in favor of my PR:
#38900

because it is a more general solution and is not tied to any specific kv_connector.

return tail if sep and head.isdigit() else None

def get_finished(self) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving on this specific worker.
Expand All @@ -1699,6 +1715,14 @@ def get_finished(self) -> tuple[set[str], set[str]]:
"""
assert self.transfer_topo is not None
done_sending = self._get_new_notifs()
# FIX(best_of fan-out): siblings released at registration.
if self._fanout_released:
done_sending |= self._fanout_released
self._fanout_released = set()
# FIX(early-notif race): merge late settlements.
if self._late_released:
done_sending |= self._late_released
self._late_released = set()
done_recving = self._pop_done_transfers(self._recving_transfers)

# Drain queue of requests where handshake or transfer setup failed.
Expand Down Expand Up @@ -1814,11 +1838,21 @@ def _get_new_notifs(self) -> set[str]:
req_id not in self._reqs_to_send
and req_id not in self._reqs_to_process
):
logger.error(
"Potentially invalid KV blocks for "
"unrecognized request %s were retrieved by "
"a decode worker. They may have expired.",
# FIX(early-notif race): record + count, settle on registration.
if (
req_id not in self._notif_n_consumers
and len(self._notif_n_consumers) >= 8192
):
_ev = next(iter(self._notif_n_consumers))
self._notif_n_consumers.pop(_ev)
self.consumer_notification_counts_by_req.pop(_ev, None)
self._notif_n_consumers[req_id] = int(tp_size)
self.consumer_notification_counts_by_req[req_id] += 1
logger.debug(
"Early notif arrived for req %s before registration "
"(n_consumers=%s); will settle on registration.",
req_id,
tp_size,
)
continue

Expand All @@ -1842,6 +1876,32 @@ def _get_new_notifs(self) -> set[str]:
del self.consumer_notification_counts_by_req[req_id]
self._reqs_to_process.remove(req_id)
self._reqs_to_send.pop(req_id, None)
Comment thread
crazyguitar marked this conversation as resolved.
# FIX(early-notif race): drop any staged orphan entry so it
# doesn't linger until FIFO eviction drops an unrelated id.
self._notif_n_consumers.pop(req_id, None)
# FIX(best_of fan-out): release un-pulled siblings of this parent.
parent = self._best_of_parent(req_id)
if parent is not None and parent not in self._pulled_bases:
self._pulled_bases[parent] = None
if len(self._pulled_bases) > 8192:
self._pulled_bases.pop(next(iter(self._pulled_bases)))
_siblings = [
r
for r in list(self._reqs_to_process)
if self._best_of_parent(r) == parent
]
for _sib in _siblings:
notified_req_ids.add(_sib)
self.consumer_notification_counts_by_req.pop(_sib, None)
self._notif_n_consumers.pop(_sib, None)
self._reqs_to_process.discard(_sib)
self._reqs_to_send.pop(_sib, None)
if _siblings:
logger.debug(
"best_of fan-out: parent %s released %d sibling(s)",
parent,
len(_siblings),
)
return notified_req_ids

def _handle_heartbeat(self, payload: str) -> None:
Expand Down Expand Up @@ -1975,6 +2035,30 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
# expiration for requests that have not been read from D yet.
for req_id in metadata.reqs_in_batch:
self._reqs_to_process.add(req_id)
# FIX(early-notif race): settle if a notif arrived before registration.
if req_id in self._notif_n_consumers:
assert self.transfer_topo is not None
_n = self._notif_n_consumers[req_id]
_cpp = -self.transfer_topo.tp_ratio(_n) if _n > self.world_size else 1
if self.consumer_notification_counts_by_req[req_id] >= _cpp:
self._late_released.add(req_id)
self.consumer_notification_counts_by_req.pop(req_id, None)
self._notif_n_consumers.pop(req_id, None)
self._reqs_to_process.discard(req_id)
self._reqs_to_send.pop(req_id, None)
logger.debug(
"Settled early notif for req %s after registration.",
req_id,
)
continue
# FIX(best_of fan-out): late sibling of an already-pulled parent.
_parent = self._best_of_parent(req_id)
if _parent is not None and _parent in self._pulled_bases:
self._fanout_released.add(req_id)
self.consumer_notification_counts_by_req.pop(req_id, None)
self._notif_n_consumers.pop(req_id, None)
self._reqs_to_process.discard(req_id)
self._reqs_to_send.pop(req_id, None)

# Remove all requests that are not to be processed (eg aborted).
for req_id in metadata.reqs_not_processed:
Expand Down
Loading