@@ -254,6 +254,57 @@ void RecvFromWorker0(NDArray buffer) {
254254 NCCL_CALL (ncclGroupEnd ());
255255}
256256
257+ void SendToNextGroup (NDArray buffer) {
258+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get ();
259+ deviceStream_t stream = ctx->GetDefaultStream ();
260+ int worker_id = ctx->worker ->worker_id ;
261+ int group_size = ctx->worker ->num_workers / ctx->worker ->num_groups ;
262+ int receiver_id = worker_id + group_size;
263+ CHECK_LT (receiver_id, ctx->worker ->num_workers )
264+ << " The current group is already the last group and there is no such a next group." ;
265+ NCCL_CALL (ncclGroupStart ());
266+ NCCL_CALL (ncclSend (buffer->data , buffer.Shape ()->Product (), AsNCCLDataType (buffer.DataType ()),
267+ receiver_id, ctx->global_comm , stream));
268+ NCCL_CALL (ncclGroupEnd ());
269+ }
270+
271+ void RecvFromPrevGroup (NDArray buffer) {
272+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get ();
273+ deviceStream_t stream = ctx->GetDefaultStream ();
274+ int worker_id = ctx->worker ->worker_id ;
275+ int group_size = ctx->worker ->num_workers / ctx->worker ->num_groups ;
276+ int sender_id = worker_id - group_size;
277+ CHECK_GE (sender_id, 0 )
278+ << " The current group is already the first group and there is no such a previous group." ;
279+ NCCL_CALL (ncclGroupStart ());
280+ NCCL_CALL (ncclRecv (buffer->data , buffer.Shape ()->Product (), AsNCCLDataType (buffer.DataType ()),
281+ sender_id, ctx->global_comm , stream));
282+ NCCL_CALL (ncclGroupEnd ());
283+ }
284+
285+ void SendToWorker (NDArray buffer, int receiver_id) {
286+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get ();
287+ deviceStream_t stream = ctx->GetDefaultStream ();
288+ int worker_id = ctx->worker ->worker_id ;
289+ CHECK (receiver_id >= 0 && receiver_id < ctx->worker ->num_workers )
290+ << " Invalid receiver id " << receiver_id << " . The world size is "
291+ << ctx->worker ->num_workers ;
292+ CHECK_NE (worker_id, receiver_id) << " Cannot send to worker itself." ;
293+ NCCL_CALL (ncclSend (buffer->data , buffer.Shape ()->Product (), AsNCCLDataType (buffer.DataType ()),
294+ receiver_id, ctx->global_comm , stream));
295+ }
296+
297+ void RecvFromWorker (NDArray buffer, int sender_id) {
298+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get ();
299+ deviceStream_t stream = ctx->GetDefaultStream ();
300+ int worker_id = ctx->worker ->worker_id ;
301+ CHECK (sender_id >= 0 && sender_id < ctx->worker ->num_workers )
302+ << " Invalid sender id " << sender_id << " . The world size is " << ctx->worker ->num_workers ;
303+ CHECK_NE (worker_id, sender_id) << " Cannot receive from the worker itself." ;
304+ NCCL_CALL (ncclRecv (buffer->data , buffer.Shape ()->Product (), AsNCCLDataType (buffer.DataType ()),
305+ sender_id, ctx->global_comm , stream));
306+ }
307+
257308void SyncWorker () {
258309 CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get ();
259310 ICHECK (ctx->worker != nullptr );
@@ -284,8 +335,43 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0")
284335 .set_body_typed(GatherToWorker0);
285336TVM_REGISTER_GLOBAL (" runtime.disco." TVM_DISCO_CCL_NAME " .recv_from_worker0" )
286337 .set_body_typed(RecvFromWorker0);
338+ TVM_REGISTER_GLOBAL (" runtime.disco." TVM_DISCO_CCL_NAME " .send_to_next_group" )
339+ .set_body_typed(SendToNextGroup);
340+ TVM_REGISTER_GLOBAL (" runtime.disco." TVM_DISCO_CCL_NAME " .recv_from_prev_group" )
341+ .set_body_typed(RecvFromPrevGroup);
342+ TVM_REGISTER_GLOBAL (" runtime.disco." TVM_DISCO_CCL_NAME " .send_to_worker" )
343+ .set_body_typed(SendToWorker);
344+ TVM_REGISTER_GLOBAL (" runtime.disco." TVM_DISCO_CCL_NAME " .recv_from_worker" )
345+ .set_body_typed(RecvFromWorker);
287346TVM_REGISTER_GLOBAL (" runtime.disco." TVM_DISCO_CCL_NAME " .sync_worker" ).set_body_typed(SyncWorker);
288347
348+ TVM_REGISTER_GLOBAL (" runtime.disco." TVM_DISCO_CCL_NAME
349+ " .test_send_to_next_group_recv_from_prev_group" )
350+ .set_body_typed([](NDArray buffer) {
351+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get ();
352+ CHECK_EQ (ctx->worker ->num_workers , 4 ) << " The test requires the world size to be 4." ;
353+ CHECK_EQ (ctx->worker ->num_groups , 2 ) << " The test requires the group size to be 2." ;
354+ int group_size = ctx->worker ->num_workers / ctx->worker ->num_groups ;
355+ int group_id = ctx->worker ->worker_id / group_size;
356+ if (group_id == 0 ) {
357+ tvm::runtime::nccl::SendToNextGroup (buffer);
358+ } else {
359+ tvm::runtime::nccl::RecvFromPrevGroup (buffer);
360+ }
361+ });
362+
363+ TVM_REGISTER_GLOBAL (" runtime.disco." TVM_DISCO_CCL_NAME " .test_worker2_sends_to_worker0" )
364+ .set_body_typed([](NDArray buffer) {
365+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get ();
366+ CHECK_EQ (ctx->worker ->num_workers , 4 ) << " The test requires the world size to be 4." ;
367+ CHECK_EQ (ctx->worker ->num_groups , 2 ) << " The test requires the group size to be 2." ;
368+ if (ctx->worker ->worker_id == 2 ) {
369+ tvm::runtime::nccl::SendToWorker (buffer, 0 );
370+ } else if (ctx->worker ->worker_id == 0 ) {
371+ tvm::runtime::nccl::RecvFromWorker (buffer, 2 );
372+ }
373+ });
374+
289375} // namespace nccl
290376} // namespace runtime
291377} // namespace tvm
0 commit comments