Skip to content

Commit 3db94d5

Browse files
committed
[Disco][QoL] Implement broadcast/scatter methods for Session
Prior to this commit, use of the `disco.Session` API to broadcast or scatter an array required several steps from the caller. 1. Allocate memory on worker0 2. Transfer data from the controller to worker0 3. Allocate memory on each worker 4. Broadcast/scatter data from worker0 to all workers While exposing these steps is necessary for performance, especially when used repeatedly, it can be tedious/error-prone to use for initialization that is only performed once. This commit adds utility methods `Session.broadcast` and `Session.scatter`, which are implemented in terms of the existing lower-level methods `Session.broadcast_from_worker0` and `Session.scatter_from_worker0`. These methods perform the transfer from the controller to worker0, and from worker0 to all other workers.
1 parent 3cd6673 commit 3db94d5

File tree

2 files changed

+161
-15
lines changed

2 files changed

+161
-15
lines changed

python/tvm/runtime/disco/session.py

Lines changed: 99 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ..._ffi.runtime_ctypes import Device
3030
from ..container import ShapeTuple
3131
from ..ndarray import NDArray
32-
from ..ndarray import array as _as_NDArray
32+
from ..ndarray import array as _as_NDArray, from_dlpack as _tvm_array_from_dlpack
3333
from ..object import Object
3434
from . import _ffi_api, process_pool # pylint: disable=unused-import
3535

@@ -86,6 +86,8 @@ def __call__(self, *args) -> DRef:
8686
return self.session.call_packed(self, *args)
8787

8888

89+
90+
8991
class DModule(DRef):
9092
"""A Module in a Disco session."""
9193

@@ -249,17 +251,34 @@ def copy_from_worker_0(self, host_array: NDArray, remote_array: DRef) -> None:
249251
"""
250252
return _ffi_api.SessionCopyFromWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member
251253

252-
def copy_to_worker_0(self, host_array: NDArray, remote_array: DRef) -> None:
254+
def copy_to_worker_0(self, host_array: NDArray, remote_array: Optional[DRef] = None) -> DRef:
253255
"""Copy the controller-side NDArray to worker-0.
254256
255257
Parameters
256258
----------
257-
host_array : numpy.ndarray
258-
The array to be copied from worker-0.
259-
remote_array : NDArray
260-
The NDArray on worker-0.
259+
host_array : NDArray
260+
261+
The array to be copied to worker-0.
262+
263+
remote_array : Optiona[DRef]
264+
265+
The destination NDArray on worker-0.
266+
267+
Returns
268+
-------
269+
output_array: DRef
270+
271+
The DRef containing the copied data on worker0, and
272+
NullOpt on all other workers. If `remote_array` was
273+
provided, this return value is the same as `remote_array`.
274+
Otherwise, it is the newly allocated space.
275+
261276
"""
262-
return _ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member
277+
if remote_array is None:
278+
remote_array = self.empty(host_array.shape, host_array.dtype, worker0_only=True)
279+
280+
_ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member
281+
return remote_array
263282

264283
def load_vm_module(
265284
self,
@@ -302,6 +321,40 @@ def init_ccl(self, ccl: str, *device_ids):
302321
_ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) # type: ignore # pylint: disable=no-member
303322
self._clear_ipc_memory_pool()
304323

324+
def broadcast(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef:
325+
"""Broadcast an array to all workers
326+
327+
Parameters
328+
----------
329+
src: Union[np.ndarray, NDArray]
330+
331+
The array to be broadcasted.
332+
333+
dst: Optional[DRef]
334+
335+
The output array. If None, an array matching the shape
336+
and dtype of `src` will be allocated on each worker.
337+
338+
Returns
339+
-------
340+
output_array: DRef
341+
342+
The DRef containing the broadcasted data on all workers.
343+
If `dst` was provided, this return value is the same as
344+
`dst`. Otherwise, it is the newly allocated space.
345+
346+
"""
347+
if not isinstance(src, NDArray):
348+
src = _as_NDArray(src)
349+
350+
if dst is None:
351+
dst = self.empty(src.shape, src.dtype)
352+
353+
src_dref = self.copy_to_worker_0(src)
354+
self.broadcast_from_worker0(src_dref, dst)
355+
356+
return dst
357+
305358
def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
306359
"""Broadcast an array from worker-0 to all other workers.
307360
@@ -313,6 +366,45 @@ def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
313366
func = self._get_cached_method("runtime.disco.broadcast_from_worker0")
314367
func(src, dst)
315368

369+
def scatter(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef:
370+
"""Scatter an array across all workers
371+
372+
Parameters
373+
----------
374+
src: Union[np.ndarray, NDArray]
375+
376+
The array to be scattered. The first dimension of this
377+
array, `src.shape[0]`, must be equal to the number of
378+
workers.
379+
380+
dst: Optional[DRef]
381+
382+
The output array. If None, an array with compatible shape
383+
and the same dtype as `src` will be allocated on each
384+
worker.
385+
386+
Returns
387+
-------
388+
output_array: DRef
389+
390+
The DRef containing the scattered data on all workers.
391+
If `dst` was provided, this return value is the same as
392+
`dst`. Otherwise, it is the newly allocated space.
393+
394+
"""
395+
assert src.shape[0] == self.num_workers
396+
397+
if not isinstance(src, NDArray):
398+
src = _as_NDArray(src)
399+
400+
if dst is None:
401+
dst = self.empty(src.shape[1:], src.dtype)
402+
403+
src_dref = self.copy_to_worker_0(src)
404+
self.scatter_from_worker0(src_dref, dst)
405+
406+
return dst
407+
316408
def scatter_from_worker0(self, from_array: DRef, to_array: DRef) -> None:
317409
"""Scatter an array from worker-0 to all other workers.
318410

tests/python/disco/test_ccl.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,33 +103,87 @@ def test_allgather(session_kind, ccl):
103103

104104
@pytest.mark.parametrize("session_kind", _all_session_kinds)
105105
@pytest.mark.parametrize("ccl", _ccl)
106-
def test_broadcast_from_worker0(session_kind, ccl):
106+
@pytest.mark.parametrize("use_explicit_output", [True, False])
107+
def test_broadcast_from_worker0(session_kind, ccl, use_explicit_output):
107108
devices = [0, 1]
108109
sess = session_kind(num_workers=len(devices))
109110
sess.init_ccl(ccl, *devices)
110111

111112
array = np.arange(12, dtype="float32").reshape(3, 4)
112-
d_array = sess.empty((3, 4), "float32", worker0_only=True)
113-
d_array.debug_copy_from(0, array)
114-
dst_array = sess.empty((3, 4), "float32")
115-
sess.broadcast_from_worker0(d_array, dst_array)
113+
114+
if use_explicit_output:
115+
src_array = sess.empty((3, 4), "float32", worker0_only=True)
116+
src_array.debug_copy_from(0, array)
117+
dst_array = sess.empty((3, 4), "float32")
118+
sess.broadcast_from_worker0(src_array, dst_array)
119+
else:
120+
dst_array = sess.broadcast(array)
121+
116122
result = dst_array.debug_get_from_remote(1).numpy()
117123
np.testing.assert_equal(result, array)
118124

119125

120126
@pytest.mark.parametrize("session_kind", _all_session_kinds)
121127
@pytest.mark.parametrize("ccl", _ccl)
122-
def test_scatter(session_kind, ccl, capfd):
128+
@pytest.mark.parametrize("use_explicit_output", [True, False])
129+
def test_scatter(session_kind, ccl, use_explicit_output, capfd):
130+
devices = [0, 1]
131+
sess = session_kind(num_workers=len(devices))
132+
sess.init_ccl(ccl, *devices)
133+
134+
array = np.arange(36, dtype="float32").reshape(2, 6, 3)
135+
136+
if use_explicit_output:
137+
d_src = sess.empty((2, 6, 3), "float32", worker0_only=True)
138+
d_dst = sess.empty((6, 3), "float32")
139+
d_src.debug_copy_from(0, array)
140+
sess.scatter_from_worker0(d_src, d_dst)
141+
else:
142+
d_dst = sess.scatter(array)
143+
144+
np.testing.assert_equal(
145+
d_dst.debug_get_from_remote(0).numpy(),
146+
array[0, :, :],
147+
)
148+
np.testing.assert_equal(
149+
d_dst.debug_get_from_remote(1).numpy(),
150+
array[1, :, :],
151+
)
152+
153+
captured = capfd.readouterr()
154+
assert (
155+
not captured.err
156+
), "No warning messages should be generated from disco.Session.scatter_from_worker0"
157+
158+
159+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
160+
@pytest.mark.parametrize("ccl", _ccl)
161+
def test_scatter_with_implicit_reshape(session_kind, ccl, capfd):
162+
"""Scatter may perform an implicit reshape
163+
164+
Scattering elements to the workers requires the total number of
165+
elements to be divisible by the number of workers. It does not
166+
necessarily correspond to scattering across the outermost
167+
dimension. Here, the number of workers (2) and the outermost
168+
dimension (3) are not divisible, but the scatter may still be
169+
performed.
170+
171+
This is only allowed when the caller explicitly uses the
172+
`sess.scatter_from_worker0` method, and is not allowed in
173+
`sess.scatter` method. Because the `sess.scatter` method may
174+
perform an allocation on the disco workers, it requires that the
175+
scatter occur across the outermost dimension.
176+
177+
"""
123178
devices = [0, 1]
124179
sess = session_kind(num_workers=len(devices))
125180
sess.init_ccl(ccl, *devices)
126181

127182
array = np.arange(36, dtype="float32").reshape(3, 4, 3)
183+
128184
d_src = sess.empty((3, 4, 3), "float32", worker0_only=True)
129185
d_dst = sess.empty((3, 3, 2), "float32")
130-
131186
d_src.debug_copy_from(0, array)
132-
133187
sess.scatter_from_worker0(d_src, d_dst)
134188

135189
np.testing.assert_equal(

0 commit comments

Comments
 (0)