Skip to content

Commit 7c2c0d9

Browse files
authored
[Disco][QoL] Implement broadcast/scatter methods for Session (#17035)
* [Disco][QoL] Implement broadcast/scatter methods for Session Prior to this commit, use of the `disco.Session` API to broadcast or scatter an array required several steps from the caller. 1. Allocate memory on worker0 2. Transfer data from the controller to worker0 3. Allocate memory on each worker 4. Broadcast/scatter data from worker0 to all workers While exposing these steps is necessary for performance, especially when used repeatedly, it can be tedious/error-prone to use for initialization that is only performed once. This commit adds utility methods `Session.broadcast` and `Session.scatter`, which are implemented in terms of the existing lower-level methods `Session.broadcast_from_worker0` and `Session.scatter_from_worker0`. These methods perform the transfer from the controller to worker0, and from worker0 to all other workers. * lint fix
1 parent f6aab98 commit 7c2c0d9

File tree

2 files changed

+158
-14
lines changed

2 files changed

+158
-14
lines changed

python/tvm/runtime/disco/session.py

Lines changed: 96 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,17 +249,34 @@ def copy_from_worker_0(self, host_array: NDArray, remote_array: DRef) -> None:
249249
"""
250250
return _ffi_api.SessionCopyFromWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member
251251

252-
def copy_to_worker_0(self, host_array: NDArray, remote_array: DRef) -> None:
252+
def copy_to_worker_0(self, host_array: NDArray, remote_array: Optional[DRef] = None) -> DRef:
253253
"""Copy the controller-side NDArray to worker-0.
254254
255255
Parameters
256256
----------
257-
host_array : numpy.ndarray
258-
The array to be copied from worker-0.
259-
remote_array : NDArray
260-
The NDArray on worker-0.
257+
host_array : NDArray
258+
259+
The array to be copied to worker-0.
260+
261+
remote_array : Optiona[DRef]
262+
263+
The destination NDArray on worker-0.
264+
265+
Returns
266+
-------
267+
output_array: DRef
268+
269+
The DRef containing the copied data on worker0, and
270+
NullOpt on all other workers. If `remote_array` was
271+
provided, this return value is the same as `remote_array`.
272+
Otherwise, it is the newly allocated space.
273+
261274
"""
262-
return _ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member
275+
if remote_array is None:
276+
remote_array = self.empty(host_array.shape, host_array.dtype, worker0_only=True)
277+
278+
_ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member
279+
return remote_array
263280

264281
def load_vm_module(
265282
self,
@@ -302,6 +319,40 @@ def init_ccl(self, ccl: str, *device_ids):
302319
_ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) # type: ignore # pylint: disable=no-member
303320
self._clear_ipc_memory_pool()
304321

322+
def broadcast(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef:
323+
"""Broadcast an array to all workers
324+
325+
Parameters
326+
----------
327+
src: Union[np.ndarray, NDArray]
328+
329+
The array to be broadcasted.
330+
331+
dst: Optional[DRef]
332+
333+
The output array. If None, an array matching the shape
334+
and dtype of `src` will be allocated on each worker.
335+
336+
Returns
337+
-------
338+
output_array: DRef
339+
340+
The DRef containing the broadcasted data on all workers.
341+
If `dst` was provided, this return value is the same as
342+
`dst`. Otherwise, it is the newly allocated space.
343+
344+
"""
345+
if not isinstance(src, NDArray):
346+
src = _as_NDArray(src)
347+
348+
if dst is None:
349+
dst = self.empty(src.shape, src.dtype)
350+
351+
src_dref = self.copy_to_worker_0(src)
352+
self.broadcast_from_worker0(src_dref, dst)
353+
354+
return dst
355+
305356
def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
306357
"""Broadcast an array from worker-0 to all other workers.
307358
@@ -313,6 +364,45 @@ def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
313364
func = self._get_cached_method("runtime.disco.broadcast_from_worker0")
314365
func(src, dst)
315366

367+
def scatter(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef:
368+
"""Scatter an array across all workers
369+
370+
Parameters
371+
----------
372+
src: Union[np.ndarray, NDArray]
373+
374+
The array to be scattered. The first dimension of this
375+
array, `src.shape[0]`, must be equal to the number of
376+
workers.
377+
378+
dst: Optional[DRef]
379+
380+
The output array. If None, an array with compatible shape
381+
and the same dtype as `src` will be allocated on each
382+
worker.
383+
384+
Returns
385+
-------
386+
output_array: DRef
387+
388+
The DRef containing the scattered data on all workers.
389+
If `dst` was provided, this return value is the same as
390+
`dst`. Otherwise, it is the newly allocated space.
391+
392+
"""
393+
assert src.shape[0] == self.num_workers
394+
395+
if not isinstance(src, NDArray):
396+
src = _as_NDArray(src)
397+
398+
if dst is None:
399+
dst = self.empty(src.shape[1:], src.dtype)
400+
401+
src_dref = self.copy_to_worker_0(src)
402+
self.scatter_from_worker0(src_dref, dst)
403+
404+
return dst
405+
316406
def scatter_from_worker0(self, from_array: DRef, to_array: DRef) -> None:
317407
"""Scatter an array from worker-0 to all other workers.
318408

tests/python/disco/test_ccl.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,33 +103,87 @@ def test_allgather(session_kind, ccl):
103103

104104
@pytest.mark.parametrize("session_kind", _all_session_kinds)
105105
@pytest.mark.parametrize("ccl", _ccl)
106-
def test_broadcast_from_worker0(session_kind, ccl):
106+
@pytest.mark.parametrize("use_explicit_output", [True, False])
107+
def test_broadcast_from_worker0(session_kind, ccl, use_explicit_output):
107108
devices = [0, 1]
108109
sess = session_kind(num_workers=len(devices))
109110
sess.init_ccl(ccl, *devices)
110111

111112
array = np.arange(12, dtype="float32").reshape(3, 4)
112-
d_array = sess.empty((3, 4), "float32", worker0_only=True)
113-
d_array.debug_copy_from(0, array)
114-
dst_array = sess.empty((3, 4), "float32")
115-
sess.broadcast_from_worker0(d_array, dst_array)
113+
114+
if use_explicit_output:
115+
src_array = sess.empty((3, 4), "float32", worker0_only=True)
116+
src_array.debug_copy_from(0, array)
117+
dst_array = sess.empty((3, 4), "float32")
118+
sess.broadcast_from_worker0(src_array, dst_array)
119+
else:
120+
dst_array = sess.broadcast(array)
121+
116122
result = dst_array.debug_get_from_remote(1).numpy()
117123
np.testing.assert_equal(result, array)
118124

119125

120126
@pytest.mark.parametrize("session_kind", _all_session_kinds)
121127
@pytest.mark.parametrize("ccl", _ccl)
122-
def test_scatter(session_kind, ccl, capfd):
128+
@pytest.mark.parametrize("use_explicit_output", [True, False])
129+
def test_scatter(session_kind, ccl, use_explicit_output, capfd):
130+
devices = [0, 1]
131+
sess = session_kind(num_workers=len(devices))
132+
sess.init_ccl(ccl, *devices)
133+
134+
array = np.arange(36, dtype="float32").reshape(2, 6, 3)
135+
136+
if use_explicit_output:
137+
d_src = sess.empty((2, 6, 3), "float32", worker0_only=True)
138+
d_dst = sess.empty((6, 3), "float32")
139+
d_src.debug_copy_from(0, array)
140+
sess.scatter_from_worker0(d_src, d_dst)
141+
else:
142+
d_dst = sess.scatter(array)
143+
144+
np.testing.assert_equal(
145+
d_dst.debug_get_from_remote(0).numpy(),
146+
array[0, :, :],
147+
)
148+
np.testing.assert_equal(
149+
d_dst.debug_get_from_remote(1).numpy(),
150+
array[1, :, :],
151+
)
152+
153+
captured = capfd.readouterr()
154+
assert (
155+
not captured.err
156+
), "No warning messages should be generated from disco.Session.scatter_from_worker0"
157+
158+
159+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
160+
@pytest.mark.parametrize("ccl", _ccl)
161+
def test_scatter_with_implicit_reshape(session_kind, ccl, capfd):
162+
"""Scatter may perform an implicit reshape
163+
164+
Scattering elements to the workers requires the total number of
165+
elements to be divisible by the number of workers. It does not
166+
necessarily correspond to scattering across the outermost
167+
dimension. Here, the number of workers (2) and the outermost
168+
dimension (3) are not divisible, but the scatter may still be
169+
performed.
170+
171+
This is only allowed when the caller explicitly uses the
172+
`sess.scatter_from_worker0` method, and is not allowed in
173+
`sess.scatter` method. Because the `sess.scatter` method may
174+
perform an allocation on the disco workers, it requires that the
175+
scatter occur across the outermost dimension.
176+
177+
"""
123178
devices = [0, 1]
124179
sess = session_kind(num_workers=len(devices))
125180
sess.init_ccl(ccl, *devices)
126181

127182
array = np.arange(36, dtype="float32").reshape(3, 4, 3)
183+
128184
d_src = sess.empty((3, 4, 3), "float32", worker0_only=True)
129185
d_dst = sess.empty((3, 3, 2), "float32")
130-
131186
d_src.debug_copy_from(0, array)
132-
133187
sess.scatter_from_worker0(d_src, d_dst)
134188

135189
np.testing.assert_equal(

0 commit comments

Comments
 (0)