diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 990c808a9831..9f305c718f9d 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -3,6 +3,8 @@ # ===================== import region ===================== +import threading + import torch import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp @@ -145,8 +147,19 @@ def __init__( def destroy(self): if self.available and not self.disabled: - with torch.accelerator.device_index(self.device.index): - self.nccl.ncclCommDestroy(self.comm) + # ncclCommAbort can block until all CUDA graphs that + # captured NCCL ops on this comm are destroyed — and + # those graphs are released later in this same main- + # thread teardown, so a direct call here self-deadlocks. + # Run it in a daemon thread with a timeout: the main + # thread proceeds, the graphs drop, and the abort returns. + def _abort(): + with torch.accelerator.device_index(self.device.index): + self.nccl.ncclCommAbort(self.comm) + + abort_thread = threading.Thread(target=_abort, daemon=True) + abort_thread.start() + abort_thread.join(timeout=5.0) self.available = False self.disabled = True diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 57c7397e01b6..5ca8cc7c77f4 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -290,6 +290,12 @@ class NCCLLibrary: # it is better not to call it at all. # ncclResult_t ncclCommDestroy(ncclComm_t comm); Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + # ncclCommAbort frees resources associated with the communicator + # without requiring a collective synchronization. Unlike + # ncclCommDestroy, it is safe to call during an uncoordinated + # shutdown when peer ranks may already be gone. + # ncclResult_t ncclCommAbort(ncclComm_t comm); + Function("ncclCommAbort", ncclResult_t, [ncclComm_t]), # ncclResult_t ncclGroupStart(); Function("ncclGroupStart", ncclResult_t, []), # ncclResult_t ncclGroupEnd(); @@ -548,6 +554,9 @@ def ncclBroadcast( def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + def ncclCommAbort(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommAbort"](comm)) + def ncclGroupStart(self) -> None: self.NCCL_CHECK(self._funcs["ncclGroupStart"]())