From d29b2f68f8c8cf222a3114c5804f84878c871a32 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Thu, 4 Sep 2025 13:54:35 -0400 Subject: [PATCH 1/3] make deletion atomic in nixl timeout handling Signed-off-by: Will Eaton --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index de9cbc660666..31ed82193009 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1044,12 +1044,14 @@ def get_finished(self) -> tuple[set[str], set[str]]: if now < expires: break count = self.consumer_notification_counts_by_req.pop(req_id, 0) - logger.warning( - "Releasing expired KV blocks for request %s which were " - "retrieved by %d decode worker(s) within %d seconds.", req_id, - count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) - del self._reqs_to_send[req_id] - done_sending.add(req_id) + # pop because it's possible for request to complete at the + # same time as timeout, creating a race condition + if self._reqs_to_send.pop(req_id, None) is not None: + logger.warning( + "Releasing expired KV blocks for request %s which were " + "retrieved by %d decode worker(s) within %d seconds.", + req_id, count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) + done_sending.add(req_id) return done_sending, done_recving From 857953c920fa6a04b9947507721d1c67512a2652 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Thu, 4 Sep 2025 14:19:49 -0400 Subject: [PATCH 2/3] pop in both places by refactoring Signed-off-by: Will Eaton --- .../kv_connector/v1/nixl_connector.py | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 31ed82193009..842b16338496 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1044,14 +1044,12 @@ def get_finished(self) -> tuple[set[str], set[str]]: if now < expires: break count = self.consumer_notification_counts_by_req.pop(req_id, 0) - # pop because it's possible for request to complete at the - # same time as timeout, creating a race condition - if self._reqs_to_send.pop(req_id, None) is not None: + if self.try_remove_request(req_id, "timeout"): + done_sending.add(req_id) logger.warning( "Releasing expired KV blocks for request %s which were " "retrieved by %d decode worker(s) within %d seconds.", req_id, count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) - done_sending.add(req_id) return done_sending, done_recving @@ -1065,20 +1063,18 @@ def _get_new_notifs(self) -> set[str]: for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) - if req_id not in self._reqs_to_send: - logger.error( - "Potentially invalid KV blocks for " - "unrecognized request %s were retrieved by " - "a decode worker. They may have expired.", req_id) - continue - self.consumer_notification_counts_by_req[req_id] += 1 # Wait all consumers (D) to be done reading before freeing. if self.consumer_notification_counts_by_req[req_id] == int( tp_ratio): - notified_req_ids.add(req_id) del self.consumer_notification_counts_by_req[req_id] - del self._reqs_to_send[req_id] + if self.try_remove_request(req_id, "consumer_complete"): + notified_req_ids.add(req_id) + else: + logger.debug( + "Request %s completed by all consumers but was" + "already removed (likely timed out)", req_id) + return notified_req_ids def _pop_done_transfers( @@ -1300,6 +1296,24 @@ def get_backend_aware_kv_block_len(self): block_len = self.block_len return block_len + def try_remove_request(self, req_id: str, reason: str) -> bool: + """ + Safely remove a request from pending sends. + + Returns: + True if the request was removed, False if already gone. + """ + timeout_value = self._reqs_to_send.pop(req_id, None) + + if timeout_value is not None: + logger.debug("Removed request %s (reason: %s, was due at: %.2f)", + req_id, reason, timeout_value) + return True + else: + logger.debug("Request %s already removed when attempting %s", + req_id, reason) + return False + @contextlib.contextmanager def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: From aa1029450c0e7fb823bf70d0f351e0568d1430ed Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Thu, 4 Sep 2025 16:10:51 -0400 Subject: [PATCH 3/3] add back check even though it's racy to prevent memory leak Signed-off-by: Will Eaton --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 842b16338496..1929037b41c7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1063,6 +1063,13 @@ def _get_new_notifs(self) -> set[str]: for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) + if req_id not in self._reqs_to_send: + logger.error( + "Potentially invalid KV blocks for " + "unrecognized request %s were retrieved by " + "a decode worker. They may have expired.", req_id) + continue + self.consumer_notification_counts_by_req[req_id] += 1 # Wait all consumers (D) to be done reading before freeing. if self.consumer_notification_counts_by_req[req_id] == int(