diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index b8f74bacb00d..7eb4924e047a 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -120,6 +120,7 @@ def empty( shape: Sequence[int], dtype: str, device: Optional[Device] = None, + worker0_only: bool = False, ) -> DRef: """Create an empty NDArray on all workers and attach them to a DRef. @@ -127,20 +128,27 @@ def empty( ---------- shape : tuple of int The shape of the NDArray. + dtype : str The data type of the NDArray. + device : Optional[Device] = None The device of the NDArray. + worker0_only: bool + If False (default), allocate an array on each worker. If + True, only allocate an array on worker0. + Returns ------- array : DRef The created NDArray. + """ if device is None: device = Device(device_type=0, device_id=0) func = self._get_cached_method("runtime.disco.empty") - return func(ShapeTuple(shape), dtype, device) + return func(ShapeTuple(shape), dtype, device, worker0_only) def get_global_func(self, name: str) -> DRef: """Get a global function on workers. diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 906cea1e323e..26d1c22ee975 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -108,7 +108,17 @@ void SyncWorker() { } TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule); -TVM_REGISTER_GLOBAL("runtime.disco.empty").set_body_typed(DiscoEmptyNDArray); + +TVM_REGISTER_GLOBAL("runtime.disco.empty") + .set_body_typed([](ShapeTuple shape, DataType dtype, Device device, + bool worker0_only) -> Optional { + if (worker0_only && WorkerId()) { + return NullOpt; + } else { + return DiscoEmptyNDArray(shape, dtype, device); + } + }); + TVM_REGISTER_GLOBAL("runtime.disco.allreduce") .set_body_typed([](NDArray send, ShapeTuple reduce_kind, NDArray recv) { int kind = IntegerFromShapeTuple(reduce_kind); diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index b5fc1053b227..7b943cf83f1f 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -106,14 +106,24 @@ void AllGather(NDArray send, NDArray recv) { /*datatype=*/AsNCCLDataType(DataType(send->dtype)), ctx->comm, stream)); } -void BroadcastFromWorker0(NDArray send, NDArray recv) { +void BroadcastFromWorker0(Optional send, NDArray recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); - ICHECK(send.Shape()->Product() == recv.Shape()->Product()); - ShapeTuple shape = send.Shape(); - int64_t numel = shape->Product(); + + const void* send_data = [&]() -> const void* { + int worker_id = ctx->worker->worker_id; + if (worker_id == 0) { + CHECK(send.defined()); + CHECK(send.value().Shape()->Product() == recv.Shape()->Product()); + return send.value()->data; + } else { + return nullptr; + } + }(); + int64_t numel = recv.Shape()->Product(); + deviceStream_t stream = ctx->GetDefaultStream(); - NCCL_CALL(ncclBroadcast(send->data, recv->data, numel, - /*datatype=*/AsNCCLDataType(DataType(send->dtype)), + NCCL_CALL(ncclBroadcast(send_data, recv->data, numel, + /*datatype=*/AsNCCLDataType(DataType(recv->dtype)), /*root=*/0, ctx->comm, stream)); } diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py index 4ecc14babc9b..b94bfdb2bb59 100644 --- a/tests/python/disco/test_ccl.py +++ b/tests/python/disco/test_ccl.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=missing-docstring """Tests for NCCL/RCCL""" + import tempfile import numpy as np @@ -108,7 +109,7 @@ def test_broadcast_from_worker0(session_kind, ccl): sess.init_ccl(ccl, *devices) array = np.arange(12, dtype="float32").reshape(3, 4) - d_array = sess.empty((3, 4), "float32") + d_array = sess.empty((3, 4), "float32", worker0_only=True) d_array.debug_copy_from(0, array) dst_array = sess.empty((3, 4), "float32") sess.broadcast_from_worker0(d_array, dst_array) @@ -118,16 +119,17 @@ def test_broadcast_from_worker0(session_kind, ccl): @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) -def test_scatter(session_kind, ccl): +def test_scatter(session_kind, ccl, capfd): devices = [0, 1] sess = session_kind(num_workers=len(devices)) sess.init_ccl(ccl, *devices) array = np.arange(36, dtype="float32").reshape(3, 4, 3) - d_src = sess.empty((3, 4, 3), "float32") + d_src = sess.empty((3, 4, 3), "float32", worker0_only=True) d_dst = sess.empty((3, 3, 2), "float32") d_src.debug_copy_from(0, array) + sess.scatter_from_worker0(d_src, d_dst) np.testing.assert_equal( @@ -139,17 +141,22 @@ def test_scatter(session_kind, ccl): array.flat[18:].reshape(3, 3, 2), ) + captured = capfd.readouterr() + assert ( + not captured.err + ), "No warning messages should be generated from disco.Session.scatter_from_worker0" + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) -def test_gather(session_kind, ccl): +def test_gather(session_kind, ccl, capfd): devices = [0, 1] sess = session_kind(num_workers=len(devices)) sess.init_ccl(ccl, *devices) array = np.arange(36, dtype="float32") d_src = sess.empty((3, 3, 2), "float32") - d_dst = sess.empty((3, 4, 3), "float32") + d_dst = sess.empty((3, 4, 3), "float32", worker0_only=True) d_src.debug_copy_from(0, array[:18]) d_src.debug_copy_from(1, array[18:]) sess.gather_to_worker0(d_src, d_dst) @@ -158,6 +165,11 @@ def test_gather(session_kind, ccl): array.reshape(3, 4, 3), ) + captured = capfd.readouterr() + assert ( + not captured.err + ), "No warning messages should be generated from disco.Session.gather_to_worker0" + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl)