diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h index 42cec88de673..de043f92be82 100644 --- a/include/tvm/relax/attrs/ccl.h +++ b/include/tvm/relax/attrs/ccl.h @@ -32,14 +32,32 @@ namespace relax { /*! \brief Attributes used in allreduce operators */ struct AllReduceAttrs : public tvm::AttrsNode { String op_type; + bool in_group; TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") { TVM_ATTR_FIELD(op_type).describe( "The type of reduction operation to be applied to the input data. Now only sum is " "supported."); + TVM_ATTR_FIELD(in_group).describe( + "Whether the reduction operation performs in group or globally or in group as default."); } }; // struct AllReduceAttrs +/*! \brief Attributes used in allgather operators */ +struct AllGatherAttrs : public tvm::AttrsNode { + int num_workers; + bool in_group; + + TVM_DECLARE_ATTRS(AllGatherAttrs, "relax.attrs.AllGatherAttrs") { + TVM_ATTR_FIELD(num_workers) + .describe( + "The number of workers, also the number of parts the given buffer should be chunked " + "into."); + TVM_ATTR_FIELD(in_group).describe( + "Whether the allgather operation performs in group or globally or in group as default."); + } +}; // struct AllGatherAttrs + /*! \brief Attributes used in scatter operators */ struct ScatterCollectiveAttrs : public tvm::AttrsNode { int num_workers; diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index cf9967dbfe76..7d15e35fbdbc 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -75,35 +75,40 @@ TVM_DLL NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device devic * \brief Perform an allreduce operation using the underlying communication library * \param send The array send to perform allreduce on * \param reduce_kind The kind of reduction operation (e.g. sum, avg, min, max) + * \param in_group Whether the allreduce operation performs globally or in group as default. * \param recv The array receives the outcome of allreduce */ -TVM_DLL void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv); +TVM_DLL void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv); /*! * \brief Perform an allgather operation using the underlying communication library * \param send The array send to perform allgather on + * \param in_group Whether the allgather operation performs globally or in group as default. * \param recv The array receives the outcome of allgather */ -TVM_DLL void AllGather(NDArray send, NDArray recv); +TVM_DLL void AllGather(NDArray send, bool in_group, NDArray recv); /*! * \brief Perform a broadcast operation from worker-0 * \param send The buffer to be broadcasted + * \param in_group Whether the broadcast operation performs globally or in group as default. * \param recv The buffer receives the broadcasted array */ -TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv); +TVM_DLL void BroadcastFromWorker0(NDArray send, bool in_group, NDArray recv); /*! * \brief Perform a scatter operation from worker-0, chunking the given buffer into equal parts. * \param send For worker-0, it must be provided, and otherwise, the buffer must be None. * The buffer will be divided into equal parts and sent to each worker accordingly. + * \param in_group Whether the scatter operation performs globally or in group as default. * \param recv The receiving buffer, which must not be None. */ -TVM_DLL void ScatterFromWorker0(Optional send, NDArray recv); +TVM_DLL void ScatterFromWorker0(Optional send, bool in_group, NDArray recv); /*! * \brief Perform a gather operation to worker-0. * \param send The sending buffer, which must not be None. + * \param in_group Whether the gather operation performs globally or in group as default. * \param recv For worker-0, it must be provided, and otherwise, the buffer must be None. The * receiving buffer will be divided into equal parts and receive from each worker accordingly. */ -TVM_DLL void GatherToWorker0(NDArray send, Optional recv); +TVM_DLL void GatherToWorker0(NDArray send, bool in_group, Optional recv); /*! * \brief Receive a buffer from worker-0. No-op if the current worker is worker-0. * \param buffer The buffer to be received diff --git a/include/tvm/runtime/disco/disco_worker.h b/include/tvm/runtime/disco/disco_worker.h index 14f8f238074f..301b5b8d626b 100644 --- a/include/tvm/runtime/disco/disco_worker.h +++ b/include/tvm/runtime/disco/disco_worker.h @@ -44,14 +44,16 @@ class DiscoWorker { * \brief Construct a worker. * \param worker_id The id of the worker. * \param num_workers The number of the workers. + * \param num_groups The number of the worker groups. * \param worker_zero_data The data shared between worker-0 and the controler. It's a nullptr if * the worker is not worker-0. * \param channel The communication channel between the worker and the controler. */ - explicit DiscoWorker(int worker_id, int num_workers, WorkerZeroData* worker_zero_data, - DiscoChannel* channel) + explicit DiscoWorker(int worker_id, int num_workers, int num_groups, + WorkerZeroData* worker_zero_data, DiscoChannel* channel) : worker_id(worker_id), num_workers(num_workers), + num_groups(num_groups), default_device(Device{DLDeviceType::kDLCPU, 0}), worker_zero_data(worker_zero_data), channel(channel), @@ -68,6 +70,8 @@ class DiscoWorker { int worker_id; /*! \brief Total number of workers */ int num_workers; + /*! \brief Total number of workers */ + int num_groups; /*! \brief The default device to allocate data if not specified */ Device default_device; /*! \brief The name of the underlying collective communication library. */ diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 71fcce75b292..97fa79096d63 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -264,11 +264,13 @@ class Session : public ObjectRef { /*! * \brief Create a session backed by a thread pool of workers * \param num_workers The number of workers. + * \param num_groups The number of worker groups. */ - TVM_DLL static Session ThreadedSession(int num_workers); + TVM_DLL static Session ThreadedSession(int num_workers, int num_groups); /*! * \brief Create a session backed by pipe-based multiprocessing * \param num_workers The number of workers. + * \param num_groups The number of worker groups. * \param process_pool_creator The name of a global function that takes `num_workers` as an input, * and returns a PackedFunc, which takes an integer `worker_id` as the input and returns None. * When `worker-id` is 0, it shuts down the process pool; Otherwise, it retursn a tuple @@ -277,8 +279,8 @@ class Session : public ObjectRef { * \note Worker-0 is always co-located with the controler as a separate thread, and therefore * worker-0 does not exist in the process pool. */ - TVM_DLL static Session ProcessSession(int num_workers, String process_pool_creator, - String entrypoint); + TVM_DLL static Session ProcessSession(int num_workers, int num_groups, + String process_pool_creator, String entrypoint); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj); }; diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py index 76ce0ff9936f..b1f1554b56f9 100644 --- a/python/tvm/exec/disco_worker.py +++ b/python/tvm/exec/disco_worker.py @@ -99,22 +99,23 @@ def fget_item(param_name: str, param_index: int) -> NDArray: def main(): """Main worker function""" - if len(sys.argv) != 5: - print("Usage: ") + if len(sys.argv) != 6: + print("Usage: ") return worker_id = int(sys.argv[1]) num_workers = int(sys.argv[2]) + num_groups = int(sys.argv[3]) if sys.platform == "win32": import msvcrt # pylint: disable=import-outside-toplevel,import-error - reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY) - writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY) + reader = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY) + writer = msvcrt.open_osfhandle(int(sys.argv[5]), os.O_BINARY) else: - reader = int(sys.argv[3]) - writer = int(sys.argv[4]) + reader = int(sys.argv[4]) + writer = int(sys.argv[5]) worker_func = get_global_func("runtime.disco.WorkerProcess") - worker_func(worker_id, num_workers, reader, writer) + worker_func(worker_id, num_workers, num_groups, reader, writer) if __name__ == "__main__": diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 725a930fd680..dbeba4ca7273 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1668,16 +1668,21 @@ def interpolate( ) -def ccl_allreduce(x: Tensor, op_type: str = "sum", name="ccl_allreduce"): +def ccl_allreduce(x: Tensor, op_type: str = "sum", in_group: bool = True, name="ccl_allreduce"): """CCL Allreduce operator Parameters ---------- - x : Tensor + x : relax.Expr The input tensor. - op_type: str + + op_type : str The type of reduction operation to be applied to the input data. Now "sum", "prod", "min", "max" and "avg" are supported. + + in_group : bool + Whether the reduction operation performs globally or in group as default. + name : str Name hint for this operation. @@ -1686,7 +1691,7 @@ def ccl_allreduce(x: Tensor, op_type: str = "sum", name="ccl_allreduce"): result : Tensor The result tensor of allreduce. """ - return wrap_nested(_op.ccl.allreduce(x._expr, op_type), name) + return wrap_nested(_op.ccl.allreduce(x._expr, op_type, in_group), name) def ccl_broadcast_from_worker0(x: Tensor, name="broadcast_from_worker"): diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py index 21c7946120a7..982c04802156 100644 --- a/python/tvm/relax/op/ccl/ccl.py +++ b/python/tvm/relax/op/ccl/ccl.py @@ -15,25 +15,26 @@ # specific language governing permissions and limitations # under the License. """Relax Collective Communications Library (CCL) operators""" -from typing import Union -from tvm.relax import PrimValue from . import _ffi_api from ...expr import Expr -from ....ir import PrimExpr -def allreduce(x, op_type: str = "sum"): # pylint: disable=invalid-name +def allreduce(x, op_type: str = "sum", in_group: bool = True): # pylint: disable=invalid-name """Allreduce operator Parameters ---------- x : relax.Expr The input tensor. - op_type: str + + op_type : str The type of reduction operation to be applied to the input data. Now "sum", "prod", "min", "max" and "avg" are supported. + in_group : bool + Whether the reduction operation performs globally or in group as default. + Returns ------- result : relax.Expr @@ -44,10 +45,10 @@ def allreduce(x, op_type: str = "sum"): # pylint: disable=invalid-name "Allreduce only supports limited reduction operations, " f"including {supported_op_types}, but got {op_type}." ) - return _ffi_api.allreduce(x, op_type) # type: ignore # pylint: disable=no-member + return _ffi_api.allreduce(x, op_type, in_group) # type: ignore # pylint: disable=no-member -def allgather(x, num_workers: Union[int, PrimExpr, PrimValue]): # pylint: disable=invalid-name +def allgather(x, num_workers: int, in_group: bool = True): # pylint: disable=invalid-name """AllGather operator Parameters @@ -55,17 +56,18 @@ def allgather(x, num_workers: Union[int, PrimExpr, PrimValue]): # pylint: disab x : relax.Expr The input tensor. - num_worker : Union[int, PrimExpr, PrimValue] + num_worker : int The number of workers to gather data from. + in_group : bool + Whether the gather operation performs globally or in group as default. + Returns ------- result : relax.Expr The result of allgather. """ - if not isinstance(num_workers, PrimValue): - num_workers = PrimValue(num_workers) - return _ffi_api.allgather(x, num_workers) # type: ignore # pylint: disable=no-member + return _ffi_api.allgather(x, num_workers, in_group) # type: ignore # pylint: disable=no-member def broadcast_from_worker0(x: Expr) -> Expr: diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py b/python/tvm/relax/transform/legalize_ops/ccl.py index ae0be3c228f5..364dee750e8b 100644 --- a/python/tvm/relax/transform/legalize_ops/ccl.py +++ b/python/tvm/relax/transform/legalize_ops/ccl.py @@ -41,7 +41,7 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr: ) return call_dps_packed( "runtime.disco.allreduce", - [call.args[0], ShapeExpr([op_type_map[op_type_str]])], + [call.args[0], ShapeExpr([op_type_map[op_type_str]]), call.attrs.in_group], out_sinfo=call.args[0].struct_info, ) @@ -57,12 +57,12 @@ def _allgather(_bb: BlockBuilder, call: Call) -> Expr: arg_shape = arg_sinfo.shape.struct_info for i, shape_value in enumerate(arg_shape.values): if i == 0: - output_shape.append(shape_value * call.args[1].value) + output_shape.append(shape_value * call.attrs.num_workers) else: output_shape.append(shape_value) return call_dps_packed( "runtime.disco.allgather", - call.args[0], + [call.args[0], call.attrs.in_group], out_sinfo=TensorStructInfo( shape=output_shape, dtype=arg_sinfo.dtype, @@ -75,7 +75,7 @@ def _allgather(_bb: BlockBuilder, call: Call) -> Expr: def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr: return call_dps_packed( "runtime.disco.broadcast_from_worker0", - call.args[0], + [call.args[0], False], out_sinfo=call.args[0].struct_info, ) @@ -116,7 +116,7 @@ def _scatter_from_worker0(_bb: BlockBuilder, call: Call) -> Expr: output_shape = output_shape[1:] return call_dps_packed( "runtime.disco.scatter_from_worker0", - transpose_var, + [transpose_var, False], out_sinfo=TensorStructInfo( shape=output_shape, dtype=call.args[0].struct_info.dtype, diff --git a/python/tvm/runtime/disco/process_pool.py b/python/tvm/runtime/disco/process_pool.py index 1ad8659d6088..95969e038e0f 100644 --- a/python/tvm/runtime/disco/process_pool.py +++ b/python/tvm/runtime/disco/process_pool.py @@ -38,6 +38,9 @@ class DiscoPopenWorker: num_workers : int The total number of workers. + num_groups : int + The total number of worker groups. + stdout: Union[None, int, IO[Any]] The standard output streams handler specified for the popen process. @@ -49,12 +52,14 @@ def __init__( # pylint: disable=too-many-arguments self, worker_id: int, num_workers: int, + num_groups: int, entrypoint: str = "tvm.exec.disco_worker", stdout=None, stderr=None, ): self.worker_id = worker_id self.num_workers = num_workers + self.num_groups = num_groups self.entrypoint = entrypoint self._proc = None self._stdout = stdout @@ -118,6 +123,7 @@ def start(self): self.entrypoint, str(self.worker_id), str(self.num_workers), + str(self.num_groups), ] if sys.platform == "win32": import msvcrt # pylint: disable=import-error,import-outside-toplevel @@ -172,9 +178,9 @@ def _kill_child_processes(pid): @register_func("runtime.disco.create_process_pool") -def _create_process_pool(num_workers: int, entrypoint: str): +def _create_process_pool(num_workers: int, num_groups: int, entrypoint: str): """Create a process pool where the workers' are [1, num_workers).""" - pool = [DiscoPopenWorker(i, num_workers, entrypoint) for i in range(1, num_workers)] + pool = [DiscoPopenWorker(i, num_workers, num_groups, entrypoint) for i in range(1, num_workers)] def result_func(worker_id: int): nonlocal pool diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index ddde1bc1f323..38c4f2a2354c 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -66,6 +66,7 @@ def debug_copy_from( ---------- worker_id : int The id of the worker to be copied to. + value : Union[numpy.ndarray, NDArray] The value to be copied. """ @@ -121,6 +122,7 @@ def empty( dtype: str, device: Optional[Device] = None, worker0_only: bool = False, + in_group: bool = True, ) -> DRef: """Create an empty NDArray on all workers and attach them to a DRef. @@ -139,6 +141,11 @@ def empty( If False (default), allocate an array on each worker. If True, only allocate an array on worker0. + in_group: bool + Take effective when `worker0_only` is True. If True (default), + allocate an array on each first worker in each group. If + False, only allocate an array on worker0 globally. + Returns ------- array : DRef @@ -148,7 +155,7 @@ def empty( if device is None: device = Device(device_type=0, device_id=0) func = self._get_cached_method("runtime.disco.empty") - return func(ShapeTuple(shape), dtype, device, worker0_only) + return func(ShapeTuple(shape), dtype, device, worker0_only, in_group) def shutdown(self): """Shut down the Disco session""" @@ -244,6 +251,7 @@ def copy_from_worker_0(self, host_array: NDArray, remote_array: DRef) -> None: ---------- host_array : numpy.ndarray The array to be copied to worker-0. + remote_array : NDArray The NDArray on worker-0. """ @@ -255,11 +263,9 @@ def copy_to_worker_0(self, host_array: NDArray, remote_array: Optional[DRef] = N Parameters ---------- host_array : NDArray - The array to be copied to worker-0. remote_array : Optiona[DRef] - The destination NDArray on worker-0. Returns @@ -289,6 +295,7 @@ def load_vm_module( ---------- path : str The path to the VM module file. + device : Optional[Device] = None The device to load the VM module to. Default to the default device of each worker. @@ -312,6 +319,7 @@ def init_ccl(self, ccl: str, *device_ids): - nccl - rccl - mpi + *device_ids : int The device IDs to be used by the underlying communication library. """ @@ -319,20 +327,23 @@ def init_ccl(self, ccl: str, *device_ids): _ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) # type: ignore # pylint: disable=no-member self._clear_ipc_memory_pool() - def broadcast(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef: + def broadcast( + self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None, in_group: bool = True + ) -> DRef: """Broadcast an array to all workers Parameters ---------- src: Union[np.ndarray, NDArray] - The array to be broadcasted. dst: Optional[DRef] - The output array. If None, an array matching the shape and dtype of `src` will be allocated on each worker. + in_group: bool + Whether the broadcast operation performs globally or in group as default. + Returns ------- output_array: DRef @@ -349,38 +360,48 @@ def broadcast(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) dst = self.empty(src.shape, src.dtype) src_dref = self.copy_to_worker_0(src) - self.broadcast_from_worker0(src_dref, dst) + self.broadcast_from_worker0(src_dref, dst, in_group) return dst - def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef: + def broadcast_from_worker0(self, src: DRef, dst: DRef, in_group: bool = True) -> DRef: """Broadcast an array from worker-0 to all other workers. Parameters ---------- - array : DRef - The array to be broadcasted in-place + src: Union[np.ndarray, NDArray] + The array to be broadcasted. + + dst: Optional[DRef] + The output array. If None, an array matching the shape + and dtype of `src` will be allocated on each worker. + + in_group: bool + Whether the broadcast operation performs globally or in group as default. """ func = self._get_cached_method("runtime.disco.broadcast_from_worker0") - func(src, dst) + func(src, in_group, dst) - def scatter(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef: + def scatter( + self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None, in_group: bool = True + ) -> DRef: """Scatter an array across all workers Parameters ---------- src: Union[np.ndarray, NDArray] - The array to be scattered. The first dimension of this array, `src.shape[0]`, must be equal to the number of workers. dst: Optional[DRef] - The output array. If None, an array with compatible shape and the same dtype as `src` will be allocated on each worker. + in_group: bool + Whether the scatter operation performs globally or in group as default. + Returns ------- output_array: DRef @@ -399,41 +420,54 @@ def scatter(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) - dst = self.empty(src.shape[1:], src.dtype) src_dref = self.copy_to_worker_0(src) - self.scatter_from_worker0(src_dref, dst) + self.scatter_from_worker0(src_dref, dst, in_group) return dst - def scatter_from_worker0(self, from_array: DRef, to_array: DRef) -> None: + def scatter_from_worker0(self, from_array: DRef, to_array: DRef, in_group: bool = True) -> None: """Scatter an array from worker-0 to all other workers. Parameters ---------- - from_array : DRef - The array to be scattered from. - to_array : DRef - The array to be scattered to. + src: Union[np.ndarray, NDArray] + The array to be scattered. The first dimension of this + array, `src.shape[0]`, must be equal to the number of + workers. + + dst: Optional[DRef] + The output array. If None, an array with compatible shape + and the same dtype as `src` will be allocated on each + worker. + + in_group: bool + Whether the scatter operation performs globally or in group as default. """ func = self._get_cached_method("runtime.disco.scatter_from_worker0") - func(from_array, to_array) + func(from_array, in_group, to_array) - def gather_to_worker0(self, from_array: DRef, to_array: DRef) -> None: + def gather_to_worker0(self, from_array: DRef, to_array: DRef, in_group: bool = True) -> None: """Gather an array from all other workers to worker-0. Parameters ---------- from_array : DRef The array to be gathered from. + to_array : DRef The array to be gathered to. + + in_group: bool + Whether the gather operation performs globally or in group as default. """ func = self._get_cached_method("runtime.disco.gather_to_worker0") - func(from_array, to_array) + func(from_array, in_group, to_array) def allreduce( self, src: DRef, dst: DRef, op: str = "sum", # pylint: disable=invalid-name + in_group: bool = True, ) -> DRef: """Perform an allreduce operation on an array. @@ -441,6 +475,7 @@ def allreduce( ---------- array : DRef The array to be reduced. + op : str = "sum" The reduce operation to be performed. Available options are: - "sum" @@ -448,17 +483,21 @@ def allreduce( - "min" - "max" - "avg" + + in_group : bool + Whether the reduce operation performs globally or in group as default. """ if op not in REDUCE_OPS: raise ValueError(f"Unsupported reduce op: {op}. Available ops are: {REDUCE_OPS.keys()}") op = ShapeTuple([REDUCE_OPS[op]]) func = self._get_cached_method("runtime.disco.allreduce") - func(src, op, dst) + func(src, op, in_group, dst) def allgather( self, src: DRef, dst: DRef, + in_group: bool = True, ) -> DRef: """Perform an allgather operation on an array. @@ -466,11 +505,15 @@ def allgather( ---------- src : DRef The array to be gathered from. + dst : DRef The array to be gathered to. + + in_group : bool + Whether the reduce operation performs globally or in group as default. """ func = self._get_cached_method("runtime.disco.allgather") - func(src, dst) + func(src, in_group, dst) def _clear_ipc_memory_pool(self): # Clear the IPC memory allocator when the allocator exists. @@ -483,11 +526,12 @@ def _clear_ipc_memory_pool(self): class ThreadedSession(Session): """A Disco session backed by multi-threading.""" - def __init__(self, num_workers: int) -> None: + def __init__(self, num_workers: int, num_groups: int = 1) -> None: """Create a disco session backed by multiple threads in the same process.""" self.__init_handle_by_constructor__( _ffi_api.SessionThreaded, # type: ignore # pylint: disable=no-member num_workers, + num_groups, ) @@ -495,10 +539,13 @@ def __init__(self, num_workers: int) -> None: class ProcessSession(Session): """A Disco session backed by pipe-based multi-processing.""" - def __init__(self, num_workers: int, entrypoint: str = "tvm.exec.disco_worker") -> None: + def __init__( + self, num_workers: int, num_groups: int = 1, entrypoint: str = "tvm.exec.disco_worker" + ) -> None: self.__init_handle_by_constructor__( _ffi_api.SessionProcess, # type: ignore # pylint: disable=no-member num_workers, + num_groups, "runtime.disco.create_process_pool", entrypoint, ) diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index c0fe6f4d88d7..092727cb5115 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -27,9 +27,10 @@ namespace relax { /* relax.ccl.allreduce */ TVM_REGISTER_NODE_TYPE(AllReduceAttrs); -Expr allreduce(Expr x, String op_type) { +Expr allreduce(Expr x, String op_type, bool in_group) { ObjectPtr attrs = make_object(); attrs->op_type = std::move(op_type); + attrs->in_group = std::move(in_group); static const Op& op = Op::Get("relax.ccl.allreduce"); return Call(op, {std::move(x)}, Attrs{attrs}, {}); @@ -51,19 +52,24 @@ TVM_REGISTER_OP("relax.ccl.allreduce") .set_attr("FPurity", Bool(true)); /* relax.ccl.allgather */ -Expr allgather(Expr x, Expr num_workers) { +TVM_REGISTER_NODE_TYPE(AllGatherAttrs); + +Expr allgather(Expr x, int num_workers, bool in_group) { + ObjectPtr attrs = make_object(); + attrs->num_workers = std::move(num_workers); + attrs->in_group = std::move(in_group); + static const Op& op = Op::Get("relax.ccl.allgather"); - return Call(op, {std::move(x), std::move(num_workers)}); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); } TVM_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather); StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { - CHECK_EQ(call->args.size(), 2); - auto input_sinfo = Downcast(call->args[0]->struct_info_); - auto num_workers_sinfo = Downcast(call->args[1]->struct_info_); + TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - auto num_workers = num_workers_sinfo->value; + const auto* attrs = call->attrs.as(); + int num_workers = attrs->num_workers; DataType output_dtype = input_sinfo->dtype; auto input_shape = input_sinfo->GetShape(); @@ -71,7 +77,7 @@ StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { return input_sinfo; } Array output_shape = input_shape.value(); - output_shape.Set(0, floor(output_shape[0] * num_workers.value())); + output_shape.Set(0, floor(output_shape[0] * num_workers)); return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice); } diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h index 3e7f0220c9dc..82ea3935675d 100644 --- a/src/relax/op/ccl/ccl.h +++ b/src/relax/op/ccl/ccl.h @@ -33,10 +33,10 @@ namespace tvm { namespace relax { /*! \brief AllReduce. */ -Expr allreduce(Expr data, String op_type); +Expr allreduce(Expr data, String op_type, bool in_group); /*! \brief AllGather. */ -Expr allgather(Expr data, Expr num_workers); +Expr allgather(Expr data, int num_workers, bool in_group); /*! \brief Broadcast data from worker-0 to all other workers. */ Expr broadcast_from_worker0(Expr data); diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 26d1c22ee975..0cb2ee6f5d6b 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -79,22 +79,24 @@ const PackedFunc& GetCCLFunc(const char* name) { return *pf; } -void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) { - GetCCLFunc("allreduce")(send, static_cast(reduce_kind), recv); +void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv) { + GetCCLFunc("allreduce")(send, static_cast(reduce_kind), in_group, recv); } -void AllGather(NDArray send, NDArray recv) { GetCCLFunc("allgather")(send, recv); } +void AllGather(NDArray send, bool in_group, NDArray recv) { + GetCCLFunc("allgather")(send, in_group, recv); +} -TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv) { - GetCCLFunc("broadcast_from_worker0")(send, recv); +TVM_DLL void BroadcastFromWorker0(NDArray send, bool in_group, NDArray recv) { + GetCCLFunc("broadcast_from_worker0")(send, in_group, recv); } -TVM_DLL void ScatterFromWorker0(Optional send, NDArray recv) { - GetCCLFunc("scatter_from_worker0")(send, recv); +TVM_DLL void ScatterFromWorker0(Optional send, bool in_group, NDArray recv) { + GetCCLFunc("scatter_from_worker0")(send, in_group, recv); } -void GatherToWorker0(NDArray send, Optional recv) { - GetCCLFunc("gather_to_worker0")(send, recv); +void GatherToWorker0(NDArray send, bool in_group, Optional recv) { + GetCCLFunc("gather_to_worker0")(send, in_group, recv); } void RecvFromWorker0(NDArray buffer) { GetCCLFunc("recv_from_worker0")(buffer); } @@ -110,9 +112,13 @@ void SyncWorker() { TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule); TVM_REGISTER_GLOBAL("runtime.disco.empty") - .set_body_typed([](ShapeTuple shape, DataType dtype, Device device, - bool worker0_only) -> Optional { - if (worker0_only && WorkerId()) { + .set_body_typed([](ShapeTuple shape, DataType dtype, Device device, bool worker0_only, + bool in_group) -> Optional { + int worker_id = WorkerId(); + int group_size = + DiscoWorker::ThreadLocal()->num_workers / DiscoWorker::ThreadLocal()->num_groups; + bool is_worker0 = (worker_id == 0 && !in_group) || (in_group && worker_id % group_size == 0); + if (worker0_only && !is_worker0) { return NullOpt; } else { return DiscoEmptyNDArray(shape, dtype, device); @@ -120,10 +126,10 @@ TVM_REGISTER_GLOBAL("runtime.disco.empty") }); TVM_REGISTER_GLOBAL("runtime.disco.allreduce") - .set_body_typed([](NDArray send, ShapeTuple reduce_kind, NDArray recv) { + .set_body_typed([](NDArray send, ShapeTuple reduce_kind, bool in_group, NDArray recv) { int kind = IntegerFromShapeTuple(reduce_kind); CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; - AllReduce(send, static_cast(kind), recv); + AllReduce(send, static_cast(kind), in_group, recv); }); TVM_REGISTER_GLOBAL("runtime.disco.allgather").set_body_typed(AllGather); TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(BroadcastFromWorker0); diff --git a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc index fec5abec86b0..490217d62c79 100644 --- a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc +++ b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc @@ -47,8 +47,8 @@ std::vector AllGatherIPCHandles(nccl::CCLThreadLocalContext* CUDA_CALL(cudaMalloc(&d_src, CUDA_IPC_HANDLE_SIZE)); CUDA_CALL(cudaMalloc(&d_dst, CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers)); CUDA_CALL(cudaMemcpy(d_src, &local_handle, CUDA_IPC_HANDLE_SIZE, cudaMemcpyHostToDevice)); - NCCL_CALL( - ncclAllGather(d_src, d_dst, CUDA_IPC_HANDLE_SIZE, ncclChar, ctx->comm, /*stream=*/nullptr)); + NCCL_CALL(ncclAllGather(d_src, d_dst, CUDA_IPC_HANDLE_SIZE, ncclChar, ctx->global_comm, + /*stream=*/nullptr)); std::vector serial_handles(CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers, 0); CUDA_CALL(cudaMemcpy(serial_handles.data(), d_dst, CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers, cudaMemcpyDefault)); diff --git a/src/runtime/disco/cuda_ipc/custom_allreduce.cc b/src/runtime/disco/cuda_ipc/custom_allreduce.cc index 98fd777b8364..d969005f9476 100644 --- a/src/runtime/disco/cuda_ipc/custom_allreduce.cc +++ b/src/runtime/disco/cuda_ipc/custom_allreduce.cc @@ -65,6 +65,8 @@ inline bool CanApplyTwoShotAllReduce(int64_t num_elements, DLDataType dtype, int void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) { int64_t num_elements = TensorSize(send); nccl::CCLThreadLocalContext* ctx = nccl::CCLThreadLocalContext::Get(); + CHECK_EQ(ctx->worker->num_groups, 1) + << "Custom AllReduce for multiple group is not yet implemented."; tensorrt_llm::AllReduceStrategyType strategy_ = static_cast(strategy); @@ -79,7 +81,7 @@ void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) { deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclAllReduce(send->data, recv->data, num_elements, /*datatype=*/nccl::AsNCCLDataType(DataType(send->dtype)), - /*op=*/ncclSum, ctx->comm, stream)); + /*op=*/ncclSum, ctx->global_comm, stream)); return; } diff --git a/src/runtime/disco/disco_worker_thread.h b/src/runtime/disco/disco_worker_thread.h index 67742cdd0408..8d6b44396f4d 100644 --- a/src/runtime/disco/disco_worker_thread.h +++ b/src/runtime/disco/disco_worker_thread.h @@ -47,12 +47,14 @@ class DiscoWorkerThread { * \brief Construct a worker thread. * \param worker_id The id of the worker. * \param num_workers The total number of workers. + * \param num_groups The total number of worker groups. * \param worker_zero_data_ The data shared between worker-0 and the controler. It's a nullptr if * the worker is not worker-0. * \note This method is implemented in threaded worker, because it depends on creation of a * sub-class of DiscoChannel, DiscoThreadChannel, which is hidden from the public interface. */ - explicit DiscoWorkerThread(int worker_id, int num_workers, WorkerZeroData* worker_zero_data_); + explicit DiscoWorkerThread(int worker_id, int num_workers, int num_groups, + WorkerZeroData* worker_zero_data_); /*! \brief Move constructor. */ explicit DiscoWorkerThread(DiscoWorkerThread&& other) diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc index 7a5d97894680..efe42539cb56 100644 --- a/src/runtime/disco/loader.cc +++ b/src/runtime/disco/loader.cc @@ -326,19 +326,19 @@ NDArray ShardLoaderObj::Load(int weight_index) const { for (const ShardInfo::ShardFunc& shard_func : param_info.shard_info.funcs) { w = this->ApplyShardFunc(shard_func, w); } - ScatterFromWorker0(w, recv); + ScatterFromWorker0(w, /*in_group=*/false, recv); } else { - ScatterFromWorker0(NullOpt, recv); + ScatterFromWorker0(NullOpt, /*in_group=*/false, recv); } return recv; } else { if (worker_id == 0) { NDArray w = LoadDirect(weight_index); - BroadcastFromWorker0(w, w); + BroadcastFromWorker0(w, /*in_group=*/false, w); return w; } else { NDArray w = NDArray::Empty(param->shape, param->dtype, device); - BroadcastFromWorker0(w, w); + BroadcastFromWorker0(w, /*in_group=*/false, w); return w; } } diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index bba42ed3bdfe..2d2c528b5291 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -72,9 +72,12 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { << "ValueError: The length of unique_id must be " << NCCL_UNIQUE_ID_BYTES << ", but got " << unique_id_bytes.size() << "."; - CHECK(!ctx->comm) << "Cannot initialize CCL, " - << "the previous thread-global comm still exists, " - << "and has not been destructed"; + CHECK(!ctx->global_comm) << "Cannot initialize CCL, " + << "the previous thread-global comm still exists, " + << "and has not been destructed"; + CHECK(!ctx->group_comm) << "Cannot initialize CCL, " + << "the previous thread-group comm still exists, " + << "and has not been destructed"; CHECK(!ctx->default_stream) << "Cannot initialize CCL, " << "the previous thread-global stream still exists, " << "and has not been destructed"; @@ -96,34 +99,41 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { // Initialize the communicator ncclUniqueId id; std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES); - NCCL_CALL(ncclCommInitRank(&ctx->comm, worker->num_workers, id, worker->worker_id)); + int group_size = worker->num_workers / worker->num_groups; + NCCL_CALL(ncclCommInitRank(&ctx->global_comm, worker->num_workers, id, worker->worker_id)); + NCCL_CALL(ncclCommSplit(ctx->global_comm, worker->worker_id / group_size, + worker->worker_id % group_size, &ctx->group_comm, NULL)); } -void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) { +void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); ShapeTuple shape = send.Shape(); int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclAllReduce(send->data, recv->data, numel, /*datatype=*/AsNCCLDataType(DataType(send->dtype)), - /*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, stream)); + /*op=*/AsNCCLRedOp(reduce_kind), + in_group ? ctx->group_comm : ctx->global_comm, stream)); } -void AllGather(NDArray send, NDArray recv) { +void AllGather(NDArray send, bool in_group, NDArray recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); ShapeTuple shape = send.Shape(); int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclAllGather(send->data, recv->data, numel, - /*datatype=*/AsNCCLDataType(DataType(send->dtype)), ctx->comm, stream)); + /*datatype=*/AsNCCLDataType(DataType(send->dtype)), + in_group ? ctx->group_comm : ctx->global_comm, stream)); } -void BroadcastFromWorker0(Optional send, NDArray recv) { +void BroadcastFromWorker0(Optional send, bool in_group, NDArray recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + int worker_id = ctx->worker->worker_id; + int group_size = ctx->worker->num_workers / ctx->worker->num_groups; + bool is_sender = (worker_id == 0 && !in_group) || (in_group && worker_id % group_size == 0); const void* send_data = [&]() -> const void* { - int worker_id = ctx->worker->worker_id; - if (worker_id == 0) { + if (is_sender) { CHECK(send.defined()); CHECK(send.value().Shape()->Product() == recv.Shape()->Product()); return send.value()->data; @@ -136,25 +146,28 @@ void BroadcastFromWorker0(Optional send, NDArray recv) { deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclBroadcast(send_data, recv->data, numel, /*datatype=*/AsNCCLDataType(DataType(recv->dtype)), - /*root=*/0, ctx->comm, stream)); + /*root=*/0, in_group ? ctx->group_comm : ctx->global_comm, stream)); } -void ScatterFromWorker0(Optional send, NDArray recv) { +void ScatterFromWorker0(Optional send, bool in_group, NDArray recv) { CHECK(recv.defined()) << "ValueError: buffer `recv` must not be None"; CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; int num_workers = ctx->worker->num_workers; + int group_size = num_workers / ctx->worker->num_groups; + bool is_sender = (worker_id == 0 && !in_group) || (in_group && worker_id % group_size == 0); + int num_receiver = in_group ? group_size : num_workers; deviceStream_t stream = ctx->GetDefaultStream(); - if (worker_id == 0) { + if (is_sender) { CHECK(send.defined()) << "ValueError: buffer `send` must be provided when worker_id == 0."; NDArray buffer = send.value(); int64_t numel = buffer.Shape()->Product(); - CHECK_EQ(numel % num_workers, 0) << "ValueError: Scattering evenly requires that the number " - "of elements in the buffer to be " - "divisible by the number of workers, but got numel = " - << numel << " and " << num_workers << " workers."; + CHECK_EQ(numel % num_receiver, 0) << "ValueError: Scattering evenly requires that the number " + "of elements in the buffer to be " + "divisible by the number of workers, but got numel = " + << numel << " and " << num_receiver << " workers."; DataType dtype(buffer->dtype); - int64_t numel_per_shard = numel / num_workers; + int64_t numel_per_shard = numel / num_receiver; int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); CHECK_EQ(numel_per_shard, recv.Shape()->Product()) << "ValueError: The number of elements in buffer `recv` must be the same as each shard " @@ -163,40 +176,45 @@ void ScatterFromWorker0(Optional send, NDArray recv) { << numel << ", but `recv.size` is " << recv.Shape()->Product() << "."; NCCL_CALL(ncclGroupStart()); uint8_t* data = static_cast(buffer->data); - for (int i = 0; i < num_workers; ++i) { - NCCL_CALL(ncclSend(data, numel_per_shard, AsNCCLDataType(dtype), i, ctx->comm, stream)); + for (int i = 0; i < num_receiver; ++i) { + NCCL_CALL(ncclSend(data, numel_per_shard, AsNCCLDataType(dtype), i, + in_group ? ctx->group_comm : ctx->global_comm, stream)); data += bytes_per_shard; } } else { if (send.defined()) { - LOG(WARNING) << "Buffer `send` must be None when worker_id != 0, but got " - "send = " + LOG(WARNING) << "ValueError: buffer `send` must be None when (worker_id != 0 && !in_group) " + "or (worker_id % group_size != 0 && in_group). However, got send = " << send.get() << ". This will be ignored."; } NCCL_CALL(ncclGroupStart()); } int64_t numel = recv.Shape()->Product(); DataType dtype(recv->dtype); - NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, stream)); + NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, + in_group ? ctx->group_comm : ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } -void GatherToWorker0(NDArray send, Optional recv) { +void GatherToWorker0(NDArray send, bool in_group, Optional recv) { CHECK(send.defined()) << "ValueError: buffer `send` must not be None"; CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; int num_workers = ctx->worker->num_workers; + int group_size = num_workers / ctx->worker->num_groups; + bool is_sender = (worker_id == 0 && !in_group) || (in_group && worker_id % group_size == 0); + int num_receiver = in_group ? group_size : num_workers; deviceStream_t stream = ctx->GetDefaultStream(); - if (worker_id == 0) { + if (is_sender) { CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when worker_id == 0."; NDArray buffer = recv.value(); int64_t numel = buffer.Shape()->Product(); - CHECK_EQ(numel % num_workers, 0) << "ValueError: Gathering evenly requires that the number " - "of elements in the buffer to be " - "divisible by the number of workers, but got numel = " - << numel << " and " << num_workers << " workers."; + CHECK_EQ(numel % num_receiver, 0) << "ValueError: Gathering evenly requires that the number " + "of elements in the buffer to be " + "divisible by the number of workers, but got numel = " + << numel << " and " << num_receiver << " workers."; DataType dtype(buffer->dtype); - int64_t numel_per_shard = numel / num_workers; + int64_t numel_per_shard = numel / num_receiver; int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); CHECK_EQ(numel_per_shard, send.Shape()->Product()) << "ValueError: The number of elements in buffer `send` must be the same as each shard " @@ -205,21 +223,23 @@ void GatherToWorker0(NDArray send, Optional recv) { << numel << ", but `send.size` is " << send.Shape()->Product() << "."; NCCL_CALL(ncclGroupStart()); uint8_t* data = static_cast(buffer->data); - for (int i = 0; i < num_workers; ++i) { - NCCL_CALL(ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i, ctx->comm, stream)); + for (int i = 0; i < num_receiver; ++i) { + NCCL_CALL(ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i, + in_group ? ctx->group_comm : ctx->global_comm, stream)); data += bytes_per_shard; } } else { if (recv.defined()) { - LOG(WARNING) << "ValueError: buffer `recv` must be None when worker_id != 0. However, got " - "recv = " + LOG(WARNING) << "ValueError: buffer `recv` must be None when (worker_id != 0 && !in_group) " + "or (worker_id % group_size != 0 && in_group). However, got recv = " << recv.get() << ". This will be ignored."; } NCCL_CALL(ncclGroupStart()); } int64_t numel = send.Shape()->Product(); DataType dtype(send->dtype); - NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, stream)); + NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, + in_group ? ctx->group_comm : ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } @@ -230,7 +250,7 @@ void RecvFromWorker0(NDArray buffer) { << "ValueError: Worker 0 is not allowed to call RecvFromWorker0."; NCCL_CALL(ncclGroupStart()); NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), 0, - ctx->comm, stream)); + ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } @@ -248,12 +268,14 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl").set_body_ty TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker") .set_body_typed(InitCCLPerWorker); TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce") - .set_body_typed([](NDArray send, int kind, NDArray recv) { + .set_body_typed([](NDArray send, int kind, bool in_group, NDArray recv) { CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; - nccl::AllReduce(send, static_cast(kind), recv); + nccl::AllReduce(send, static_cast(kind), in_group, recv); }); TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allgather") - .set_body_typed([](NDArray send, NDArray recv) { nccl::AllGather(send, recv); }); + .set_body_typed([](NDArray send, bool in_group, NDArray recv) { + nccl::AllGather(send, in_group, recv); + }); TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0") .set_body_typed(BroadcastFromWorker0); TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0") diff --git a/src/runtime/disco/nccl/nccl_context.h b/src/runtime/disco/nccl/nccl_context.h index 3fb281f2cb7c..730479b61ac0 100644 --- a/src/runtime/disco/nccl/nccl_context.h +++ b/src/runtime/disco/nccl/nccl_context.h @@ -121,14 +121,19 @@ struct CCLThreadLocalContext { DiscoWorker* worker = nullptr; int device_id; deviceStream_t default_stream = nullptr; - ncclComm_t comm = nullptr; + ncclComm_t global_comm = nullptr; + ncclComm_t group_comm = nullptr; ~CCLThreadLocalContext() { Clear(); } void Clear() { - if (comm) { - NCCL_CALL(ncclCommDestroy(comm)); - comm = nullptr; + if (group_comm) { + NCCL_CALL(ncclCommDestroy(group_comm)); + group_comm = nullptr; + } + if (global_comm) { + NCCL_CALL(ncclCommDestroy(global_comm)); + global_comm = nullptr; } if (default_stream) { StreamDestroy(default_stream); diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index 179010db8a23..7c8d0796dd81 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -154,9 +154,10 @@ class DiscoProcessChannel final : public DiscoChannel { class ProcessSessionObj final : public BcastSessionObj { public: - explicit ProcessSessionObj(int num_workers, PackedFunc process_pool) + explicit ProcessSessionObj(int num_workers, int num_groups, PackedFunc process_pool) : process_pool_(process_pool), - worker_0_(std::make_unique(0, num_workers, &worker_zero_data_)) { + worker_0_( + std::make_unique(0, num_workers, num_groups, &worker_zero_data_)) { std::vector read_fds; std::vector write_fds; read_fds.reserve(num_workers - 1); @@ -258,18 +259,24 @@ class ProcessSessionObj final : public BcastSessionObj { TVM_REGISTER_OBJECT_TYPE(DiscoDebugObject); TVM_REGISTER_OBJECT_TYPE(ProcessSessionObj); -Session Session::ProcessSession(int num_workers, String process_pool_creator, String entrypoint) { +Session Session::ProcessSession(int num_workers, int num_group, String process_pool_creator, + String entrypoint) { + CHECK_EQ(num_workers % num_group, 0) + << "The number of workers should be divisible by the number of worker group."; const PackedFunc* pf = Registry::Get(process_pool_creator); CHECK(pf) << "ValueError: Cannot find function " << process_pool_creator << " in the registry. Please check if it is registered."; - PackedFunc process_pool = (*pf)(num_workers, entrypoint); - auto n = make_object(num_workers, process_pool); + PackedFunc process_pool = (*pf)(num_workers, num_group, entrypoint); + auto n = make_object(num_workers, num_group, process_pool); return Session(n); } -void WorkerProcess(int worker_id, int num_workers, int64_t read_fd, int64_t write_fd) { +void WorkerProcess(int worker_id, int num_workers, int num_group, int64_t read_fd, + int64_t write_fd) { + CHECK_EQ(num_workers % num_group, 0) + << "The number of workers should be divisible by the number of worker group."; DiscoProcessChannel channel(read_fd, write_fd); - DiscoWorker worker(worker_id, num_workers, nullptr, &channel); + DiscoWorker worker(worker_id, num_workers, num_group, nullptr, &channel); worker.MainLoop(); } diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index 22f906b809d2..cc9a311a6b3f 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -133,20 +133,20 @@ class DiscoThreadChannel final : public DiscoChannel { DiscoThreadedMessageQueue worker_to_controler_; }; -DiscoWorkerThread::DiscoWorkerThread(int worker_id, int num_workers, +DiscoWorkerThread::DiscoWorkerThread(int worker_id, int num_workers, int num_groups, WorkerZeroData* worker_zero_data_) : channel(std::make_unique()), - worker( - std::make_unique(worker_id, num_workers, worker_zero_data_, channel.get())), + worker(std::make_unique(worker_id, num_workers, num_groups, worker_zero_data_, + channel.get())), thread(std::make_unique([worker = this->worker.get()] { worker->MainLoop(); })) { } class ThreadedSessionObj final : public BcastSessionObj { public: - explicit ThreadedSessionObj(int num_workers) { + explicit ThreadedSessionObj(int num_workers, int num_groups) { for (int i = 0; i < num_workers; ++i) { WorkerZeroData* data = (i == 0) ? &worker_zero_data_ : nullptr; - workers_.emplace_back(i, num_workers, data); + workers_.emplace_back(i, num_workers, num_groups, data); } } @@ -185,8 +185,10 @@ class ThreadedSessionObj final : public BcastSessionObj { TVM_REGISTER_OBJECT_TYPE(ThreadedSessionObj); -Session Session::ThreadedSession(int num_workers) { - ObjectPtr n = make_object(num_workers); +Session Session::ThreadedSession(int num_workers, int num_group) { + CHECK_EQ(num_workers % num_group, 0) + << "The number of workers should be divisible by the number of worker group."; + ObjectPtr n = make_object(num_workers, num_group); return Session(std::move(n)); } diff --git a/tests/python/disco/test_callback.py b/tests/python/disco/test_callback.py index 6e2dc9b7470c..3f8d5e9e525b 100644 --- a/tests/python/disco/test_callback.py +++ b/tests/python/disco/test_callback.py @@ -30,16 +30,17 @@ @tvm.testing.requires_nccl def test_callback(): + """Simulate lazy loading of parameters in a callback + + The output of a lazy parameter loading, which would accept a + callback to load the parameters. + """ + @R.function def transform_params( rank_arg: R.Prim(value="rank"), fget_item: R.Callable([R.Object, R.Prim("int64")], R.Object), ): - """Simulate lazy loading of parameters in a callback - - The output of a lazy parameter loading, which would accept a - callback to load the parameters. - """ rank = T.int64() A = fget_item(R.str("A"), R.prim_value(0)) diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py index 5831f245dfaf..6c63f64554a3 100644 --- a/tests/python/disco/test_ccl.py +++ b/tests/python/disco/test_ccl.py @@ -78,6 +78,42 @@ def test_allreduce(session_kind, ccl): np.testing.assert_equal(result, expected) +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_group_allreduce(session_kind, ccl): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array_1 = np.arange(12, dtype="float32").reshape(3, 4) + array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4) + array_3 = np.arange(30, dtype="float32").reshape(5, 6) + array_4 = np.arange(start=1, stop=-29, step=-1, dtype="float32").reshape(5, 6) + d_array_1 = sess.empty((3, 4), "float32") + d_array_2 = sess.empty((5, 6), "float32") + d_array_1.debug_copy_from(0, array_1) + d_array_1.debug_copy_from(1, array_2) + d_array_2.debug_copy_from(2, array_3) + d_array_2.debug_copy_from(3, array_4) + for op, np_op in [ # pylint: disable=invalid-name + ("sum", np.add), + ("prod", np.multiply), + ("min", np.minimum), + ("max", np.maximum), + ("avg", lambda a, b: (a + b) * 0.5), + ]: + dst_array_1 = sess.empty((3, 4), "float32") + dst_array_2 = sess.empty((5, 6), "float32") + sess.allreduce(d_array_1, dst_array_1, op=op, in_group=True) + sess.allreduce(d_array_2, dst_array_2, op=op, in_group=True) + result_1 = dst_array_1.debug_get_from_remote(0).numpy() + result_2 = dst_array_2.debug_get_from_remote(2).numpy() + expected_1 = np_op(array_1, array_2) + expected_2 = np_op(array_3, array_4) + np.testing.assert_equal(result_1, expected_1) + np.testing.assert_equal(result_2, expected_2) + + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) def test_allgather(session_kind, ccl): @@ -101,10 +137,47 @@ def test_allgather(session_kind, ccl): ) +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_group_allgather(session_kind, ccl): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array_1 = np.arange(36, dtype="float32") + array_2 = np.arange(48, dtype="float32") + d_src_1 = sess.empty((3, 3, 2), "float32") + d_dst_1 = sess.empty((3, 4, 3), "float32") + d_src_2 = sess.empty((2, 4, 3), "float32") + d_dst_2 = sess.empty((2, 6, 4), "float32") + d_src_1.debug_copy_from(0, array_1[:18]) + d_src_1.debug_copy_from(1, array_1[18:]) + d_src_2.debug_copy_from(2, array_2[:24]) + d_src_2.debug_copy_from(3, array_2[24:]) + sess.allgather(d_src_1, d_dst_1, in_group=True) + sess.allgather(d_src_2, d_dst_2, in_group=True) + np.testing.assert_equal( + d_dst_1.debug_get_from_remote(0).numpy(), + array_1.reshape(3, 4, 3), + ) + np.testing.assert_equal( + d_dst_1.debug_get_from_remote(1).numpy(), + array_1.reshape(3, 4, 3), + ) + np.testing.assert_equal( + d_dst_2.debug_get_from_remote(2).numpy(), + array_2.reshape(2, 6, 4), + ) + np.testing.assert_equal( + d_dst_2.debug_get_from_remote(3).numpy(), + array_2.reshape(2, 6, 4), + ) + + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) @pytest.mark.parametrize("use_explicit_output", [True, False]) -def test_broadcast_from_worker0(session_kind, ccl, use_explicit_output): +def test_broadcast(session_kind, ccl, use_explicit_output): devices = [0, 1] sess = session_kind(num_workers=len(devices)) sess.init_ccl(ccl, *devices) @@ -123,6 +196,29 @@ def test_broadcast_from_worker0(session_kind, ccl, use_explicit_output): np.testing.assert_equal(result, array) +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_group_broadcast(session_kind, ccl): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array_1 = np.arange(12, dtype="float32").reshape(3, 4) + array_2 = np.multiply(array_1, -1) + + src_array = sess.empty((3, 4), "float32", worker0_only=True, in_group=True) + src_array.debug_copy_from(0, array_1) + src_array.debug_copy_from(2, array_2) + dst_array = sess.empty((3, 4), "float32") + sess.broadcast_from_worker0(src_array, dst_array) + + result_1 = dst_array.debug_get_from_remote(1).numpy() + np.testing.assert_equal(result_1, array_1) + + result_3 = dst_array.debug_get_from_remote(3).numpy() + np.testing.assert_equal(result_3, array_2) + + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) @pytest.mark.parametrize("use_explicit_output", [True, False]) @@ -156,6 +252,45 @@ def test_scatter(session_kind, ccl, use_explicit_output, capfd): ), "No warning messages should be generated from disco.Session.scatter_from_worker0" +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_group_scatter(session_kind, ccl, capfd): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array_1 = np.arange(36, dtype="float32").reshape(2, 6, 3) + array_2 = np.multiply(array_1, -1) + + d_src = sess.empty((2, 6, 3), "float32", worker0_only=True, in_group=True) + d_src.debug_copy_from(0, array_1) + d_src.debug_copy_from(2, array_2) + d_dst = sess.empty((6, 3), "float32") + sess.scatter_from_worker0(d_src, d_dst) + + np.testing.assert_equal( + d_dst.debug_get_from_remote(0).numpy(), + array_1[0, :, :], + ) + np.testing.assert_equal( + d_dst.debug_get_from_remote(1).numpy(), + array_1[1, :, :], + ) + np.testing.assert_equal( + d_dst.debug_get_from_remote(2).numpy(), + array_2[0, :, :], + ) + np.testing.assert_equal( + d_dst.debug_get_from_remote(3).numpy(), + array_2[1, :, :], + ) + + captured = capfd.readouterr() + assert ( + not captured.err + ), "No warning messages should be generated from disco.Session.scatter_from_worker0" + + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) def test_scatter_with_implicit_reshape(session_kind, ccl, capfd): @@ -225,6 +360,37 @@ def test_gather(session_kind, ccl, capfd): ), "No warning messages should be generated from disco.Session.gather_to_worker0" +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_group_gather(session_kind, ccl, capfd): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array_1 = np.arange(36, dtype="float32") + array_2 = np.multiply(array_1, -1) + d_src = sess.empty((3, 3, 2), "float32") + d_dst = sess.empty((3, 4, 3), "float32", worker0_only=True, in_group=True) + d_src.debug_copy_from(0, array_1[:18]) + d_src.debug_copy_from(1, array_1[18:]) + d_src.debug_copy_from(2, array_2[:18]) + d_src.debug_copy_from(3, array_2[18:]) + sess.gather_to_worker0(d_src, d_dst) + np.testing.assert_equal( + d_dst.debug_get_from_remote(0).numpy(), + array_1.reshape(3, 4, 3), + ) + np.testing.assert_equal( + d_dst.debug_get_from_remote(2).numpy(), + array_2.reshape(3, 4, 3), + ) + + captured = capfd.readouterr() + assert ( + not captured.err + ), "No warning messages should be generated from disco.Session.gather_to_worker0" + + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) def test_mlp(session_kind, ccl): # pylint: disable=too-many-locals diff --git a/tests/python/disco/test_loader.py b/tests/python/disco/test_loader.py index 502cbe0b811a..b4e2440857e6 100644 --- a/tests/python/disco/test_loader.py +++ b/tests/python/disco/test_loader.py @@ -22,6 +22,7 @@ import numpy as np import tvm +import tvm.testing from tvm import dlight as dl from tvm import relax as rx from tvm._ffi import register_func @@ -246,7 +247,7 @@ class Module: # pylint: disable=too-few-public-methods @R.function def main( loader: R.Object, - ) -> R.Tuple(R.Tensor((64, 64), "float32"), R.Tensor((16, 128), "float32"),): + ) -> R.Tuple(R.Tensor((64, 64), "float32"), R.Tensor((16, 128), "float32")): R.func_attr({"global_symbol": "main"}) with R.dataflow(): lv0: R.Tensor((64, 64), "float32") = R.call_pure_packed( diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py index ef8ea2e70a25..837b3a14f271 100644 --- a/tests/python/disco/test_session.py +++ b/tests/python/disco/test_session.py @@ -22,13 +22,14 @@ import pytest import tvm +import tvm.testing from tvm import relax as rx from tvm.runtime import ShapeTuple, String from tvm.runtime import disco as di from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T -from tvm.testing import disco as _ +from tvm.exec import disco_worker as _ def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device): @@ -168,14 +169,14 @@ class TestMod: @T.prim_func def t1(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): for i, j in T.grid(16, 8): - with T.block("transpose"): + with T.block("t1"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] @T.prim_func def t2(A: T.Buffer((16, 8), "float32"), B: T.Buffer((8, 16), "float32")): for i, j in T.grid(8, 16): - with T.block("transpose"): + with T.block("t2"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] @@ -183,7 +184,7 @@ def t2(A: T.Buffer((16, 8), "float32"), B: T.Buffer((8, 16), "float32")): def transpose_1( A: R.Tensor((8, 16), dtype="float32") ) -> R.Tensor((16, 8), dtype="float32"): - R.func_attr({"global_symbol": "main"}) + R.func_attr({"global_symbol": "transpose_1"}) cls = TestMod with R.dataflow(): B = R.call_tir(cls.t1, (A,), out_sinfo=R.Tensor((16, 8), dtype="float32")) @@ -194,7 +195,7 @@ def transpose_1( def transpose_2( A: R.Tensor((16, 8), dtype="float32") ) -> R.Tensor((8, 16), dtype="float32"): - R.func_attr({"global_symbol": "main"}) + R.func_attr({"global_symbol": "transpose_2"}) cls = TestMod with R.dataflow(): B = R.call_tir(cls.t2, (A,), out_sinfo=R.Tensor((8, 16), dtype="float32")) @@ -228,11 +229,4 @@ def test_num_workers(session_kind, num_workers): if __name__ == "__main__": - test_int(di.ProcessSession) - test_float(di.ProcessSession) - test_string(di.ProcessSession) - test_string_obj(di.ProcessSession) - test_shape_tuple(di.ProcessSession) - test_ndarray(di.ProcessSession) - test_vm_module(di.ProcessSession) - test_vm_multi_func(di.ProcessSession) + tvm.testing.main() diff --git a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py index 3a76f535d76b..6ee64a18156d 100644 --- a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py +++ b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py @@ -220,7 +220,7 @@ def foo( out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "R"), ) lv3: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.ccl.allreduce( - gv, op_type="sum" + gv, op_type="sum", in_group=False ) return lv3 @@ -1559,7 +1559,7 @@ def foo( out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), ) lv43: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = R.ccl.allreduce( - gv, op_type="sum" + gv, op_type="sum", in_group=False ) lv44: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = R.dist.call_tir_local_view( cls.add, diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py b/tests/python/relax/test_transform_legalize_ops_ccl.py index 63563ee3c95d..9ea4d21d610d 100644 --- a/tests/python/relax/test_transform_legalize_ops_ccl.py +++ b/tests/python/relax/test_transform_legalize_ops_ccl.py @@ -40,11 +40,11 @@ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 10), "float32"): class Expected: @R.function def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): - gv0: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([0])], out_sinfo=R.Tensor((10, 10), dtype="float32")) - gv1: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([1])], out_sinfo=R.Tensor((10, 10), dtype="float32")) - gv2: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([2])], out_sinfo=R.Tensor((10, 10), dtype="float32")) - gv3: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([3])], out_sinfo=R.Tensor((10, 10), dtype="float32")) - gv4: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([4])], out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv0: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([0]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv1: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([1]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv2: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([2]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv3: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([3]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv4: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([4]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) return x # fmt: on @@ -66,8 +66,8 @@ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 10), "float32"): class Expected: @R.function def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): - gv0: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x], out_sinfo=R.Tensor((20, 10), dtype="float32")) - gv1: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x], out_sinfo=R.Tensor((20, 10), dtype="float32")) + gv0: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x, True], out_sinfo=R.Tensor((20, 10), dtype="float32")) + gv1: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x, True], out_sinfo=R.Tensor((20, 10), dtype="float32")) return x # fmt: on @@ -88,7 +88,7 @@ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 10), "float32"): class Expected: @R.function def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): - gv0: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.broadcast_from_worker0", x, out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv0: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.broadcast_from_worker0", [x, False], out_sinfo=R.Tensor((10, 10), dtype="float32")) return x # fmt: on @@ -134,7 +134,7 @@ def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 5), dtype="flo cls = Expected gv = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((10, 2, 5), dtype="float32")) gv1 = R.call_tir(cls.transpose, (gv,), out_sinfo=R.Tensor((2, 10, 5), dtype="float32")) - gv0 = R.call_dps_packed("runtime.disco.scatter_from_worker0", (gv1,), out_sinfo=R.Tensor((10, 5), dtype="float32")) + gv0 = R.call_dps_packed("runtime.disco.scatter_from_worker0", (gv1, False), out_sinfo=R.Tensor((10, 5), dtype="float32")) return gv0 # fmt: on