diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index de9cbc660666..1929037b41c7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1044,12 +1044,12 @@ def get_finished(self) -> tuple[set[str], set[str]]: if now < expires: break count = self.consumer_notification_counts_by_req.pop(req_id, 0) - logger.warning( - "Releasing expired KV blocks for request %s which were " - "retrieved by %d decode worker(s) within %d seconds.", req_id, - count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) - del self._reqs_to_send[req_id] - done_sending.add(req_id) + if self.try_remove_request(req_id, "timeout"): + done_sending.add(req_id) + logger.warning( + "Releasing expired KV blocks for request %s which were " + "retrieved by %d decode worker(s) within %d seconds.", + req_id, count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) return done_sending, done_recving @@ -1074,9 +1074,14 @@ def _get_new_notifs(self) -> set[str]: # Wait all consumers (D) to be done reading before freeing. if self.consumer_notification_counts_by_req[req_id] == int( tp_ratio): - notified_req_ids.add(req_id) del self.consumer_notification_counts_by_req[req_id] - del self._reqs_to_send[req_id] + if self.try_remove_request(req_id, "consumer_complete"): + notified_req_ids.add(req_id) + else: + logger.debug( + "Request %s completed by all consumers but was" + "already removed (likely timed out)", req_id) + return notified_req_ids def _pop_done_transfers( @@ -1298,6 +1303,24 @@ def get_backend_aware_kv_block_len(self): block_len = self.block_len return block_len + def try_remove_request(self, req_id: str, reason: str) -> bool: + """ + Safely remove a request from pending sends. + + Returns: + True if the request was removed, False if already gone. + """ + timeout_value = self._reqs_to_send.pop(req_id, None) + + if timeout_value is not None: + logger.debug("Removed request %s (reason: %s, was due at: %.2f)", + req_id, reason, timeout_value) + return True + else: + logger.debug("Request %s already removed when attempting %s", + req_id, reason) + return False + @contextlib.contextmanager def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: