Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/tvm/runtime/disco/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,27 +120,35 @@ def empty(
shape: Sequence[int],
dtype: str,
device: Optional[Device] = None,
worker0_only: bool = False,
) -> DRef:
"""Create an empty NDArray on all workers and attach them to a DRef.

Parameters
----------
shape : tuple of int
The shape of the NDArray.

dtype : str
The data type of the NDArray.

device : Optional[Device] = None
The device of the NDArray.

worker0_only: bool
If False (default), allocate an array on each worker. If
True, only allocate an array on worker0.

Returns
-------
array : DRef
The created NDArray.

"""
if device is None:
device = Device(device_type=0, device_id=0)
func = self._get_cached_method("runtime.disco.empty")
return func(ShapeTuple(shape), dtype, device)
return func(ShapeTuple(shape), dtype, device, worker0_only)

def get_global_func(self, name: str) -> DRef:
"""Get a global function on workers.
Expand Down
12 changes: 11 additions & 1 deletion src/runtime/disco/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,17 @@ void SyncWorker() {
}

TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule);
TVM_REGISTER_GLOBAL("runtime.disco.empty").set_body_typed(DiscoEmptyNDArray);

TVM_REGISTER_GLOBAL("runtime.disco.empty")
.set_body_typed([](ShapeTuple shape, DataType dtype, Device device,
bool worker0_only) -> Optional<NDArray> {
if (worker0_only && WorkerId()) {
return NullOpt;
} else {
return DiscoEmptyNDArray(shape, dtype, device);
}
});

TVM_REGISTER_GLOBAL("runtime.disco.allreduce")
.set_body_typed([](NDArray send, ShapeTuple reduce_kind, NDArray recv) {
int kind = IntegerFromShapeTuple(reduce_kind);
Expand Down
22 changes: 16 additions & 6 deletions src/runtime/disco/nccl/nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,24 @@ void AllGather(NDArray send, NDArray recv) {
/*datatype=*/AsNCCLDataType(DataType(send->dtype)), ctx->comm, stream));
}

void BroadcastFromWorker0(NDArray send, NDArray recv) {
void BroadcastFromWorker0(Optional<NDArray> send, NDArray recv) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
ICHECK(send.Shape()->Product() == recv.Shape()->Product());
ShapeTuple shape = send.Shape();
int64_t numel = shape->Product();

const void* send_data = [&]() -> const void* {
int worker_id = ctx->worker->worker_id;
if (worker_id == 0) {
CHECK(send.defined());
CHECK(send.value().Shape()->Product() == recv.Shape()->Product());
return send.value()->data;
} else {
return nullptr;
}
}();
int64_t numel = recv.Shape()->Product();

deviceStream_t stream = ctx->GetDefaultStream();
NCCL_CALL(ncclBroadcast(send->data, recv->data, numel,
/*datatype=*/AsNCCLDataType(DataType(send->dtype)),
NCCL_CALL(ncclBroadcast(send_data, recv->data, numel,
/*datatype=*/AsNCCLDataType(DataType(recv->dtype)),
/*root=*/0, ctx->comm, stream));
}

Expand Down
22 changes: 17 additions & 5 deletions tests/python/disco/test_ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=missing-docstring
"""Tests for NCCL/RCCL"""

import tempfile

import numpy as np
Expand Down Expand Up @@ -108,7 +109,7 @@ def test_broadcast_from_worker0(session_kind, ccl):
sess.init_ccl(ccl, *devices)

array = np.arange(12, dtype="float32").reshape(3, 4)
d_array = sess.empty((3, 4), "float32")
d_array = sess.empty((3, 4), "float32", worker0_only=True)
d_array.debug_copy_from(0, array)
dst_array = sess.empty((3, 4), "float32")
sess.broadcast_from_worker0(d_array, dst_array)
Expand All @@ -118,16 +119,17 @@ def test_broadcast_from_worker0(session_kind, ccl):

@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_scatter(session_kind, ccl):
def test_scatter(session_kind, ccl, capfd):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)

array = np.arange(36, dtype="float32").reshape(3, 4, 3)
d_src = sess.empty((3, 4, 3), "float32")
d_src = sess.empty((3, 4, 3), "float32", worker0_only=True)
d_dst = sess.empty((3, 3, 2), "float32")

d_src.debug_copy_from(0, array)

sess.scatter_from_worker0(d_src, d_dst)

np.testing.assert_equal(
Expand All @@ -139,17 +141,22 @@ def test_scatter(session_kind, ccl):
array.flat[18:].reshape(3, 3, 2),
)

captured = capfd.readouterr()
assert (
not captured.err
), "No warning messages should be generated from disco.Session.scatter_from_worker0"


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_gather(session_kind, ccl):
def test_gather(session_kind, ccl, capfd):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)

array = np.arange(36, dtype="float32")
d_src = sess.empty((3, 3, 2), "float32")
d_dst = sess.empty((3, 4, 3), "float32")
d_dst = sess.empty((3, 4, 3), "float32", worker0_only=True)
d_src.debug_copy_from(0, array[:18])
d_src.debug_copy_from(1, array[18:])
sess.gather_to_worker0(d_src, d_dst)
Expand All @@ -158,6 +165,11 @@ def test_gather(session_kind, ccl):
array.reshape(3, 4, 3),
)

captured = capfd.readouterr()
assert (
not captured.err
), "No warning messages should be generated from disco.Session.gather_to_worker0"


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
Expand Down