2424#include < tvm/runtime/registry.h>
2525
2626#include < cstring>
27- #include < memory>
2827#include < mutex>
2928#include < sstream>
3029#include < vector>
3938#if TVM_NCCL_RCCL_SWITCH == 0
4039#include < nccl.h>
4140
42- #include " ../../../../3rdparty/trt-llm-allreduce/include/cuda_allreduce.h"
4341#include " ../../cuda/cuda_common.h"
4442#else
4543#include < rccl/rccl.h>
@@ -142,7 +140,6 @@ struct CCLThreadLocalContext {
142140 int device_id;
143141 deviceStream_t default_stream = nullptr ;
144142 ncclComm_t comm;
145- std::unique_ptr<CustomAllReduce> custom_allreduce;
146143
147144 void Clear () {
148145 NCCL_CALL (ncclCommDestroy (comm));
@@ -193,8 +190,6 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) {
193190 worker->ccl = TVM_DISCO_CCL_NAME;
194191 ctx->worker = worker;
195192 ctx->device_id = device_id;
196- ctx->custom_allreduce =
197- std::make_unique<CustomAllReduce>(worker->num_workers , worker->worker_id , ctx->comm );
198193 // Initialize the communicator
199194 ncclUniqueId id;
200195 std::memcpy (id.internal , unique_id_bytes.data (), NCCL_UNIQUE_ID_BYTES);
@@ -206,13 +201,6 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
206201 ShapeTuple shape = send.Shape ();
207202 int64_t numel = shape->Product ();
208203 deviceStream_t stream = ctx->GetDefaultStream ();
209- // TODO(csullivan) make this work
210- // 1. pass type in
211- // 2. src and dest args
212- // 3. some strategy selection outside, if (!enqueu) do nccl?
213- // 3. reduce kind
214- // 4. pass stream in to custom api
215- // ctx->custom_allreduce->enqueue(send->data, numel);
216204 NCCL_CALL (ncclAllReduce (send->data , recv->data , numel,
217205 /* datatype=*/ AsNCCLDataType (DataType (send->dtype )),
218206 /* op=*/ AsNCCLRedOp (reduce_kind), ctx->comm , stream));
0 commit comments