Skip to content

Commit ae1be53

Browse files
authored
[Disco] Cross-group and p2p send/receive primitives (#17191)
This PR introduces the disco CCL primitives for cross-group and p2p communication. Specifically, we introduce the send/receive primitives for one group to send a buffer to its next group, where every worker in the first group sends the buffer to the corresponding worker in the second group. The p2p communication refer to the send/receive operations to/from a target global worker.
1 parent 9a07870 commit ae1be53

File tree

5 files changed

+168
-4
lines changed

5 files changed

+168
-4
lines changed

include/tvm/runtime/disco/builtin.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,30 @@ TVM_DLL void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray> recv
114114
* \param buffer The buffer to be received
115115
*/
116116
TVM_DLL void RecvFromWorker0(NDArray buffer);
117+
/*!
118+
* \brief Send a buffer to the corresponding worker in the next group.
119+
* An error is thrown if the worker is already in the last group.
120+
* \param buffer The sending buffer.
121+
*/
122+
TVM_DLL void SendToNextGroup(NDArray buffer);
123+
/*!
124+
* \brief Receive a buffer from the corresponding worker in the previous group.
125+
* An error is thrown if the worker is already in the first group.
126+
* \param buffer The receiving buffer.
127+
*/
128+
TVM_DLL void RecvFromPrevGroup(NDArray buffer);
129+
/*!
130+
* \brief Send a buffer to the target receiver worker (globally across all groups).
131+
* \param buffer The sending buffer.
132+
* \param receiver_id The global receiver worker id.
133+
*/
134+
TVM_DLL void SendToWorker(NDArray buffer, int receiver_id);
135+
/*!
136+
* \brief Receive a buffer from the target sender worker (globally across all groups).
137+
* \param buffer The receiving buffer.
138+
* \param sender_id The global sender worker id.
139+
*/
140+
TVM_DLL void RecvFromWorker(NDArray buffer, int sender_id);
117141
/*! \brief Get the local worker id */
118142
TVM_DLL int WorkerId();
119143
/*!

python/tvm/relax/frontend/nn/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,16 +549,16 @@ def __init__(self, modules: List[Module]):
549549
def __iter__(self):
550550
return iter(self.modules)
551551

552-
def __getitem__(self, idx):
552+
def __getitem__(self, idx: int) -> Module:
553553
return self.modules[idx]
554554

555-
def __setitem__(self, idx, module):
555+
def __setitem__(self, idx: int, module: Module) -> None:
556556
self.modules[idx] = module
557557

558558
def __len__(self):
559559
return len(self.modules)
560560

561-
def append(self, module):
561+
def append(self, module: Module):
562562
"""Add a module to the end of the ModuleList"""
563563
self.modules.append(module)
564564

src/runtime/disco/builtin.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray> recv) {
101101

102102
void RecvFromWorker0(NDArray buffer) { GetCCLFunc("recv_from_worker0")(buffer); }
103103

104+
void SendToNextGroup(NDArray buffer) { GetCCLFunc("send_to_next_group")(buffer); }
105+
106+
void RecvFromPrevGroup(NDArray buffer) { GetCCLFunc("recv_from_prev_group")(buffer); }
107+
108+
void SendToWorker(NDArray buffer, int receiver_id) {
109+
GetCCLFunc("send_to_worker")(buffer, receiver_id);
110+
}
111+
112+
void RecvFromWorker(NDArray buffer, int sender_id) {
113+
GetCCLFunc("recv_from_worker")(buffer, sender_id);
114+
}
115+
104116
int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; }
105117

106118
void SyncWorker() {
@@ -136,6 +148,10 @@ TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(Broad
136148
TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
137149
TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0);
138150
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0);
151+
TVM_REGISTER_GLOBAL("runtime.disco.send_to_next_group").set_body_typed(SendToNextGroup);
152+
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_prev_group").set_body_typed(RecvFromPrevGroup);
153+
TVM_REGISTER_GLOBAL("runtime.disco.send_to_worker").set_body_typed(SendToWorker);
154+
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker").set_body_typed(RecvFromWorker);
139155
TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> ShapeTuple {
140156
return ShapeTuple({WorkerId()});
141157
});

src/runtime/disco/nccl/nccl.cc

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,57 @@ void RecvFromWorker0(NDArray buffer) {
254254
NCCL_CALL(ncclGroupEnd());
255255
}
256256

257+
void SendToNextGroup(NDArray buffer) {
258+
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
259+
deviceStream_t stream = ctx->GetDefaultStream();
260+
int worker_id = ctx->worker->worker_id;
261+
int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
262+
int receiver_id = worker_id + group_size;
263+
CHECK_LT(receiver_id, ctx->worker->num_workers)
264+
<< "The current group is already the last group and there is no such a next group.";
265+
NCCL_CALL(ncclGroupStart());
266+
NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
267+
receiver_id, ctx->global_comm, stream));
268+
NCCL_CALL(ncclGroupEnd());
269+
}
270+
271+
void RecvFromPrevGroup(NDArray buffer) {
272+
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
273+
deviceStream_t stream = ctx->GetDefaultStream();
274+
int worker_id = ctx->worker->worker_id;
275+
int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
276+
int sender_id = worker_id - group_size;
277+
CHECK_GE(sender_id, 0)
278+
<< "The current group is already the first group and there is no such a previous group.";
279+
NCCL_CALL(ncclGroupStart());
280+
NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
281+
sender_id, ctx->global_comm, stream));
282+
NCCL_CALL(ncclGroupEnd());
283+
}
284+
285+
void SendToWorker(NDArray buffer, int receiver_id) {
286+
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
287+
deviceStream_t stream = ctx->GetDefaultStream();
288+
int worker_id = ctx->worker->worker_id;
289+
CHECK(receiver_id >= 0 && receiver_id < ctx->worker->num_workers)
290+
<< "Invalid receiver id " << receiver_id << ". The world size is "
291+
<< ctx->worker->num_workers;
292+
CHECK_NE(worker_id, receiver_id) << "Cannot send to worker itself.";
293+
NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
294+
receiver_id, ctx->global_comm, stream));
295+
}
296+
297+
void RecvFromWorker(NDArray buffer, int sender_id) {
298+
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
299+
deviceStream_t stream = ctx->GetDefaultStream();
300+
int worker_id = ctx->worker->worker_id;
301+
CHECK(sender_id >= 0 && sender_id < ctx->worker->num_workers)
302+
<< "Invalid sender id " << sender_id << ". The world size is " << ctx->worker->num_workers;
303+
CHECK_NE(worker_id, sender_id) << "Cannot receive from the worker itself.";
304+
NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
305+
sender_id, ctx->global_comm, stream));
306+
}
307+
257308
void SyncWorker() {
258309
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
259310
ICHECK(ctx->worker != nullptr);
@@ -284,8 +335,43 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0")
284335
.set_body_typed(GatherToWorker0);
285336
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0")
286337
.set_body_typed(RecvFromWorker0);
338+
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group")
339+
.set_body_typed(SendToNextGroup);
340+
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_prev_group")
341+
.set_body_typed(RecvFromPrevGroup);
342+
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker")
343+
.set_body_typed(SendToWorker);
344+
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker")
345+
.set_body_typed(RecvFromWorker);
287346
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker").set_body_typed(SyncWorker);
288347

348+
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
349+
".test_send_to_next_group_recv_from_prev_group")
350+
.set_body_typed([](NDArray buffer) {
351+
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
352+
CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4.";
353+
CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2.";
354+
int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
355+
int group_id = ctx->worker->worker_id / group_size;
356+
if (group_id == 0) {
357+
tvm::runtime::nccl::SendToNextGroup(buffer);
358+
} else {
359+
tvm::runtime::nccl::RecvFromPrevGroup(buffer);
360+
}
361+
});
362+
363+
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0")
364+
.set_body_typed([](NDArray buffer) {
365+
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
366+
CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4.";
367+
CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2.";
368+
if (ctx->worker->worker_id == 2) {
369+
tvm::runtime::nccl::SendToWorker(buffer, 0);
370+
} else if (ctx->worker->worker_id == 0) {
371+
tvm::runtime::nccl::RecvFromWorker(buffer, 2);
372+
}
373+
});
374+
289375
} // namespace nccl
290376
} // namespace runtime
291377
} // namespace tvm

tests/python/disco/test_ccl.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
import tvm
2626
import tvm.testing
2727
from tvm import dlight as dl
28+
from tvm import get_global_func
2829
from tvm import relax as rx
2930
from tvm.runtime import disco as di
3031
from tvm.runtime.relax_vm import VirtualMachine
3132
from tvm.script import relax as R
32-
from tvm import get_global_func
3333

3434
_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
3535
_ccl = [get_global_func("runtime.disco.compiled_ccl")()]
@@ -391,6 +391,44 @@ def test_group_gather(session_kind, ccl, capfd):
391391
), "No warning messages should be generated from disco.Session.gather_to_worker0"
392392

393393

394+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
395+
@pytest.mark.parametrize("ccl", _ccl)
396+
def test_send_to_next_group_receive_from_prev_group(session_kind, ccl):
397+
devices = [0, 1, 2, 3]
398+
sess = session_kind(num_workers=len(devices), num_groups=2)
399+
sess.init_ccl(ccl, *devices)
400+
401+
array_1 = np.arange(12, dtype="float32").reshape(3, 4)
402+
array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
403+
d_array = sess.empty((3, 4), "float32")
404+
d_array.debug_copy_from(0, array_1)
405+
d_array.debug_copy_from(1, array_2)
406+
sess.get_global_func("runtime.disco." + ccl + ".test_send_to_next_group_recv_from_prev_group")(
407+
d_array
408+
)
409+
410+
result_1 = d_array.debug_get_from_remote(2).numpy()
411+
result_2 = d_array.debug_get_from_remote(3).numpy()
412+
np.testing.assert_equal(result_1, array_1)
413+
np.testing.assert_equal(result_2, array_2)
414+
415+
416+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
417+
@pytest.mark.parametrize("ccl", _ccl)
418+
def test_worker2_send_to_worker0(session_kind, ccl):
419+
devices = [0, 1, 2, 3]
420+
sess = session_kind(num_workers=len(devices), num_groups=2)
421+
sess.init_ccl(ccl, *devices)
422+
423+
array = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
424+
d_array = sess.empty((3, 4), "float32")
425+
d_array.debug_copy_from(2, array)
426+
sess.get_global_func("runtime.disco." + ccl + ".test_worker2_sends_to_worker0")(d_array)
427+
428+
result = d_array.debug_get_from_remote(0).numpy()
429+
np.testing.assert_equal(result, array)
430+
431+
394432
@pytest.mark.parametrize("session_kind", _all_session_kinds)
395433
@pytest.mark.parametrize("ccl", _ccl)
396434
def test_mlp(session_kind, ccl): # pylint: disable=too-many-locals

0 commit comments

Comments
 (0)