diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 7b943cf83f1f..bba42ed3bdfe 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -67,9 +67,21 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); DiscoWorker* worker = DiscoWorker::ThreadLocal(); ICHECK(worker != nullptr); + CHECK_EQ(unique_id_bytes.size(), NCCL_UNIQUE_ID_BYTES) << "ValueError: The length of unique_id must be " << NCCL_UNIQUE_ID_BYTES << ", but got " << unique_id_bytes.size() << "."; + + CHECK(!ctx->comm) << "Cannot initialize CCL, " + << "the previous thread-global comm still exists, " + << "and has not been destructed"; + CHECK(!ctx->default_stream) << "Cannot initialize CCL, " + << "the previous thread-global stream still exists, " + << "and has not been destructed"; + CHECK(!ctx->worker) << "Cannot initialize CCL, " + << "the previous thread-global worker still exists, " + << "and has not been destructed"; + // Step up local context of NCCL int device_id = device_ids[worker->worker_id]; SetDevice(device_id); diff --git a/src/runtime/disco/nccl/nccl_context.h b/src/runtime/disco/nccl/nccl_context.h index 9d1b8b933a83..3fb281f2cb7c 100644 --- a/src/runtime/disco/nccl/nccl_context.h +++ b/src/runtime/disco/nccl/nccl_context.h @@ -118,16 +118,23 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) { } struct CCLThreadLocalContext { - DiscoWorker* worker; + DiscoWorker* worker = nullptr; int device_id; deviceStream_t default_stream = nullptr; - ncclComm_t comm; + ncclComm_t comm = nullptr; + + ~CCLThreadLocalContext() { Clear(); } void Clear() { - NCCL_CALL(ncclCommDestroy(comm)); - if (default_stream != nullptr) { + if (comm) { + NCCL_CALL(ncclCommDestroy(comm)); + comm = nullptr; + } + if (default_stream) { StreamDestroy(default_stream); + default_stream = nullptr; } + worker = nullptr; } deviceStream_t GetDefaultStream() {