Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@


# ===================== import region =====================
import threading

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
Expand Down Expand Up @@ -145,8 +147,19 @@ def __init__(

def destroy(self):
if self.available and not self.disabled:
with torch.accelerator.device_index(self.device.index):
self.nccl.ncclCommDestroy(self.comm)
# ncclCommAbort can block until all CUDA graphs that
# captured NCCL ops on this comm are destroyed — and
# those graphs are released later in this same main-
# thread teardown, so a direct call here self-deadlocks.
# Run it in a daemon thread with a timeout: the main
# thread proceeds, the graphs drop, and the abort returns.
def _abort():
with torch.accelerator.device_index(self.device.index):
self.nccl.ncclCommAbort(self.comm)

abort_thread = threading.Thread(target=_abort, daemon=True)
abort_thread.start()
abort_thread.join(timeout=5.0)
self.available = False
self.disabled = True

Expand Down
9 changes: 9 additions & 0 deletions vllm/distributed/device_communicators/pynccl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,12 @@ class NCCLLibrary:
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
# ncclCommAbort frees resources associated with the communicator
# without requiring a collective synchronization. Unlike
# ncclCommDestroy, it is safe to call during an uncoordinated
# shutdown when peer ranks may already be gone.
# ncclResult_t ncclCommAbort(ncclComm_t comm);
Function("ncclCommAbort", ncclResult_t, [ncclComm_t]),
# ncclResult_t ncclGroupStart();
Function("ncclGroupStart", ncclResult_t, []),
# ncclResult_t ncclGroupEnd();
Expand Down Expand Up @@ -548,6 +554,9 @@ def ncclBroadcast(
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))

def ncclCommAbort(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommAbort"](comm))

def ncclGroupStart(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())

Expand Down
Loading