Skip to content

Commit c6a8a80

Browse files
authored
[Disco] Allow allocation that only exists on worker0 (#16993)
The `disco.Session.scatter_from_worker0` function expects a `DRef` which an `NDArray` on worker 0, and `NullOpt` on all other workers. Prior to this commit, there was no method in the `disco.Session` that could be used to make such a `DRef`. As a result, every use of `scatter_from_worker0` generated an error, stating that non-zero workers should have `NullOpt` as their `send` argument. This commit adds a `worker0_only: bool` argument to `disco.Session.empty`. This can be used to generate an allocation that only exists on worker zero, suitable for use in `scatter_from_worker0`.
1 parent c2d14ae commit c6a8a80

File tree

4 files changed

+53
-13
lines changed

4 files changed

+53
-13
lines changed

python/tvm/runtime/disco/session.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,27 +120,35 @@ def empty(
120120
shape: Sequence[int],
121121
dtype: str,
122122
device: Optional[Device] = None,
123+
worker0_only: bool = False,
123124
) -> DRef:
124125
"""Create an empty NDArray on all workers and attach them to a DRef.
125126
126127
Parameters
127128
----------
128129
shape : tuple of int
129130
The shape of the NDArray.
131+
130132
dtype : str
131133
The data type of the NDArray.
134+
132135
device : Optional[Device] = None
133136
The device of the NDArray.
134137
138+
worker0_only: bool
139+
If False (default), allocate an array on each worker. If
140+
True, only allocate an array on worker0.
141+
135142
Returns
136143
-------
137144
array : DRef
138145
The created NDArray.
146+
139147
"""
140148
if device is None:
141149
device = Device(device_type=0, device_id=0)
142150
func = self._get_cached_method("runtime.disco.empty")
143-
return func(ShapeTuple(shape), dtype, device)
151+
return func(ShapeTuple(shape), dtype, device, worker0_only)
144152

145153
def shutdown(self):
146154
"""Shut down the Disco session"""

src/runtime/disco/builtin.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,17 @@ void SyncWorker() {
108108
}
109109

110110
TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule);
111-
TVM_REGISTER_GLOBAL("runtime.disco.empty").set_body_typed(DiscoEmptyNDArray);
111+
112+
TVM_REGISTER_GLOBAL("runtime.disco.empty")
113+
.set_body_typed([](ShapeTuple shape, DataType dtype, Device device,
114+
bool worker0_only) -> Optional<NDArray> {
115+
if (worker0_only && WorkerId()) {
116+
return NullOpt;
117+
} else {
118+
return DiscoEmptyNDArray(shape, dtype, device);
119+
}
120+
});
121+
112122
TVM_REGISTER_GLOBAL("runtime.disco.allreduce")
113123
.set_body_typed([](NDArray send, ShapeTuple reduce_kind, NDArray recv) {
114124
int kind = IntegerFromShapeTuple(reduce_kind);

src/runtime/disco/nccl/nccl.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,24 @@ void AllGather(NDArray send, NDArray recv) {
106106
/*datatype=*/AsNCCLDataType(DataType(send->dtype)), ctx->comm, stream));
107107
}
108108

109-
void BroadcastFromWorker0(NDArray send, NDArray recv) {
109+
void BroadcastFromWorker0(Optional<NDArray> send, NDArray recv) {
110110
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
111-
ICHECK(send.Shape()->Product() == recv.Shape()->Product());
112-
ShapeTuple shape = send.Shape();
113-
int64_t numel = shape->Product();
111+
112+
const void* send_data = [&]() -> const void* {
113+
int worker_id = ctx->worker->worker_id;
114+
if (worker_id == 0) {
115+
CHECK(send.defined());
116+
CHECK(send.value().Shape()->Product() == recv.Shape()->Product());
117+
return send.value()->data;
118+
} else {
119+
return nullptr;
120+
}
121+
}();
122+
int64_t numel = recv.Shape()->Product();
123+
114124
deviceStream_t stream = ctx->GetDefaultStream();
115-
NCCL_CALL(ncclBroadcast(send->data, recv->data, numel,
116-
/*datatype=*/AsNCCLDataType(DataType(send->dtype)),
125+
NCCL_CALL(ncclBroadcast(send_data, recv->data, numel,
126+
/*datatype=*/AsNCCLDataType(DataType(recv->dtype)),
117127
/*root=*/0, ctx->comm, stream));
118128
}
119129

tests/python/disco/test_ccl.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
# pylint: disable=missing-docstring
1818
"""Tests for NCCL/RCCL"""
19+
1920
import tempfile
2021

2122
import numpy as np
@@ -108,7 +109,7 @@ def test_broadcast_from_worker0(session_kind, ccl):
108109
sess.init_ccl(ccl, *devices)
109110

110111
array = np.arange(12, dtype="float32").reshape(3, 4)
111-
d_array = sess.empty((3, 4), "float32")
112+
d_array = sess.empty((3, 4), "float32", worker0_only=True)
112113
d_array.debug_copy_from(0, array)
113114
dst_array = sess.empty((3, 4), "float32")
114115
sess.broadcast_from_worker0(d_array, dst_array)
@@ -118,16 +119,17 @@ def test_broadcast_from_worker0(session_kind, ccl):
118119

119120
@pytest.mark.parametrize("session_kind", _all_session_kinds)
120121
@pytest.mark.parametrize("ccl", _ccl)
121-
def test_scatter(session_kind, ccl):
122+
def test_scatter(session_kind, ccl, capfd):
122123
devices = [0, 1]
123124
sess = session_kind(num_workers=len(devices))
124125
sess.init_ccl(ccl, *devices)
125126

126127
array = np.arange(36, dtype="float32").reshape(3, 4, 3)
127-
d_src = sess.empty((3, 4, 3), "float32")
128+
d_src = sess.empty((3, 4, 3), "float32", worker0_only=True)
128129
d_dst = sess.empty((3, 3, 2), "float32")
129130

130131
d_src.debug_copy_from(0, array)
132+
131133
sess.scatter_from_worker0(d_src, d_dst)
132134

133135
np.testing.assert_equal(
@@ -139,17 +141,22 @@ def test_scatter(session_kind, ccl):
139141
array.flat[18:].reshape(3, 3, 2),
140142
)
141143

144+
captured = capfd.readouterr()
145+
assert (
146+
not captured.err
147+
), "No warning messages should be generated from disco.Session.scatter_from_worker0"
148+
142149

143150
@pytest.mark.parametrize("session_kind", _all_session_kinds)
144151
@pytest.mark.parametrize("ccl", _ccl)
145-
def test_gather(session_kind, ccl):
152+
def test_gather(session_kind, ccl, capfd):
146153
devices = [0, 1]
147154
sess = session_kind(num_workers=len(devices))
148155
sess.init_ccl(ccl, *devices)
149156

150157
array = np.arange(36, dtype="float32")
151158
d_src = sess.empty((3, 3, 2), "float32")
152-
d_dst = sess.empty((3, 4, 3), "float32")
159+
d_dst = sess.empty((3, 4, 3), "float32", worker0_only=True)
153160
d_src.debug_copy_from(0, array[:18])
154161
d_src.debug_copy_from(1, array[18:])
155162
sess.gather_to_worker0(d_src, d_dst)
@@ -158,6 +165,11 @@ def test_gather(session_kind, ccl):
158165
array.reshape(3, 4, 3),
159166
)
160167

168+
captured = capfd.readouterr()
169+
assert (
170+
not captured.err
171+
), "No warning messages should be generated from disco.Session.gather_to_worker0"
172+
161173

162174
@pytest.mark.parametrize("session_kind", _all_session_kinds)
163175
@pytest.mark.parametrize("ccl", _ccl)

0 commit comments

Comments
 (0)