diff --git a/python/ray/util/collective/__init__.py b/python/ray/util/collective/__init__.py index 4ae88660702f..694698474062 100644 --- a/python/ray/util/collective/__init__.py +++ b/python/ray/util/collective/__init__.py @@ -1,11 +1,15 @@ -from ray.util.collective.collective import nccl_available, mpi_available, \ +from ray.util.collective.collective import nccl_available, gloo_available, \ is_group_initialized, init_collective_group, destroy_collective_group, \ - get_rank, get_world_size, allreduce, barrier, reduce, broadcast, \ - allgather, reducescatter, send, recv + declare_collective_group, get_rank, get_world_size, allreduce, \ + allreduce_multigpu, barrier, reduce, reduce_multigpu, broadcast, \ + broadcast_multigpu, allgather, allgather_multigpu, reducescatter, \ + reducescatter_multigpu, send, send_multigpu, recv, recv_multigpu __all__ = [ - "nccl_available", "mpi_available", "is_group_initialized", - "init_collective_group", "destroy_collective_group", "get_rank", - "get_world_size", "allreduce", "barrier", "reduce", "broadcast", - "allgather", "reducescatter", "send", "recv" + "nccl_available", "gloo_available", "is_group_initialized", + "init_collective_group", "destroy_collective_group", + "declare_collective_group", "get_rank", "get_world_size", "allreduce", + "allreduce_multigpu", "barrier", "reduce", "reduce_multigpu", "broadcast", + "broadcast_multigpu", "allgather", "allgather_multigpu", "reducescatter", + "reducescatter_multigpu", "send", "send_multigpu", "recv", "recv_multigpu" ] diff --git a/python/ray/util/collective/collective.py b/python/ray/util/collective/collective.py index 08f9026b0467..afd523e6bf37 100644 --- a/python/ray/util/collective/collective.py +++ b/python/ray/util/collective/collective.py @@ -7,14 +7,9 @@ import ray from ray.util.collective import types -_MPI_AVAILABLE = False +_GLOO_AVAILABLE = False _NCCL_AVAILABLE = True -# try: -# from ray.util.collective.collective_group.mpi_collective_group \ -# import MPIGroup -# except ImportError: -# _MPI_AVAILABLE = False try: from ray.util.collective.collective_group import NCCLGroup except ImportError: @@ -27,8 +22,8 @@ def nccl_available(): return _NCCL_AVAILABLE -def mpi_available(): - return _MPI_AVAILABLE +def gloo_available(): + return _GLOO_AVAILABLE class GroupManager(object): @@ -51,9 +46,11 @@ def create_collective_group(self, backend, world_size, rank, group_name): """ backend = types.Backend(backend) if backend == types.Backend.MPI: + raise RuntimeError("Ray does not support MPI.") + elif backend == types.Backend.GLOO: raise NotImplementedError() elif backend == types.Backend.NCCL: - logger.debug("creating NCCL group: '{}'".format(group_name)) + logger.debug("Creating NCCL group: '{}'...".format(group_name)) g = NCCLGroup(world_size, rank, group_name) self._name_group_map[group_name] = g self._group_name_map[g] = group_name @@ -100,9 +97,9 @@ def init_collective_group(world_size: int, """Initialize a collective group inside an actor process. Args: - world_size (int): the total number of processed in the group. + world_size (int): the total number of processes in the group. rank (int): the rank of the current process. - backend: the CCL backend to use, NCCL or MPI. + backend: the CCL backend to use, NCCL or GLOO. group_name (str): the name of the collective group. Returns: @@ -137,10 +134,13 @@ def declare_collective_group(actors, Args: actors (list): a list of actors to be set in a collective group. - group_options (dict): a dictionary that contains group_name(str), - world_size(int), rank(list of int, e.g. [0,1] - means the first actor is rank 0, and the second - actor is rank 1), backend(str). + world_size (int): the total number of processes in the group. + ranks (List[int]): the rank of each actor. + backend: the CCL backend to use, NCCL or GLOO. + group_name (str): the name of the collective group. + + Returns: + None """ backend = types.Backend(backend) _check_backend_availability(backend) @@ -162,18 +162,25 @@ def declare_collective_group(actors, "Ranks must be a permutation from 0 to '{}'. Got '{}'.".format( len(ranks), "".join([str(r) for r in ranks]))) - assert world_size > 0 - assert all(ranks) >= 0 and all(ranks) < world_size + if world_size <= 0: + raise RuntimeError("World size must be greater than zero. " + "Got '{}'.".format(world_size)) + if not all(ranks) >= 0: + raise RuntimeError("Ranks must be non-negative.") + if not all(ranks) < world_size: + raise RuntimeError("Ranks cannot be greater than world_size.") # avoid a circular dependency from ray.util.collective.util import Info - # store the information into a NamedActor that can be accessed later/ + # store the information into a NamedActor that can be accessed later. name = "info_" + group_name actors_id = [a._ray_actor_id for a in actors] + # TODO (Dacheng): how do we recycle this name actor? info = Info.options(name=name, lifetime="detached").remote() ray.get([info.set_info.remote(actors_id, world_size, ranks, backend)]) +# TODO (we need a declarative destroy() API here.) def destroy_collective_group(group_name: str = "default") -> None: """Destroy a collective group given its group name.""" _check_inside_actor() @@ -206,9 +213,8 @@ def get_world_size(group_name: str = "default") -> int: group_name: the name of the group to query Returns: - The world size of the collective group, - -1 if the group does not exist or the process does - not belong to the group. + The world size of the collective group, -1 if the group does + not exist or the process does not belong to the group. """ _check_inside_actor() if not is_group_initialized(group_name): @@ -232,7 +238,29 @@ def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM): g = _check_and_get_group(group_name) opts = types.AllReduceOptions opts.reduceOp = op - g.allreduce(tensor, opts) + g.allreduce([tensor], opts) + + +def allreduce_multigpu(tensor_list: list, + group_name: str = "default", + op=types.ReduceOp.SUM): + """Collective allreduce a list of tensors across the group. + + Args: + tensor_list (List[tensor]): list of tensors to be allreduced, + each on a GPU. + group_name (str): the collective group name to perform allreduce. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("Multigpu calls requires NCCL and Cupy.") + _check_tensor_list_input(tensor_list) + g = _check_and_get_group(group_name) + opts = types.AllReduceOptions + opts.reduceOp = op + g.allreduce(tensor_list, opts) def barrier(group_name: str = "default"): @@ -256,8 +284,8 @@ def reduce(tensor, Args: tensor: the tensor to be reduced on this process. - dst_rank: the rank of the destination process. - group_name: the collective group name to perform reduce. + dst_rank (int): the rank of the destination process. + group_name (str): the collective group name to perform reduce. op: The reduce operation. Returns: @@ -271,7 +299,42 @@ def reduce(tensor, opts = types.ReduceOptions() opts.reduceOp = op opts.root_rank = dst_rank - g.reduce(tensor, opts) + opts.root_tensor = 0 + g.reduce([tensor], opts) + + +def reduce_multigpu(tensor_list: list, + dst_rank: int = 0, + dst_tensor: int = 0, + group_name: str = "default", + op=types.ReduceOp.SUM): + """Reduce the tensor across the group to the destination rank + and destination tensor. + + Args: + tensor_list: the list of tensors to be reduced on this process; + each tensor located on a GPU. + dst_rank (int): the rank of the destination process. + dst_tensor: the index of GPU at the destination. + group_name (str): the collective group name to perform reduce. + op: The reduce operation. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("Multigpu calls requires NCCL and Cupy.") + _check_tensor_list_input(tensor_list) + g = _check_and_get_group(group_name) + + # check dst rank + _check_rank_valid(g, dst_rank) + _check_root_tensor_valid(len(tensor_list), dst_tensor) + opts = types.ReduceOptions() + opts.reduceOp = op + opts.root_rank = dst_rank + opts.root_tensor = dst_tensor + g.reduce(tensor_list, opts) def broadcast(tensor, src_rank: int = 0, group_name: str = "default"): @@ -279,8 +342,8 @@ def broadcast(tensor, src_rank: int = 0, group_name: str = "default"): Args: tensor: the tensor to be broadcasted (src) or received (destination). - src_rank: the rank of the source process. - group_name: he collective group name to perform broadcast. + src_rank (int): the rank of the source process. + group_name (str): the collective group name to perform broadcast. Returns: None @@ -292,7 +355,37 @@ def broadcast(tensor, src_rank: int = 0, group_name: str = "default"): _check_rank_valid(g, src_rank) opts = types.BroadcastOptions() opts.root_rank = src_rank - g.broadcast(tensor, opts) + opts.root_tensor = 0 + g.broadcast([tensor], opts) + + +def broadcast_multigpu(tensor_list, + src_rank: int = 0, + src_tensor: int = 0, + group_name: str = "default"): + """Broadcast the tensor from a source GPU to all other GPUs. + + Args: + tensor_list: the tensors to broadcast (src) or receive (dst). + src_rank (int): the rank of the source process. + src_tensor (int): the index of the source GPU on the source process. + group_name (str): the collective group name to perform broadcast. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("Multigpu calls requires NCCL and Cupy.") + _check_tensor_list_input(tensor_list) + g = _check_and_get_group(group_name) + + # check src rank + _check_rank_valid(g, src_rank) + _check_root_tensor_valid(len(tensor_list), src_tensor) + opts = types.BroadcastOptions() + opts.root_rank = src_rank + opts.root_tensor = src_tensor + g.broadcast(tensor_list, opts) def allgather(tensor_list: list, tensor, group_name: str = "default"): @@ -301,7 +394,7 @@ def allgather(tensor_list: list, tensor, group_name: str = "default"): Args: tensor_list (list): the results, stored as a list of tensors. tensor: the tensor (to be gathered) in the current process - group_name: the name of the collective group. + group_name (str): the name of the collective group. Returns: None @@ -314,9 +407,33 @@ def allgather(tensor_list: list, tensor, group_name: str = "default"): # Here we make it more strict: len(tensor_list) == world_size. raise RuntimeError( "The length of the tensor list operands to allgather " - "must not be equal to world_size.") + "must be equal to world_size.") + opts = types.AllGatherOptions() + g.allgather([tensor_list], [tensor], opts) + + +def allgather_multigpu(output_tensor_lists: list, + input_tensor_list: list, + group_name: str = "default"): + """Allgather tensors from each gpus of the group into lists. + + Args: + output_tensor_lists (List[List[tensor]]): gathered results, with shape + must be num_gpus * world_size * shape(tensor). + input_tensor_list: (List[tensor]): a list of tensors, with shape + num_gpus * shape(tensor). + group_name (str): the name of the collective group. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("Multigpu calls requires NCCL and Cupy.") + _check_tensor_lists_input(output_tensor_lists) + _check_tensor_list_input(input_tensor_list) + g = _check_and_get_group(group_name) opts = types.AllGatherOptions() - g.allgather(tensor_list, tensor, opts) + g.allgather(output_tensor_lists, input_tensor_list, opts) def reducescatter(tensor, @@ -346,11 +463,38 @@ def reducescatter(tensor, "must not be equal to world_size.") opts = types.ReduceScatterOptions() opts.reduceOp = op - g.reducescatter(tensor, tensor_list, opts) + g.reducescatter([tensor], [tensor_list], opts) + + +def reducescatter_multigpu(output_tensor_list, + input_tensor_lists, + group_name: str = "default", + op=types.ReduceOp.SUM): + """Reducescatter a list of tensors across all GPUs. + + Args: + output_tensor_list: the resulted list of tensors, with + shape: num_gpus * shape(tensor). + input_tensor_lists: the original tensors, with shape: + num_gpus * world_size * shape(tensor). + group_name (str): the name of the collective group. + op: The reduce operation. + + Returns: + None. + """ + if not types.cupy_available(): + raise RuntimeError("Multigpu calls requires NCCL and Cupy.") + _check_tensor_lists_input(input_tensor_lists) + _check_tensor_list_input(output_tensor_list) + g = _check_and_get_group(group_name) + opts = types.ReduceScatterOptions() + opts.reduceOp = op + g.reducescatter(output_tensor_list, input_tensor_lists, opts) def send(tensor, dst_rank: int, group_name: str = "default"): - """Send a tensor to a remote processes synchronously. + """Send a tensor to a remote process synchronously. Args: tensor: the tensor to send. @@ -366,7 +510,41 @@ def send(tensor, dst_rank: int, group_name: str = "default"): if dst_rank == g.rank: raise RuntimeError( "The destination rank '{}' is self.".format(dst_rank)) - g.send(tensor, dst_rank) + opts = types.SendOptions() + opts.dst_rank = dst_rank + g.send([tensor], opts) + + +def send_multigpu(tensor, + dst_rank: int, + dst_gpu_index: int, + group_name: str = "default"): + """Send a tensor to a remote GPU synchronously. + + The function asssume each process owns >1 GPUs, and the sender + process and receiver process has equal nubmer of GPUs. + + Args: + tensor: the tensor to send, located on a GPU. + dst_rank (int): the rank of the destination process. + dst_gpu_index (int): the destination gpu index. + group_name (str): the name of the collective group. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("send_multigpu call requires NCCL.") + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + _check_rank_valid(g, dst_rank) + if dst_rank == g.rank: + raise RuntimeError("The dst_rank '{}' is self. Considering " + "doing GPU to GPU memcpy instead?".format(dst_rank)) + opts = types.SendOptions() + opts.dst_rank = dst_rank + opts.dst_gpu_index = dst_gpu_index + g.send([tensor], opts) def recv(tensor, src_rank: int, group_name: str = "default"): @@ -386,7 +564,41 @@ def recv(tensor, src_rank: int, group_name: str = "default"): if src_rank == g.rank: raise RuntimeError( "The destination rank '{}' is self.".format(src_rank)) - g.recv(tensor, src_rank) + opts = types.RecvOptions() + opts.src_rank = src_rank + g.recv([tensor], opts) + + +def recv_multigpu(tensor, + src_rank: int, + src_gpu_index: int, + group_name: str = "default"): + """Receive a tensor from a remote GPU synchronously. + + The function asssume each process owns >1 GPUs, and the sender + process and receiver process has equal nubmer of GPUs. + + Args: + tensor: the received tensor, located on a GPU. + src_rank (int): the rank of the source process. + src_gpu_index (int): the index of the source gpu on the src process. + group_name (str): the name of the collective group. + + Returns: + None + """ + if not types.cupy_available(): + raise RuntimeError("recv_multigpu call requires NCCL.") + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + _check_rank_valid(g, src_rank) + if src_rank == g.rank: + raise RuntimeError("The dst_rank '{}' is self. Considering " + "doing GPU to GPU memcpy instead?".format(src_rank)) + opts = types.RecvOptions() + opts.src_rank = src_rank + opts.src_gpu_index = src_gpu_index + g.recv([tensor], opts) def _check_and_get_group(group_name): @@ -423,16 +635,6 @@ def _check_and_get_group(group_name): return g -def _check_backend_availability(backend: types.Backend): - """Check whether the backend is available.""" - if backend == types.Backend.MPI: - if not mpi_available(): - raise RuntimeError("MPI is not available.") - elif backend == types.Backend.NCCL: - if not nccl_available(): - raise RuntimeError("NCCL is not available.") - - def _check_single_tensor_input(tensor): """Check if the tensor is with a supported type.""" if isinstance(tensor, np.ndarray): @@ -448,6 +650,16 @@ def _check_single_tensor_input(tensor): type(tensor))) +def _check_backend_availability(backend: types.Backend): + """Check whether the backend is available.""" + if backend == types.Backend.GLOO: + if not gloo_available(): + raise RuntimeError("GLOO is not available.") + elif backend == types.Backend.NCCL: + if not nccl_available(): + raise RuntimeError("NCCL is not available.") + + def _check_inside_actor(): """Check if currently it is inside a Ray actor/task.""" worker = ray.worker.global_worker @@ -462,8 +674,8 @@ def _check_rank_valid(g, rank: int): """Check the rank: 0 <= rank < world_size.""" if rank < 0: raise ValueError("rank '{}' is negative.".format(rank)) - if rank > g.world_size: - raise ValueError("rank '{}' is greater than world size " + if rank >= g.world_size: + raise ValueError("rank '{}' must be less than world size " "'{}'".format(rank, g.world_size)) @@ -476,3 +688,24 @@ def _check_tensor_list_input(tensor_list): raise RuntimeError("Got an empty list of tensors.") for t in tensor_list: _check_single_tensor_input(t) + + +def _check_tensor_lists_input(tensor_lists): + """Check if the input is a list of lists of supported tensor types.""" + if not isinstance(tensor_lists, list): + raise RuntimeError("The input must be a list of lists of tensors. " + "Got '{}'.".format(type(tensor_lists))) + if not tensor_lists: + raise RuntimeError(f"Did not receive tensors. Got: {tensor_lists}") + for t in tensor_lists: + _check_tensor_list_input(t) + + +def _check_root_tensor_valid(length, root_tensor): + """Check the root_tensor device is 0 <= root_tensor < length""" + if root_tensor < 0: + raise ValueError("root_tensor '{}' is negative.".format(root_tensor)) + if root_tensor >= length: + raise ValueError( + "root_tensor '{}' is greater than the number of GPUs: " + "'{}'".format(root_tensor, length)) diff --git a/python/ray/util/collective/collective_group/nccl_collective_group.py b/python/ray/util/collective/collective_group/nccl_collective_group.py index ba8c7d2dbb08..4cc693f11479 100644 --- a/python/ray/util/collective/collective_group/nccl_collective_group.py +++ b/python/ray/util/collective/collective_group/nccl_collective_group.py @@ -11,15 +11,11 @@ from ray.util.collective.const import get_nccl_store_name from ray.util.collective.types import AllReduceOptions, \ BarrierOptions, Backend, ReduceOptions, BroadcastOptions, \ - AllGatherOptions, ReduceScatterOptions + AllGatherOptions, ReduceScatterOptions, SendOptions, \ + RecvOptions logger = logging.getLogger(__name__) -# TODO(Hao): -# (1) stream management, instead of using the default stream, -# using a dedicate stream -# (2) communicator management and support num_gpus > 2 per actor. - class Rendezvous: """A rendezvous class for different actor/task processes to meet. @@ -31,13 +27,18 @@ class Rendezvous: process. Args: - group_name (str): the unique user-specified group name. + store_key (str): the unique store key, usually as a concatanation + of group_name and communicator key. See `get_nccl_communicator` + for more details. """ - def __init__(self, group_name): - if not group_name: - raise ValueError("Invalid group name.") - self._group_name = group_name + def __init__(self, store_key): + if not store_key: + raise ValueError( + "Invalid store_key. The store_key is a concatenation of " + "'group_name' and the 'communicator_key'. See the " + "docstring of `get_nccl_communicator` for details.") + self._store_key = store_key self._store_name = None self._store = None @@ -53,7 +54,7 @@ def meet(self, timeout_s=180): if timeout_s <= 0: raise ValueError("The 'timeout' argument must be positive. " "Got '{}'.".format(timeout_s)) - self._store_name = get_nccl_store_name(self._group_name) + self._store_name = get_nccl_store_name(self._store_key) timeout_delta = datetime.timedelta(seconds=timeout_s) elapsed = datetime.timedelta(seconds=0) start_time = datetime.datetime.now() @@ -72,7 +73,9 @@ def meet(self, timeout_s=180): break if not self._store: raise RuntimeError("Unable to meet other processes " - "at the rendezvous store.") + "at the rendezvous store. If you are using " + "P2P communication, please check if tensors " + "are put in the correct GPU. ") @property def store(self): @@ -83,8 +86,9 @@ def get_nccl_id(self, timeout_s=180): Args: timeout_s: timeout in seconds. + Return: - str: the NCCLUniqueID if successful. + uid (str): the NCCLUniqueID if successful. """ if not self._store: raise ValueError("Rendezvous store is not setup.") @@ -110,55 +114,52 @@ def __init__(self, world_size, rank, group_name): """Init an NCCL collective group.""" super(NCCLGroup, self).__init__(world_size, rank, group_name) - # TODO(Hao): change this to a be a cache - self._collective_comm_cache = None - self._p2p_comm_cache = {} + # communicator and stream cache. + # TODO (Hao): we need a lock here... + self._dev_comm_map = {} + self._dev_streams_map = {} + + # record the used GPU IDs. + self._used_gpu_indices = set() if nccl_util.get_nccl_build_version() < 2000: raise RuntimeError("NCCL in Ray requires NCCL >= 2.0.") - # TODO(Hao): check version here if nccl_util.get_nccl_runtime_version() < 2704: logger.warning("NCCL send/recv calls requires NCCL>=2.7.4") - # Setup a tensor for barrier calls - self._barrier_tensor = cupy.array([1]) - def destroy_group(self): """Destroy the group and release NCCL communicators.""" - if self._collective_comm_cache: - self.barrier() - # We also need a barrier call here. - stream = self._get_cuda_stream() - stream.synchronize() - # destroy the communicator - self._collective_comm_cache.destroy() - self._collective_comm_cache = None - - if self.rank == 0: - self._destroy_store(self.group_name) - - if self._p2p_comm_cache: - for key, comm in self._p2p_comm_cache.items(): - comm.destroy() - min_rank, max_rank = self._parse_p2p_group_key(key) - if self.rank == min_rank: - self._destroy_store(key) - self._p2p_comm_cache[key] = None - for key in list(self._p2p_comm_cache.keys()): - del self._p2p_comm_cache[key] - self._p2p_comm_cache = None - + if len(self._dev_comm_map.keys()) > 0: + + # TODO(Hao): check this barrier call + # self.barrier() + + # Destroy the communicators and streams. + for comm_key, comms in self._dev_comm_map.items(): + for c in comms: + c.destroy() + self._dev_comm_map[comm_key] = None + + if self.rank == 0: + for comm_key in self._dev_comm_map: + assert not self._dev_comm_map[comm_key] + group_key = self._generate_group_key(comm_key) + self._destroy_store(group_key) + self._barrier_tensor = None + self._dev_comm_map = None + self._dev_streams_map = None super(NCCLGroup, self).destroy_group() @classmethod def backend(cls): return Backend.NCCL - def allreduce(self, tensor, allreduce_options=AllReduceOptions()): - """AllReduce the tensor across the collective group following options. + def allreduce(self, tensors, allreduce_options=AllReduceOptions()): + """AllReduce tensors across the collective group following options. Args: - tensor: the tensor to be reduced, each tensor locates on a GPU. + tensors (List): the list of tensors to be reduced. Each tensor must + reside on one GPU of the current process. allreduce_options: allreduce options. Returns: @@ -174,29 +175,41 @@ def collective_fn(input_tensor, output_tensor, comm, stream): nccl_util.get_nccl_reduce_op(allreduce_options.reduceOp), stream.ptr) - self._collective(tensor, tensor, collective_fn) + self._collective(tensors, tensors, collective_fn) def barrier(self, barrier_options=BarrierOptions()): """Blocks until all processes reach this barrier. Args: - barrier_options: + barrier_options: barrier options. Returns: None """ - self.allreduce(self._barrier_tensor) - - def reduce(self, tensor, reduce_options=ReduceOptions()): - """Reduce tensor to a destination process following options. + # Get the device list. + if self._used_gpu_indices: + devices = list(self._used_gpu_indices) + else: + devices = list(range(nccl_util.get_num_gpus())) + barrier_tensors = [None] * len(devices) + for i, d in enumerate(devices): + with nccl_util.Device(d): + barrier_tensors[i] = cupy.array([1]) + self.allreduce(barrier_tensors) + + def reduce(self, tensors, reduce_options=ReduceOptions()): + """Reduce tensors to a destination gpu following options. Args: - tensor: the tensor to be reduced. - reduce_options: reduce options + tensors (List): the list of tensors to be reduced, each tensor + must reside on one gpu of the current process. + reduce_options: reduce options. Returns: None """ + root_rank = len(tensors) * reduce_options.root_rank \ + + reduce_options.root_tensor def collective_fn(input_tensor, output_tensor, comm, stream): comm.reduce( @@ -205,40 +218,43 @@ def collective_fn(input_tensor, output_tensor, comm, stream): nccl_util.get_tensor_n_elements(input_tensor), nccl_util.get_nccl_tensor_dtype(input_tensor), nccl_util.get_nccl_reduce_op(reduce_options.reduceOp), - reduce_options.root_rank, stream.ptr) + root_rank, stream.ptr) - self._collective(tensor, tensor, collective_fn) + self._collective(tensors, tensors, collective_fn) - def broadcast(self, tensor, broadcast_options=BroadcastOptions()): - """Broadcast tensor to all other processes following options. + def broadcast(self, tensors, broadcast_options=BroadcastOptions()): + """Broadcast tensors to all other gpus following options. Args: - tensor: the tensor to be broadcasted. + tensors (List): tensors to be broadcast or received. broadcast_options: broadcast options. Returns: None """ + root_rank = len(tensors) * broadcast_options.root_rank \ + + broadcast_options.root_tensor def collective_fn(input_tensor, output_tensor, comm, stream): comm.broadcast( nccl_util.get_tensor_ptr(input_tensor), nccl_util.get_tensor_ptr(output_tensor), nccl_util.get_tensor_n_elements(input_tensor), - nccl_util.get_nccl_tensor_dtype(input_tensor), - broadcast_options.root_rank, stream.ptr) + nccl_util.get_nccl_tensor_dtype(input_tensor), root_rank, + stream.ptr) - self._collective(tensor, tensor, collective_fn) + self._collective(tensors, tensors, collective_fn) def allgather(self, - tensor_list, - tensor, + tensor_lists, + tensors, allgather_options=AllGatherOptions()): - """Allgather tensors across the group into a list of tensors. + """Allgather tensors across gpus into a list of tensors. Args: - tensor_list: the tensor list to store the results. - tensor: the tensor to be allgather-ed across the group. + tensor_lists (List[List[Tensor]]): allgathered tensors. + tensors: the list of tensors to allgather across the group. + Each tensor must lolcate on a GPU of the process. allgather_options: allgather options. Returns: @@ -252,30 +268,36 @@ def collective_fn(input_tensor, output_tensor, comm, stream): nccl_util.get_tensor_n_elements(input_tensor), nccl_util.get_nccl_tensor_dtype(input_tensor), stream.ptr) - _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list) - flattened_output_tensor = _flatten_for_scatter_gather( - tensor_list, copy=False) + _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists) + output_flattened = [ + _flatten_for_scatter_gather(tensor_list, copy=False) + for tensor_list in tensor_lists + ] def postprocess_fn(stream): - for i, tensor in enumerate(tensor_list): - nccl_util.copy_tensor(tensor, flattened_output_tensor[i]) + # TODO(Hao): designate a copy stream. + for i, tensor_list in enumerate(tensor_lists): + for j, tensor in enumerate(tensor_list): + nccl_util.copy_tensor(tensor, output_flattened[i][j]) self._collective( - tensor, - flattened_output_tensor, + tensors, + output_flattened, collective_fn, postprocess_fn=postprocess_fn) def reducescatter(self, - tensor, - tensor_list, + tensors, + tensor_lists, reducescatter_options=ReduceScatterOptions()): - """Reducescatter a list of tensors across the group. + """Reduce the scatter a list of tensors across the group. Args: - tensor: the output tensor (could be unspecified). - tensor_list: the list of tensor to be reduced then scattered. - reducescatter_options: reducescatter options. + tensors (List): the output tensors (could be unspecified), each + located on a GPU of the current process. + tensor_lists (List[List]): the list of tensors to be reduced then + scattered. + reducescatter_options: reduce-scatter options. Returns: None @@ -290,26 +312,30 @@ def collective_fn(input_tensor, output_tensor, comm, stream): nccl_util.get_nccl_reduce_op(reducescatter_options.reduceOp), stream.ptr) - _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list) - flattened_input_tensor = _flatten_for_scatter_gather( - tensor_list, copy=False) + _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists) + input_flattened = [ + _flatten_for_scatter_gather(tensor_list, copy=False) + for tensor_list in tensor_lists + ] def preprocess_fn(stream): - for i, tensor in enumerate(tensor_list): - nccl_util.copy_tensor(flattened_input_tensor[i], tensor) + # TODO(Hao): designate a copy stream. + for i, tensor_list in enumerate(tensor_lists): + for j, tensor in enumerate(tensor_list): + nccl_util.copy_tensor(input_flattened[i][j], tensor) self._collective( - flattened_input_tensor, - tensor, + input_flattened, + tensors, collective_fn, preprocess_fn=preprocess_fn) - def send(self, tensor, dst_rank): - """Send tensor to a destination process in the group. + def send(self, tensors, send_options=SendOptions()): + """Send a tensor to a destination gpu in the group. Args: - tensor: the tensor to send. - dst_rank: the rank of the destination process. + tensors (List): the tensor to send. + send_options: send options. Returns: None @@ -321,14 +347,15 @@ def p2p_fn(tensor, comm, stream, peer): nccl_util.get_tensor_n_elements(tensor), nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr) - self._point2point(tensor, p2p_fn, dst_rank) + self._point2point(tensors, p2p_fn, send_options.dst_rank, + send_options.dst_gpu_index) - def recv(self, tensor, src_rank): - """Receive tensor from a source process in the group. + def recv(self, tensors, recv_options=RecvOptions()): + """Receive a tensor from a source gpu in the group. Args: - tensor: the received tensor. - src_rank: the rank of the source process. + tensors (List): the received tensor. + recv_options: Receive options. Returns: None @@ -340,128 +367,218 @@ def p2p_fn(tensor, comm, stream, peer): nccl_util.get_tensor_n_elements(tensor), nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr) - self._point2point(tensor, p2p_fn, src_rank) + self._point2point(tensors, p2p_fn, recv_options.src_rank, + recv_options.src_gpu_index) + + def _get_nccl_collective_communicator(self, comm_key, device_list): + """Create or retrieve an NCCL communicator from cache. + + If the communicator is found in cache, return the communicator. If not, + a communicator and a stream will be created and put in cache. + TODO(Hao): this function is not thread-safe now. - def _get_nccl_collective_communicator(self): - """Create or retrieve a cached NCCL communicator. + Args: + comm_key (str): the key to query the communicator cache. + device_list (List): a list of GPU devices of the current process + that participates into the collective. Returns: - communicator + communicator: the NCCL communicator corresponded to the devices. """ - if not self._collective_comm_cache: - # create the communicator - if self.rank == 0: - group_uid = self._generate_nccl_uid(self.group_name) - else: - rendezvous = Rendezvous(self.group_name) - rendezvous.meet() - group_uid = rendezvous.get_nccl_id() - self._collective_comm_cache = \ - nccl_util.create_nccl_communicator(self.world_size, - group_uid, - self.rank) - return self._collective_comm_cache - - def _get_nccl_p2p_communicator(self, rank1, rank2): + if not comm_key: + raise RuntimeError("Got empty communicator key.") + for d in device_list: + self._used_gpu_indices.add(d) + + # TODO(Hao): lock the _dev_comm_map here. + if comm_key in self._dev_comm_map: + return self._dev_comm_map[comm_key] + + group_key = self._generate_group_key(comm_key) + if self.rank == 0: + nccl_uid = self._generate_nccl_uid(group_key) + else: + rendezvous = Rendezvous(group_key) + rendezvous.meet() + nccl_uid = rendezvous.get_nccl_id() + + # Now create the communicators + actual_world_size = len(device_list) * self.world_size + comms = [None] * len(device_list) + streams = [None] * len(device_list) + nccl_util.groupStart() + for i, device in enumerate(device_list): + actual_rank = self.rank * len(device_list) + i + with nccl_util.Device(device): + comms[i] = nccl_util.create_nccl_communicator( + actual_world_size, nccl_uid, actual_rank) + streams[i] = cupy.cuda.Stream.null + # Stream(non_blocking=True) + nccl_util.groupEnd() + self._dev_comm_map[comm_key] = comms + self._dev_streams_map[comm_key] = streams + return comms + + @staticmethod + def _sync_streams(): + """Let NCCL streams wait for current streams for every device.""" + # FIXME: This behavior is different from nccl document. It seems like + # cupy allocate tensors on null streams. + cupy.cuda.Stream.null.synchronize() + + def _get_nccl_p2p_communicator(self, comm_key, my_gpu_idx, peer_rank, + peer_gpu_idx): """Create or retrieve an NCCL communicator for p2p tasks. - Args: - rank1 (int): source rank. - rank2 (int): destination rank. + Note(Hao): this function is not thread-safe now. + Args: + comm_key (str): communicator key. + my_gpu_idx (int): the gpu index on the current process. + peer_rank (int): the rank of the destination process. + peer_gpu_idx (int): the gpu index on the peer process. Returns: communicator """ - min_rank = min(rank1, rank2) - max_rank = max(rank1, rank2) - my_rank = 0 if self.rank == min_rank else 1 - p2p_group_key = self._generate_p2p_group_key(min_rank, max_rank) - comm = self._p2p_comm_cache.get(p2p_group_key) - if not comm: - if self.rank == min_rank: - group_uid = self._generate_nccl_uid(p2p_group_key) - else: - rendezvous = Rendezvous(p2p_group_key) - rendezvous.meet() - group_uid = rendezvous.get_nccl_id() - comm = nccl_util.create_nccl_communicator(2, group_uid, my_rank) - self._p2p_comm_cache[p2p_group_key] = comm - return comm - - def _generate_p2p_group_key(self, min_rank, max_rank): - return self.group_name + "_" + str(min_rank) + "_" + str(max_rank) + if not comm_key: + raise RuntimeError("Got empty communicator key.") + + # TODO(Hao): lock the _dev_comm_map here. + if comm_key in self._dev_comm_map: + return self._dev_comm_map[comm_key] + + # Note (Hao): This is a bit complex so I decide to take a note here. + # Here we need to consider three cases: + # Case 1: src_rank != dst_rank, hence the send and recv happen on + # different process (actors/tasks); each process makes independent + # collective calls and manages corresponding communicators. + # Case 2: src_rank == dst_rank, src_gpu_idx == dst_gpu_idx; for + # this case, we simply throw a RuntimeError; + # Case 3: src_rank == dst_rank, src_gpu_idx != dst_gpu_idx, which + # means the send and recv will be called on the same process. We + # DO NOT support this case for now. We need to properly scope: + # (1) communicators creation, and + # (2) send/recv calls + # using groupStart(( and groupEnd() calls to avoid deadlocks. + if self.rank < peer_rank: + my_p2p_rank = 0 + elif self.rank > peer_rank: + my_p2p_rank = 1 + else: + raise RuntimeError( + "Send and recv happens on the same process! " + "ray.util.collective does not support this case as of now. " + "Alternatively, consider doing GPU to GPU memcpy?") + + group_key = self._generate_group_key(comm_key) + if my_p2p_rank == 0: + nccl_uid = self._generate_nccl_uid(group_key) + else: + rendezvous = Rendezvous(group_key) + rendezvous.meet() + nccl_uid = rendezvous.get_nccl_id() + + # create the p2p communicators + with nccl_util.Device(my_gpu_idx): + comm = nccl_util.create_nccl_communicator(2, nccl_uid, my_p2p_rank) + stream = cupy.cuda.Stream.null + # Stream(non_blocking=True) + self._dev_comm_map[comm_key] = [comm] + self._dev_streams_map[comm_key] = [stream] + return [comm] + + def _generate_group_key(self, comm_key): + """Generate a unique key used to initialize the KV store. + + The group key is a concatenation of the communicator key and + the group name, following: [comm_key]@[group_name]. + """ + return comm_key + "@" + self.group_name @staticmethod - def _parse_p2p_group_key(key): - strs = key.split("_") - return int(strs[-2]), int(strs[-1]) + def _destroy_store(group_key): + """Destroy the KV store (Ray named actor). - @staticmethod - def _destroy_store(group_name): - store_name = get_nccl_store_name(group_name) + Args: + group_key (str): the unique key to retrieve the KV store. + + Returns: + None + """ + store_name = get_nccl_store_name(group_key) store = ray.get_actor(store_name) # ray.get([store.__ray_terminate__.remote()]) ray.kill(store) - def _generate_nccl_uid(self, name): - """Generate an NCCL UID by calling the NCCL API. + def _generate_nccl_uid(self, key): + """Generate an NCCL unique ID for initializing communicators. + + The method will also create a KV store using Ray named actor and store + the NCCLUniqueID in the store. The store needs to be garbage collected + when destroying the collective group. Args: - name: the name of the collective group. + key (str): the key of the . Returns: - str: NCCL uid. + NCCLUniqueID (str): NCCL unique ID. """ group_uid = nccl_util.get_nccl_unique_id() - store_name = get_nccl_store_name(name) + store_name = get_nccl_store_name(key) # Avoid a potential circular dependency in ray/actor.py from ray.util.collective.util import NCCLUniqueIDStore store = NCCLUniqueIDStore.options( name=store_name, lifetime="detached").remote(store_name) - ray.wait([store.set_id.remote(group_uid)]) + ray.get([store.set_id.remote(group_uid)]) return group_uid - @staticmethod - def _get_cuda_stream(): - """Obtain an idle stream from a stream pool for the collective task.""" - # TODO: implement a simple stream manager. - return cupy.cuda.Stream.null - def _collective(self, - input_tensor, - output_tensor, + input_tensors, + output_tensors, collective_fn, preprocess_fn=None, postprocess_fn=None): """A method to encapsulate all collective calls. Args: - input_tensor: the input tensor. - output_tensor: the output tensor. + input_tensors: the list of the input tensors. + output_tensors: the list of the output tensors. collective_fn: the collective function call. - preprocess_fn: preprocess function to call before collectives. - postprocess_fn: postprocess function to call after collectives. + preprocess_fn: preprocess procedures before collective calls. + postprocess_fn: postprocess procedures after collective calls. Returns: None """ - comm = self._get_nccl_collective_communicator() - stream = self._get_cuda_stream() + _check_gpu_tensors(input_tensors) + _check_gpu_tensors(output_tensors) + + devices = nccl_util.get_tensor_device_list(input_tensors) + key = _get_comm_key_from_devices(devices) + comms = self._get_nccl_collective_communicator(key, devices) + streams = self._dev_streams_map[key] + + # TODO(Hao): sync streams and events + self._sync_streams() # Make the collective call if preprocess_fn: - preprocess_fn(stream) - collective_fn(input_tensor, output_tensor, comm, stream) + preprocess_fn(streams) + nccl_util.groupStart() + for i, tensor in enumerate(input_tensors): + collective_fn(tensor, output_tensors[i], comms[i], streams[i]) + nccl_util.groupEnd() if postprocess_fn: - postprocess_fn(stream) + postprocess_fn(streams) - def _point2point(self, tensor, p2p_fn, peer_rank: int): - """A method to encapsulate all p2p calls. + def _point2point(self, tensors, p2p_fn, peer_rank: int, peer_gpu_idx: int): + """A method to encapsulate all peer-to-peer calls (i.e., send/recv). Args: - tensor: the tensor to be sent/received. + tensors: the tensor to send or receive. p2p_fn: the p2p function call. - peer_rank (int): the peer rank of the current process. + peer_rank (int): the rank of the peer process. + peer_gpu_idx (int): the index of the gpu on the peer process. Returns: None @@ -471,13 +588,24 @@ def _point2point(self, tensor, p2p_fn, peer_rank: int): raise RuntimeError("P2p send/recv requires NCCL >= 2.7.4. " "Got '{}'.".format( nccl_util.get_nccl_runtime_version())) + _check_gpu_tensors(tensors) + + # we currently only support single device to single device send/recv. + assert len(tensors) == 1 + my_gpu_idx = nccl_util.get_tensor_device(tensors[0]) + comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank, + peer_gpu_idx) + comms = self._get_nccl_p2p_communicator(comm_key, my_gpu_idx, + peer_rank, peer_gpu_idx) + streams = self._dev_streams_map[comm_key] + + # TODO(Hao): sync streams and events + self._sync_streams() # We have made sure that self.rank != peer_rank during API check. peer_p2p_rank = 0 if self.rank > peer_rank else 1 - comm = self._get_nccl_p2p_communicator(self.rank, peer_rank) - stream = self._get_cuda_stream() - # Make the p2p call: - p2p_fn(tensor, comm, stream, peer_p2p_rank) + for i, tensor in enumerate(tensors): + p2p_fn(tensors[i], comms[i], streams[i], peer_p2p_rank) def _flatten_for_scatter_gather(tensor_list, copy=False): @@ -496,29 +624,130 @@ def _flatten_for_scatter_gather(tensor_list, copy=False): # note we need a cupy dtype here. dtype = nccl_util.get_cupy_tensor_dtype(t) buffer_shape = [len(tensor_list)] + nccl_util.get_tensor_shape(t) - buffer = cupy.empty(buffer_shape, dtype=dtype) + device = nccl_util.get_tensor_device(t) + with nccl_util.Device(device): + buffer = cupy.empty(buffer_shape, dtype=dtype) if copy: for i, tensor in enumerate(tensor_list): nccl_util.copy_tensor(buffer[i], tensor) return buffer -def _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list): - """Check the compatibility between tensor input and tensor list inputs.""" - if not tensor_list: - raise RuntimeError("Got empty list of tensors.") - dtype = nccl_util.get_nccl_tensor_dtype(tensor) - shape = nccl_util.get_tensor_shape(tensor) - for t in tensor_list: - # check dtype - dt = nccl_util.get_nccl_tensor_dtype(t) +def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists): + """Check the compatibility between tensor input and tensor list input.""" + if not tensors or not isinstance(tensors, list): + raise RuntimeError( + "The first argument 'tensors' expects a list of tensors.") + if not tensor_lists or not isinstance(tensor_lists, list): + raise RuntimeError("The second argument 'tensor_lists' " + "expects a list of tensor list.") + dtype = nccl_util.get_nccl_tensor_dtype(tensors[0]) + shape = nccl_util.get_tensor_shape(tensors[0]) + for i, tensor_list in enumerate(tensor_lists): + # check all tensor in `tensors` match. + dt = nccl_util.get_nccl_tensor_dtype(tensors[i]) if dt != dtype: raise RuntimeError("All tensor operands to scatter/gather must " - "have the same dtype. Got '{}' and '{}'" - "".format(dt, dtype)) + "have the same dtype. Got '{}' and '{}'." + .format(dt, dtype)) # Note: typically CCL libraries only requires they have the same - # number of elements; - # Here we make it more strict -- we require exact shape match. - if nccl_util.get_tensor_shape(t) != shape: + # number of elements; Here we make it more strict -- we require + # exact shape match. + s = nccl_util.get_tensor_shape(tensors[i]) + if s != shape: raise RuntimeError("All tensor operands to scatter/gather must " - "have the same shape.") + "have the same shape. Got '{}' and '{}'." + .format(s, shape)) + # check all tensors in `tensor_lists` match. + for t in tensor_lists[i]: + # check dtype + dt = nccl_util.get_nccl_tensor_dtype(t) + if dt != dtype: + raise RuntimeError( + "All tensor operands to scatter/gather must " + "have the same dtype. Got '{}' and '{}'.".format( + dt, dtype)) + s = nccl_util.get_tensor_shape(t) + if s != shape: + raise RuntimeError( + "All tensor operands to scatter/gather must " + "have the same shape. Got '{}' and '{}'.".format(s, shape)) + + +def _check_gpu_tensors(tensors): + """Check all tensors are distributed on different GPUs.""" + if not tensors or not isinstance(tensors, list): + raise RuntimeError("'tensors' must be a nonempty list.") + if len(tensors) > nccl_util.get_num_gpus(): + raise RuntimeError("Tensor list cannot be larger than the number" + "of available GPUs. Got {} > {}.".format( + len(tensors), nccl_util.get_num_gpus())) + t0 = tensors[0] + dt = nccl_util.get_nccl_tensor_dtype(t0) + s = nccl_util.get_tensor_shape(t0) + d = nccl_util.get_tensor_device(t0) + for i, t in enumerate(tensors): + if i == 0: + continue + # We need to check the following: + # (1) tensor is cuda (already checked during API) + # (2) tensor dtype + # (3) tensor shape match + # (4) each tensor is on a different GPU + dtype = nccl_util.get_nccl_tensor_dtype(t) + if dt != dtype: + raise RuntimeError("Tensors must have identical dtype. Got: '{}'." + .format(dtype)) + shape = nccl_util.get_tensor_shape(t) + if s != shape: + raise RuntimeError("Tensor must have identical shape. Got: '{}'." + .format(shape)) + device = nccl_util.get_tensor_device(t) + if device == d: + raise RuntimeError("Tensor must be on distinct GPUs.") + + +def _get_comm_key_from_devices(devices): + """Return a key from a list of devices for collective calls. + + For example, if the tensors are on gpus 0, 1, 2, 3, + then the key would be "0,1,2,3". + + Args: + devices(list): a list of GPU device indices + + Returns: + str: a string represents the key to query the communicator cache. + + """ + return ",".join([str(d) for d in devices]) + + +def _get_comm_key_send_recv(my_rank, my_gpu_idx, peer_rank, peer_gpu_idx): + """Return a key given source and destination ranks for p2p tasks. + + The p2p key is in the following form: + [min_rank]_[gpu_index]:[max_rank]_[gpu_index]. + + Args: + my_rank (int): the rank of the source process. + my_gpu_idx (int): the source gpu index on the process. + peer_rank (int): the rank of the destination process. + peer_gpu_idx (int): the destination gpu index on the process. + + Returns: + comm_key (str): a string key to query the communication cache. + """ + if my_rank < peer_rank: + lower_key = str(my_rank) + "_" + str(my_gpu_idx) + higher_key = str(peer_rank) + "_" + str(peer_gpu_idx) + elif my_rank > peer_rank: + lower_key = str(peer_rank) + "_" + str(peer_gpu_idx) + higher_key = str(my_rank) + "_" + str(my_gpu_idx) + else: + raise RuntimeError( + "Send and recv happens on the same process. ray.util.collective " + "does not support this case as of now. Alternatively, consider " + "doing GPU to GPU memcpy?") + comm_key = lower_key + ":" + higher_key + return comm_key diff --git a/python/ray/util/collective/collective_group/nccl_util.py b/python/ray/util/collective/collective_group/nccl_util.py index 889c8c443f36..36895d79b884 100644 --- a/python/ray/util/collective/collective_group/nccl_util.py +++ b/python/ray/util/collective/collective_group/nccl_util.py @@ -3,9 +3,12 @@ try: import cupy from cupy.cuda import nccl + from cupy.cuda import Device # noqa: F401 from cupy.cuda.nccl import get_version from cupy.cuda.nccl import get_build_version from cupy.cuda.nccl import NcclCommunicator + from cupy.cuda.nccl import groupStart # noqa: F401 + from cupy.cuda.nccl import groupEnd # noqa: F401 except ImportError: raise ImportError("NCCL in Ray requires Cupy being available!") @@ -74,6 +77,11 @@ } +def get_num_gpus(): + """Returns the number of compute-capable GPUs.""" + return cupy.cuda.runtime.getDeviceCount() + + def get_nccl_build_version(): return get_build_version() @@ -90,14 +98,12 @@ def create_nccl_communicator(world_size, nccl_unique_id, rank): """Create an NCCL communicator using NCCL APIs. Args: - world_size (int): the number of processes of this communcator group. + world_size (int): the number of processes of this communicator group. nccl_unique_id (str): the NCCLUniqueID for this group. rank (int): the rank of this process. Returns: comm (nccl.ncclComm_t): an NCCL communicator. """ - # TODO(Hao): make this inside the NCCLComm class, - # and implement the abort method. Make it RAII. comm = NcclCommunicator(world_size, nccl_unique_id, rank) return comm @@ -149,7 +155,7 @@ def get_tensor_ptr(tensor): if torch_available(): if isinstance(tensor, torch.Tensor): if not tensor.is_cuda: - raise RuntimeError("torch tensor must be on gpu.") + raise RuntimeError("Torch tensor must be on GPU.") return tensor.data_ptr() raise ValueError("Unsupported tensor type. Got: {}. Supported " "GPU tensor types are: torch.Tensor, " @@ -194,6 +200,24 @@ def get_tensor_strides(tensor): "cupy.ndarray.".format(type(tensor))) +def get_tensor_device(tensor): + """Return the GPU index of a tensor.""" + if isinstance(tensor, cupy.ndarray): + try: + device = tensor.device.id + except AttributeError as exec: + raise RuntimeError("The tensor is not on a valid GPU.") \ + from exec + elif torch_available() and isinstance(tensor, torch.Tensor): + device = tensor.device.index + if not isinstance(device, int): + raise RuntimeError("The tensor is not on a valid GPU.") + else: + raise ValueError("Unsupported tensor type. " + "Got: {}.".format(type(tensor))) + return device + + def copy_tensor(dst_tensor, src_tensor): """Copy the content from src_tensor to dst_tensor. @@ -228,3 +252,21 @@ def copy_tensor(dst_tensor, src_tensor): raise ValueError("Unsupported tensor type. Got: {} and {}. Supported " "GPU tensor types are: torch.Tensor, cupy.ndarray." .format(type(dst_tensor), type(src_tensor))) + + +def get_tensor_device_list(tensors): + """Returns the gpu devices of the list of input tensors. + + Args: + tensors(list): a list of tensors, each locates on a GPU. + + Returns: + list: the list of GPU devices. + + """ + if not isinstance(tensors, list): + raise RuntimeError( + "Expect a list of tensors each locates on a GPU device. " + "Got: '{}'.".format(type(tensors))) + devices = [get_tensor_device(t) for t in tensors] + return devices diff --git a/python/ray/util/collective/examples/nccl_allreduce_example.py b/python/ray/util/collective/examples/nccl_allreduce_example.py index 7010d69249f2..797924621a52 100644 --- a/python/ray/util/collective/examples/nccl_allreduce_example.py +++ b/python/ray/util/collective/examples/nccl_allreduce_example.py @@ -11,12 +11,11 @@ def __init__(self): self.recv = cp.zeros((4, ), dtype=cp.float32) def setup(self, world_size, rank): - collective.init_collective_group("nccl", world_size, rank, "default") + collective.init_collective_group(world_size, rank, "nccl", "default") return True def compute(self): collective.allreduce(self.send, "default") - print(self.send) return self.send def destroy(self): @@ -24,11 +23,8 @@ def destroy(self): if __name__ == "__main__": - send = cp.ones((4, ), dtype=cp.float32) - ray.init(num_gpus=2) - num_workers = 2 workers = [] init_rets = [] @@ -38,5 +34,4 @@ def destroy(self): init_rets.append(w.setup.remote(num_workers, i)) _ = ray.get(init_rets) results = ray.get([w.compute.remote() for w in workers]) - # print(results) ray.shutdown() diff --git a/python/ray/util/collective/examples/nccl_allreduce_example_declare_collective_group.py b/python/ray/util/collective/examples/nccl_allreduce_example_declare_collective_group.py index 9d0335dbab11..106ea31b2b7f 100644 --- a/python/ray/util/collective/examples/nccl_allreduce_example_declare_collective_group.py +++ b/python/ray/util/collective/examples/nccl_allreduce_example_declare_collective_group.py @@ -30,5 +30,4 @@ def compute(self): } collective.declare_collective_group(workers, **_options) results = ray.get([w.compute.remote() for w in workers]) - print(results) ray.shutdown() diff --git a/python/ray/util/collective/examples/nccl_allreduce_multigpu_example.py b/python/ray/util/collective/examples/nccl_allreduce_multigpu_example.py new file mode 100644 index 000000000000..88b75802e880 --- /dev/null +++ b/python/ray/util/collective/examples/nccl_allreduce_multigpu_example.py @@ -0,0 +1,43 @@ +import ray +import cupy as cp + +import ray.util.collective as collective +from cupy.cuda import Device + + +@ray.remote(num_gpus=2) +class Worker: + def __init__(self): + with Device(0): + self.send1 = cp.ones((4, ), dtype=cp.float32) + with Device(1): + self.send2 = cp.ones((4, ), dtype=cp.float32) * 2 + + self.recv = cp.zeros((4, ), dtype=cp.float32) + + def setup(self, world_size, rank): + collective.init_collective_group(world_size, rank, "nccl", "177") + return True + + def compute(self): + collective.allreduce_multigpu([self.send1, self.send2], "177") + return [self.send1, self.send2], self.send1.device, self.send2.device + + def destroy(self): + collective.destroy_collective_group("177") + + +if __name__ == "__main__": + ray.init(address="auto") + num_workers = 2 + workers = [] + init_rets = [] + for i in range(num_workers): + w = Worker.remote() + workers.append(w) + init_rets.append(w.setup.remote(num_workers, i)) + a = ray.get(init_rets) + results = ray.get([w.compute.remote() for w in workers]) + print(results) + ray.get([w.destroy.remote() for w in workers]) + ray.shutdown() diff --git a/python/ray/util/collective/examples/nccl_p2p_example_multigpu.py b/python/ray/util/collective/examples/nccl_p2p_example_multigpu.py new file mode 100644 index 000000000000..7ff637a5bd68 --- /dev/null +++ b/python/ray/util/collective/examples/nccl_p2p_example_multigpu.py @@ -0,0 +1,53 @@ +import ray +import cupy as cp + +import ray.util.collective as collective +from cupy.cuda import Device + + +@ray.remote(num_gpus=2) +class Worker: + def __init__(self): + with Device(0): + self.send1 = cp.ones((4, ), dtype=cp.float32) + with Device(1): + self.send2 = cp.ones((4, ), dtype=cp.float32) * 2 + + with Device(0): + self.recv1 = cp.zeros((4, ), dtype=cp.float32) + with Device(1): + self.recv2 = cp.zeros((4, ), dtype=cp.float32) + self.rank = -1 + + def setup(self, world_size, rank): + self.rank = rank + collective.init_collective_group(world_size, rank, "nccl", "8") + return True + + def compute(self): + if self.rank == 0: + with Device(0): + collective.send_multigpu(self.send1 * 2, 1, 1, "8") + else: + # with Device(1): + collective.recv_multigpu(self.recv2, 0, 0, "8") + return self.recv2 + + def destroy(self): + collective.destroy_collective_group("8") + + +if __name__ == "__main__": + ray.init(address="auto") + num_workers = 2 + workers = [] + init_rets = [] + for i in range(num_workers): + w = Worker.remote() + workers.append(w) + init_rets.append(w.setup.remote(num_workers, i)) + a = ray.get(init_rets) + results = ray.get([w.compute.remote() for w in workers]) + print(results) + ray.get([w.destroy.remote() for w in workers]) + ray.shutdown() diff --git a/python/ray/util/collective/tests/conftest.py b/python/ray/util/collective/tests/conftest.py index ab5b3765d166..341142ec050d 100644 --- a/python/ray/util/collective/tests/conftest.py +++ b/python/ray/util/collective/tests/conftest.py @@ -1,30 +1,41 @@ """Some fixtures for collective tests.""" -import pytest +import logging +import pytest import ray +from ray.util.collective.collective_group.nccl_collective_group \ + import _get_comm_key_from_devices, _get_comm_key_send_recv from ray.util.collective.const import get_nccl_store_name +logger = logging.getLogger(__name__) +logger.setLevel("INFO") + # TODO (Hao): remove this clean_up function as it sometimes crashes Ray. def clean_up(): group_names = ["default", "test", "123?34!", "default2", "random"] group_names.extend([str(i) for i in range(10)]) max_world_size = 4 - p2p_group_names = [] + all_keys = [] for name in group_names: + devices = [[0], [0, 1], [1, 0]] + for d in devices: + collective_communicator_key = _get_comm_key_from_devices(d) + all_keys.append(collective_communicator_key + "@" + name) for i in range(max_world_size): for j in range(max_world_size): - if i <= j: - p2p_group_name = name + "_" + str(i) + "_" + str(j) - p2p_group_names.append(p2p_group_name) - all_names = group_names + p2p_group_names - for group_name in all_names: - store_name = get_nccl_store_name(group_name) + if i < j: + p2p_communicator_key = _get_comm_key_send_recv(i, 0, j, 0) + all_keys.append(p2p_communicator_key + "@" + name) + for group_key in all_keys: + store_name = get_nccl_store_name(group_key) try: actor = ray.get_actor(store_name) except ValueError: actor = None if actor: + logger.debug("Killing actor with group_key: '{}' and store: '{}'." + .format(group_key, store_name)) ray.kill(actor) @@ -41,6 +52,18 @@ def ray_start_single_node_2_gpus(): # my own on-premise cluster before run this fixture. @pytest.fixture def ray_start_distributed_2_nodes_4_gpus(): + # The cluster has a setup of 2 nodes, each node with 2 + # GPUs. Each actor will be allocated 1 GPU. + ray.init("auto") + yield + clean_up() + ray.shutdown() + + +@pytest.fixture +def ray_start_distributed_multigpu_2_nodes_4_gpus(): + # The cluster has a setup of 2 nodes, each node with 2 + # GPUs. Each actor will be allocated 2 GPUs. ray.init("auto") yield clean_up() diff --git a/python/ray/util/collective/tests/distributed_multigpu_tests/__init__.py b/python/ray/util/collective/tests/distributed_multigpu_tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_allgather.py b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_allgather.py new file mode 100644 index 000000000000..c4cabcd45524 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_allgather.py @@ -0,0 +1,82 @@ +"""Test the allgather API on a distributed Ray cluster.""" +import pytest +import ray + +import cupy as cp +import torch + +from ray.util.collective.tests.util import \ + create_collective_multigpu_workers, \ + init_tensors_for_gather_scatter_multigpu + + +@pytest.mark.parametrize("tensor_backend", ["cupy", "torch"]) +@pytest.mark.parametrize("array_size", + [2, 2**5, 2**10, 2**15, 2**20, [2, 2], [5, 5, 5]]) +def test_allgather_different_array_size( + ray_start_distributed_multigpu_2_nodes_4_gpus, array_size, + tensor_backend): + world_size = 2 + num_gpu_per_worker = 2 + actual_world_size = world_size * num_gpu_per_worker + actors, _ = create_collective_multigpu_workers(world_size) + init_tensors_for_gather_scatter_multigpu( + actors, array_size=array_size, tensor_backend=tensor_backend) + results = ray.get([a.do_allgather_multigpu.remote() for a in actors]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + for k in range(actual_world_size): + if tensor_backend == "cupy": + assert (results[i][j][k] == cp.ones( + array_size, dtype=cp.float32)).all() + else: + assert (results[i][j][k] == torch.ones( + array_size, dtype=torch.float32).cuda(j)).all() + + +def test_allgather_torch_cupy(ray_start_distributed_multigpu_2_nodes_4_gpus): + world_size = 2 + num_gpu_per_worker = 2 + actual_world_size = world_size * num_gpu_per_worker + shape = [10, 10] + actors, _ = create_collective_multigpu_workers(world_size) + + # tensor is pytorch, list is cupy + for i, a in enumerate(actors): + ray.get([ + a.set_buffer.remote( + shape, tensor_type0="torch", tensor_type1="torch") + ]) + ray.get([ + a.set_list_buffer.remote( + shape, tensor_type0="cupy", tensor_type1="cupy") + ]) + results = ray.get([a.do_allgather_multigpu.remote() for a in actors]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + for k in range(actual_world_size): + assert (results[i][j][k] == cp.ones(shape, + dtype=cp.float32)).all() + + # tensor is cupy, list is pytorch + for i, a in enumerate(actors): + ray.get([ + a.set_buffer.remote( + shape, tensor_type0="cupy", tensor_type1="cupy") + ]) + ray.get([ + a.set_list_buffer.remote( + shape, tensor_type0="torch", tensor_type1="torch") + ]) + results = ray.get([a.do_allgather_multigpu.remote() for a in actors]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + for k in range(actual_world_size): + assert (results[i][j][k] == torch.ones( + shape, dtype=torch.float32).cuda(j)).all() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_allreduce.py b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_allreduce.py new file mode 100644 index 000000000000..b681a08490b0 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_allreduce.py @@ -0,0 +1,160 @@ +"""Test the collective allreduice API on a distributed Ray cluster.""" +import pytest +import logging + +import cupy as cp + +import ray +from ray.util.collective.types import ReduceOp +from ray.util.collective.tests.util import create_collective_multigpu_workers + +logger = logging.getLogger(__name__) +logger.setLevel("DEBUG") + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +def test_allreduce_multigpu_different_name( + ray_start_distributed_multigpu_2_nodes_4_gpus, group_name): + world_size = 2 + num_gpu_per_worker = 2 + actual_world_size = world_size * num_gpu_per_worker + actors, _ = create_collective_multigpu_workers( + num_workers=world_size, group_name=group_name) + results = ray.get( + [a.do_allreduce_multigpu.remote(group_name) for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * actual_world_size).all() + assert (results[1] == cp.ones( + (10, ), dtype=cp.float32) * actual_world_size).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +def test_allreduce_multigpu_different_array_size( + ray_start_distributed_multigpu_2_nodes_4_gpus, array_size): + world_size = 2 + num_gpu_per_worker = 2 + actual_world_size = world_size * num_gpu_per_worker + actors, _ = create_collective_multigpu_workers(world_size) + ray.get([a.set_buffer.remote(array_size) for a in actors]) + results = ray.get([a.do_allreduce_multigpu.remote() for a in actors]) + assert (results[0] == cp.ones( + (array_size, ), dtype=cp.float32) * actual_world_size).all() + assert (results[1] == cp.ones( + (array_size, ), dtype=cp.float32) * actual_world_size).all() + + +def test_allreduce_multigpu_destroy( + ray_start_distributed_multigpu_2_nodes_4_gpus, + backend="nccl", + group_name="default"): + world_size = 2 + num_gpu_per_worker = 2 + actual_world_size = world_size * num_gpu_per_worker + actors, _ = create_collective_multigpu_workers(world_size) + + results = ray.get([a.do_allreduce_multigpu.remote() for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * actual_world_size).all() + assert (results[1] == cp.ones( + (10, ), dtype=cp.float32) * actual_world_size).all() + + # destroy the group and try do work, should fail + ray.get([a.destroy_group.remote() for a in actors]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_allreduce_multigpu.remote() for a in actors]) + + # reinit the same group and all reduce + ray.get([ + actor.init_group.remote(world_size, i, backend, group_name) + for i, actor in enumerate(actors) + ]) + results = ray.get([a.do_allreduce_multigpu.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * actual_world_size + * actual_world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * actual_world_size + * actual_world_size).all() + + +def test_allreduce_multigpu_multiple_group( + ray_start_distributed_multigpu_2_nodes_4_gpus, + backend="nccl", + num_groups=5): + world_size = 2 + num_gpu_per_worker = 2 + actual_world_size = world_size * num_gpu_per_worker + actors, _ = create_collective_multigpu_workers(world_size) + for group_name in range(1, num_groups): + ray.get([ + actor.init_group.remote(world_size, i, backend, str(group_name)) + for i, actor in enumerate(actors) + ]) + for i in range(num_groups): + group_name = "default" if i == 0 else str(i) + results = ray.get( + [a.do_allreduce_multigpu.remote(group_name) for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * (actual_world_size**(i + 1))).all() + + +def test_allreduce_multigpu_different_op( + ray_start_distributed_multigpu_2_nodes_4_gpus): + world_size = 2 + actors, _ = create_collective_multigpu_workers(world_size) + # check product + ray.get(actors[0].set_buffer.remote([10], value0=2, value1=3)) + ray.get(actors[1].set_buffer.remote([10], value0=4, value1=5)) + results = ray.get( + [a.do_allreduce_multigpu.remote(op=ReduceOp.PRODUCT) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 120).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 120).all() + + # check min + ray.get(actors[0].set_buffer.remote([10], value0=2, value1=3)) + ray.get(actors[1].set_buffer.remote([10], value0=4, value1=5)) + results = ray.get( + [a.do_allreduce_multigpu.remote(op=ReduceOp.MIN) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 2).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 2).all() + + # check max + ray.get(actors[0].set_buffer.remote([10], value0=2, value1=3)) + ray.get(actors[1].set_buffer.remote([10], value0=4, value1=5)) + results = ray.get( + [a.do_allreduce_multigpu.remote(op=ReduceOp.MAX) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 5).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 5).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_allreduce_multigpu_different_dtype( + ray_start_distributed_multigpu_2_nodes_4_gpus, dtype): + world_size = 2 + num_gpu_per_worker = 2 + actual_world_size = world_size * num_gpu_per_worker + actors, _ = create_collective_multigpu_workers(world_size) + ray.get([a.set_buffer.remote([10], dtype=dtype) for a in actors]) + results = ray.get([a.do_allreduce_multigpu.remote() for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=dtype) * actual_world_size).all() + assert (results[1] == cp.ones( + (10, ), dtype=dtype) * actual_world_size).all() + + +def test_allreduce_torch_cupy(ray_start_distributed_multigpu_2_nodes_4_gpus): + # import torch + world_size = 2 + actual_world_size = 4 + actors, _ = create_collective_multigpu_workers(world_size) + ray.get(actors[0].set_buffer.remote([10])) + ray.get(actors[1].set_buffer.remote( + [10], tensor_type0="torch", tensor_type1="torch")) + results = ray.get([a.do_allreduce_multigpu.remote() for a in actors]) + assert (results[0] == cp.ones((10, )) * actual_world_size).all() + + ray.get(actors[0].set_buffer.remote( + [10], tensor_type0="cupy", tensor_type1="torch")) + ray.get(actors[1].set_buffer.remote( + [10], tensor_type0="torch", tensor_type1="cupy")) + results = ray.get([a.do_allreduce_multigpu.remote() for a in actors]) + assert (results[0] == cp.ones((10, )) * actual_world_size).all() diff --git a/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_basic_apis.py b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_basic_apis.py new file mode 100644 index 000000000000..40be55dd2e0b --- /dev/null +++ b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_basic_apis.py @@ -0,0 +1,117 @@ +"""Test the collective group APIs.""" +import pytest +import ray +from random import shuffle + +from ray.util.collective.tests.util import create_collective_multigpu_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +def test_init_two_actors(ray_start_distributed_multigpu_2_nodes_4_gpus, + group_name): + world_size = 2 + actors, results = create_collective_multigpu_workers( + world_size, group_name) + for i in range(world_size): + assert (results[i]) + + +def test_report_num_gpus(ray_start_distributed_multigpu_2_nodes_4_gpus): + world_size = 2 + actors, results = create_collective_multigpu_workers(world_size) + num_gpus = ray.get([actor.report_num_gpus.remote() for actor in actors]) + assert num_gpus == [2, 2] + + +def test_get_rank(ray_start_distributed_multigpu_2_nodes_4_gpus): + world_size = 2 + actors, _ = create_collective_multigpu_workers(world_size) + actor0_rank = ray.get(actors[0].report_rank.remote()) + assert actor0_rank == 0 + actor1_rank = ray.get(actors[1].report_rank.remote()) + assert actor1_rank == 1 + + # create a second group with a different name, and different + # orders of ranks. + new_group_name = "default2" + ranks = list(range(world_size)) + shuffle(ranks) + _ = ray.get([ + actor.init_group.remote( + world_size, ranks[i], group_name=new_group_name) + for i, actor in enumerate(actors) + ]) + actor0_rank = ray.get(actors[0].report_rank.remote(new_group_name)) + assert actor0_rank == ranks[0] + actor1_rank = ray.get(actors[1].report_rank.remote(new_group_name)) + assert actor1_rank == ranks[1] + + +def test_availability(ray_start_distributed_multigpu_2_nodes_4_gpus): + world_size = 2 + actors, _ = create_collective_multigpu_workers(world_size) + actor0_nccl_availability = ray.get( + actors[0].report_nccl_availability.remote()) + assert actor0_nccl_availability + actor0_gloo_availability = ray.get( + actors[0].report_gloo_availability.remote()) + assert not actor0_gloo_availability + + +def test_is_group_initialized(ray_start_distributed_multigpu_2_nodes_4_gpus): + world_size = 2 + actors, _ = create_collective_multigpu_workers(world_size) + # check group is_init + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("random")) + assert not actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("123")) + assert not actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + actor1_is_init = ray.get( + actors[0].report_is_group_initialized.remote("456")) + assert not actor1_is_init + + +def test_destroy_group(ray_start_distributed_multigpu_2_nodes_4_gpus): + world_size = 2 + actors, _ = create_collective_multigpu_workers(world_size) + # Now destroy the group at actor0 + ray.wait([actors[0].destroy_group.remote()]) + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert not actor0_is_init + + # should go well as the group `random` does not exist at all + ray.wait([actors[0].destroy_group.remote("random")]) + + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("random")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("default")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert not actor1_is_init + + # Now reconstruct the group using the same name + init_results = ray.get([ + actor.init_group.remote(world_size, i) + for i, actor in enumerate(actors) + ]) + for i in range(world_size): + assert init_results[i] + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_broadcast.py b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_broadcast.py new file mode 100644 index 000000000000..5ded5bce35e8 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_broadcast.py @@ -0,0 +1,92 @@ +"""Test the broadcast API.""" +import pytest +import cupy as cp +import ray + +from ray.util.collective.tests.util import create_collective_multigpu_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("src_rank", [0, 1]) +@pytest.mark.parametrize("src_gpu_index", [0, 1]) +def test_broadcast_different_name( + ray_start_distributed_multigpu_2_nodes_4_gpus, group_name, src_rank, + src_gpu_index): + world_size = 2 + num_gpu_per_worker = 2 + actors, _ = create_collective_multigpu_workers( + num_workers=world_size, group_name=group_name) + ray.get(actors[0].set_buffer.remote([10], value0=2, value1=3)) + ray.get(actors[1].set_buffer.remote([10], value0=4, value1=5)) + + results = ray.get([ + a.do_broadcast_multigpu.remote( + group_name=group_name, + src_rank=src_rank, + src_gpu_index=src_gpu_index) for a in actors + ]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + val = (src_rank + 1) * 2 + src_gpu_index + assert ( + results[i][j] == cp.ones([10], dtype=cp.float32) * val).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +@pytest.mark.parametrize("src_rank", [0, 1]) +@pytest.mark.parametrize("src_gpu_index", [0, 1]) +def test_broadcast_different_array_size( + ray_start_distributed_multigpu_2_nodes_4_gpus, array_size, src_rank, + src_gpu_index): + world_size = 2 + num_gpu_per_worker = 2 + actors, _ = create_collective_multigpu_workers(world_size) + ray.get(actors[0].set_buffer.remote([array_size], value0=2, value1=3)) + ray.get(actors[1].set_buffer.remote([array_size], value0=4, value1=5)) + results = ray.get([ + a.do_broadcast_multigpu.remote( + src_rank=src_rank, src_gpu_index=src_gpu_index) for a in actors + ]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + val = (src_rank + 1) * 2 + src_gpu_index + assert (results[i][j] == cp.ones( + (array_size, ), dtype=cp.float32) * val).all() + + +@pytest.mark.parametrize("src_rank", [0, 1]) +@pytest.mark.parametrize("src_gpu_index", [0, 1]) +def test_broadcast_torch_cupy(ray_start_distributed_multigpu_2_nodes_4_gpus, + src_rank, src_gpu_index): + import torch + world_size = 2 + num_gpu_per_worker = 2 + actors, _ = create_collective_multigpu_workers(world_size) + ray.get(actors[0].set_buffer.remote([10], value0=2, value1=3)) + ray.get(actors[1].set_buffer.remote( + [10], value0=4, value1=5, tensor_type0="torch", tensor_type1="torch")) + results = ray.get([ + a.do_broadcast_multigpu.remote( + src_rank=src_rank, src_gpu_index=src_gpu_index) for a in actors + ]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + val = (src_rank + 1) * 2 + src_gpu_index + if i == 0: + assert (results[i][j] == cp.ones([10], dtype=cp.float32) * + val).all() + else: + assert (results[i][j] == torch.ones([10]).cuda(j) * val).all() + + +@pytest.mark.parametrize("src_rank", [3, 4]) +@pytest.mark.parametrize("src_gpu_index", [2, 3]) +def test_broadcast_invalid_rank(ray_start_distributed_multigpu_2_nodes_4_gpus, + src_rank, src_gpu_index): + world_size = 2 + actors, _ = create_collective_multigpu_workers(world_size) + with pytest.raises(ValueError): + _ = ray.get([ + a.do_broadcast_multigpu.remote( + src_rank=src_rank, src_gpu_index=src_gpu_index) for a in actors + ]) diff --git a/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_reduce.py b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_reduce.py new file mode 100644 index 000000000000..8ac5d54c1c12 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_reduce.py @@ -0,0 +1,173 @@ +"""Test the reduce API.""" +import pytest +import cupy as cp +import ray +from ray.util.collective.types import ReduceOp + +from ray.util.collective.tests.util import create_collective_multigpu_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("dst_rank", [0, 1]) +@pytest.mark.parametrize("dst_gpu_index", [0, 1]) +def test_reduce_different_name(ray_start_distributed_multigpu_2_nodes_4_gpus, + group_name, dst_rank, dst_gpu_index): + world_size = 2 + num_gpu_per_worker = 2 + actual_world_size = world_size * num_gpu_per_worker + actors, _ = create_collective_multigpu_workers( + num_workers=world_size, group_name=group_name) + results = ray.get([ + a.do_reduce_multigpu.remote( + group_name, dst_rank=dst_rank, dst_gpu_index=dst_gpu_index) + for a in actors + ]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + if i == dst_rank and j == dst_gpu_index: + assert (results[i][j] == cp.ones( + (10, ), dtype=cp.float32) * actual_world_size).all() + else: + assert (results[i][j] == cp.ones((10, ), + dtype=cp.float32)).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +@pytest.mark.parametrize("dst_rank", [0, 1]) +@pytest.mark.parametrize("dst_gpu_index", [0, 1]) +def test_reduce_different_array_size( + ray_start_distributed_multigpu_2_nodes_4_gpus, array_size, dst_rank, + dst_gpu_index): + world_size = 2 + num_gpu_per_worker = 2 + actual_world_size = world_size * num_gpu_per_worker + actors, _ = create_collective_multigpu_workers(num_workers=world_size) + + ray.get(actors[0].set_buffer.remote(array_size)) + ray.get(actors[1].set_buffer.remote(array_size)) + results = ray.get([ + a.do_reduce_multigpu.remote( + dst_rank=dst_rank, dst_gpu_index=dst_gpu_index) for a in actors + ]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + if i == dst_rank and j == dst_gpu_index: + assert (results[i][j] == cp.ones( + (array_size, ), dtype=cp.float32) * + actual_world_size).all() + else: + assert (results[i][j] == cp.ones( + (array_size, ), dtype=cp.float32)).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1]) +@pytest.mark.parametrize("dst_gpu_index", [0, 1]) +def test_reduce_different_op(ray_start_distributed_multigpu_2_nodes_4_gpus, + dst_rank, dst_gpu_index): + world_size = 2 + num_gpu_per_worker = 2 + actors, _ = create_collective_multigpu_workers(world_size) + + # check product + ray.get(actors[0].set_buffer.remote([10], value0=2, value1=3)) + ray.get(actors[1].set_buffer.remote([10], value0=4, value1=5)) + results = ray.get([ + a.do_reduce_multigpu.remote( + dst_rank=dst_rank, + dst_gpu_index=dst_gpu_index, + op=ReduceOp.PRODUCT) for a in actors + ]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + if i == dst_rank and j == dst_gpu_index: + assert (results[i][j] == cp.ones( + (10, ), dtype=cp.float32) * 120).all() + else: + val = (i + 1) * 2 + j + assert (results[i][j] == cp.ones( + (10, ), dtype=cp.float32) * val).all() + + # check min + ray.get(actors[0].set_buffer.remote([10], value0=2, value1=3)) + ray.get(actors[1].set_buffer.remote([10], value0=4, value1=5)) + results = ray.get([ + a.do_reduce_multigpu.remote( + dst_rank=dst_rank, dst_gpu_index=dst_gpu_index, op=ReduceOp.MIN) + for a in actors + ]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + if i == dst_rank and j == dst_gpu_index: + assert (results[i][j] == cp.ones( + (10, ), dtype=cp.float32) * 2).all() + else: + val = (i + 1) * 2 + j + assert (results[i][j] == cp.ones( + (10, ), dtype=cp.float32) * val).all() + + # check max + ray.get(actors[0].set_buffer.remote([10], value0=2, value1=3)) + ray.get(actors[1].set_buffer.remote([10], value0=4, value1=5)) + results = ray.get([ + a.do_reduce_multigpu.remote( + dst_rank=dst_rank, dst_gpu_index=dst_gpu_index, op=ReduceOp.MAX) + for a in actors + ]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + if i == dst_rank and j == dst_gpu_index: + assert (results[i][j] == cp.ones( + (10, ), dtype=cp.float32) * 5).all() + else: + val = (i + 1) * 2 + j + assert (results[i][j] == cp.ones( + (10, ), dtype=cp.float32) * val).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1]) +@pytest.mark.parametrize("dst_gpu_index", [0, 1]) +def test_reduce_torch_cupy(ray_start_distributed_multigpu_2_nodes_4_gpus, + dst_rank, dst_gpu_index): + import torch + world_size = 2 + num_gpu_per_worker = 2 + actors, _ = create_collective_multigpu_workers(world_size) + ray.get(actors[0].set_buffer.remote([10], value0=2, value1=3)) + ray.get(actors[1].set_buffer.remote( + [10], value0=4, value1=5, tensor_type0="torch", tensor_type1="torch")) + + results = ray.get([ + a.do_reduce_multigpu.remote( + dst_rank=dst_rank, dst_gpu_index=dst_gpu_index) for a in actors + ]) + + for i in range(world_size): + for j in range(num_gpu_per_worker): + val = (i + 1) * 2 + j + if dst_rank == i and dst_gpu_index == j: + if i == 0: + assert (results[i][j] == cp.ones([10], dtype=cp.float32) * + 14).all() + else: + assert ( + results[i][j] == torch.ones([10]).cuda(j) * 14).all() + else: + if i == 0: + assert (results[i][j] == cp.ones([10], dtype=cp.float32) * + val).all() + else: + assert ( + results[i][j] == torch.ones([10]).cuda(j) * val).all() + + +@pytest.mark.parametrize("dst_rank", [3, 4]) +@pytest.mark.parametrize("dst_gpu_index", [2, 3]) +def test_reduce_invalid_rank(ray_start_distributed_multigpu_2_nodes_4_gpus, + dst_rank, dst_gpu_index): + world_size = 2 + actors, _ = create_collective_multigpu_workers(world_size) + with pytest.raises(ValueError): + _ = ray.get([ + a.do_reduce_multigpu.remote( + dst_rank=dst_rank, dst_gpu_index=dst_gpu_index) for a in actors + ]) diff --git a/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_reducescatter.py b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_reducescatter.py new file mode 100644 index 000000000000..48f72389bf89 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_reducescatter.py @@ -0,0 +1,82 @@ +"""Test the collective reducescatter API on a distributed Ray cluster.""" +import pytest +import ray + +import cupy as cp +import torch + +from ray.util.collective.tests.util import \ + create_collective_multigpu_workers, \ + init_tensors_for_gather_scatter_multigpu + + +@pytest.mark.parametrize("tensor_backend", ["cupy", "torch"]) +@pytest.mark.parametrize("array_size", + [2, 2**5, 2**10, 2**15, 2**20, [2, 2], [5, 5, 5]]) +def test_reducescatter_different_array_size( + ray_start_distributed_multigpu_2_nodes_4_gpus, array_size, + tensor_backend): + world_size = 2 + num_gpu_per_worker = 2 + actual_world_size = world_size * num_gpu_per_worker + actors, _ = create_collective_multigpu_workers(world_size) + + init_tensors_for_gather_scatter_multigpu( + actors, array_size=array_size, tensor_backend=tensor_backend) + results = ray.get([a.do_reducescatter_multigpu.remote() for a in actors]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + if tensor_backend == "cupy": + assert (results[i][j] == cp.ones(array_size, dtype=cp.float32) + * actual_world_size).all() + else: + assert (results[i][j] == torch.ones( + array_size, dtype=torch.float32).cuda(j) * + actual_world_size).all() + + +def test_reducescatter_torch_cupy( + ray_start_distributed_multigpu_2_nodes_4_gpus): + world_size = 2 + num_gpu_per_worker = 2 + actual_world_size = world_size * num_gpu_per_worker + shape = [10, 10] + actors, _ = create_collective_multigpu_workers(world_size) + + # tensor is pytorch, list is cupy + for i, a in enumerate(actors): + ray.get([ + a.set_buffer.remote( + shape, tensor_type0="torch", tensor_type1="torch") + ]) + ray.get([ + a.set_list_buffer.remote( + shape, tensor_type0="cupy", tensor_type1="cupy") + ]) + results = ray.get([a.do_reducescatter_multigpu.remote() for a in actors]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + assert (results[i][j] == torch.ones( + shape, dtype=torch.float32).cuda(j) * actual_world_size).all() + + # tensor is cupy, list is pytorch + for i, a in enumerate(actors): + ray.get([ + a.set_buffer.remote( + shape, tensor_type0="cupy", tensor_type1="cupy") + ]) + ray.get([ + a.set_list_buffer.remote( + shape, tensor_type0="torch", tensor_type1="torch") + ]) + results = ray.get([a.do_reducescatter_multigpu.remote() for a in actors]) + for i in range(world_size): + for j in range(num_gpu_per_worker): + assert (results[i][j] == cp.ones(shape, dtype=cp.float32) * + actual_world_size).all() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_sendrecv.py b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_sendrecv.py new file mode 100644 index 000000000000..a88fdb34ec8f --- /dev/null +++ b/python/ray/util/collective/tests/distributed_multigpu_tests/test_distributed_multigpu_sendrecv.py @@ -0,0 +1,47 @@ +"""Test the send/recv API.""" +import cupy as cp +import pytest +import ray + +from ray.util.collective.tests.util import create_collective_multigpu_workers + + +# @pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("dst_rank", [0, 1]) +@pytest.mark.parametrize("src_rank", [0, 1]) +@pytest.mark.parametrize("dst_gpu_index", [0, 1]) +@pytest.mark.parametrize("src_gpu_index", [0, 1]) +@pytest.mark.parametrize("array_size", + [2**10, 2**15, 2**20, [2, 2], [5, 9, 10, 85]]) +def test_sendrecv(ray_start_distributed_multigpu_2_nodes_4_gpus, array_size, + src_rank, dst_rank, src_gpu_index, dst_gpu_index): + if src_rank == dst_rank: + return + world_size = 2 + actors, _ = create_collective_multigpu_workers(num_workers=world_size) + + ray.get(actors[0].set_buffer.remote(array_size, value0=2, value1=3)) + ray.get(actors[1].set_buffer.remote(array_size, value0=4, value1=5)) + + refs = [] + for i in range(world_size): + refs.append(actors[i].get_buffer.remote()) + refs[src_rank][src_gpu_index] = actors[src_rank].do_send_multigpu.remote( + dst_rank=dst_rank, + dst_gpu_index=dst_gpu_index, + src_gpu_index=src_gpu_index) + refs[dst_rank][dst_gpu_index] = actors[dst_rank].do_recv_multigpu.remote( + src_rank=src_rank, + src_gpu_index=src_gpu_index, + dst_gpu_index=dst_gpu_index) + results = [] + results_flattend = ray.get(refs[0] + refs[1]) + results.append([results_flattend[0], results_flattend[1]]) + results.append([results_flattend[2], results_flattend[3]]) + assert (results[src_rank][src_gpu_index] == cp.ones( + array_size, dtype=cp.float32) * ( + (src_rank + 1) * 2 + src_gpu_index)).all() + assert (results[dst_rank][dst_gpu_index] == cp.ones( + array_size, dtype=cp.float32) * ( + (src_rank + 1) * 2 + src_gpu_index)).all() + ray.get([a.destroy_group.remote() for a in actors]) diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_basic_apis.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_basic_apis.py index 0f17b79ba63e..a0dd4508001f 100644 --- a/python/ray/util/collective/tests/distributed_tests/test_distributed_basic_apis.py +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_basic_apis.py @@ -69,9 +69,9 @@ def test_availability(ray_start_distributed_2_nodes_4_gpus): actor0_nccl_availability = ray.get( actors[0].report_nccl_availability.remote()) assert actor0_nccl_availability - actor0_mpi_availability = ray.get( - actors[0].report_mpi_availability.remote()) - assert not actor0_mpi_availability + actor0_gloo_availability = ray.get( + actors[0].report_gloo_availability.remote()) + assert not actor0_gloo_availability def test_is_group_initialized(ray_start_distributed_2_nodes_4_gpus): diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_broadcast.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_broadcast.py index 408ebce76b8a..5c1ecd7f14d8 100644 --- a/python/ray/util/collective/tests/distributed_tests/test_distributed_broadcast.py +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_broadcast.py @@ -60,7 +60,8 @@ def test_broadcast_torch_cupy(ray_start_distributed_2_nodes_4_gpus, src_rank): assert (results[1] == torch.ones((10, )).cuda() * world_size).all() -def test_broadcast_invalid_rank(ray_start_single_node_2_gpus, src_rank=3): +def test_broadcast_invalid_rank(ray_start_distributed_2_nodes_4_gpus, + src_rank=3): world_size = 2 actors, _ = create_collective_workers(world_size) with pytest.raises(ValueError): diff --git a/python/ray/util/collective/tests/sinlge_node_tests/__init__.py b/python/ray/util/collective/tests/sinlge_node_tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/util/collective/tests/test_allgather.py b/python/ray/util/collective/tests/sinlge_node_tests/test_allgather.py similarity index 100% rename from python/ray/util/collective/tests/test_allgather.py rename to python/ray/util/collective/tests/sinlge_node_tests/test_allgather.py diff --git a/python/ray/util/collective/tests/test_allreduce.py b/python/ray/util/collective/tests/sinlge_node_tests/test_allreduce.py similarity index 100% rename from python/ray/util/collective/tests/test_allreduce.py rename to python/ray/util/collective/tests/sinlge_node_tests/test_allreduce.py diff --git a/python/ray/util/collective/tests/test_basic_apis.py b/python/ray/util/collective/tests/sinlge_node_tests/test_basic_apis.py similarity index 97% rename from python/ray/util/collective/tests/test_basic_apis.py rename to python/ray/util/collective/tests/sinlge_node_tests/test_basic_apis.py index 8c23442a3b4c..29a3ec3f4a15 100644 --- a/python/ray/util/collective/tests/test_basic_apis.py +++ b/python/ray/util/collective/tests/sinlge_node_tests/test_basic_apis.py @@ -64,9 +64,9 @@ def test_availability(ray_start_single_node_2_gpus): actor0_nccl_availability = ray.get( actors[0].report_nccl_availability.remote()) assert actor0_nccl_availability - actor0_mpi_availability = ray.get( - actors[0].report_mpi_availability.remote()) - assert not actor0_mpi_availability + actor0_gloo_availability = ray.get( + actors[0].report_gloo_availability.remote()) + assert not actor0_gloo_availability def test_is_group_initialized(ray_start_single_node_2_gpus): diff --git a/python/ray/util/collective/tests/test_broadcast.py b/python/ray/util/collective/tests/sinlge_node_tests/test_broadcast.py similarity index 100% rename from python/ray/util/collective/tests/test_broadcast.py rename to python/ray/util/collective/tests/sinlge_node_tests/test_broadcast.py diff --git a/python/ray/util/collective/tests/test_reduce.py b/python/ray/util/collective/tests/sinlge_node_tests/test_reduce.py similarity index 100% rename from python/ray/util/collective/tests/test_reduce.py rename to python/ray/util/collective/tests/sinlge_node_tests/test_reduce.py diff --git a/python/ray/util/collective/tests/test_reducescatter.py b/python/ray/util/collective/tests/sinlge_node_tests/test_reducescatter.py similarity index 100% rename from python/ray/util/collective/tests/test_reducescatter.py rename to python/ray/util/collective/tests/sinlge_node_tests/test_reducescatter.py diff --git a/python/ray/util/collective/tests/test_sendrecv.py b/python/ray/util/collective/tests/sinlge_node_tests/test_sendrecv.py similarity index 100% rename from python/ray/util/collective/tests/test_sendrecv.py rename to python/ray/util/collective/tests/sinlge_node_tests/test_sendrecv.py diff --git a/python/ray/util/collective/tests/util.py b/python/ray/util/collective/tests/util.py index 259ee24c9727..a5fb97a53ad5 100644 --- a/python/ray/util/collective/tests/util.py +++ b/python/ray/util/collective/tests/util.py @@ -1,20 +1,29 @@ import cupy as cp +import logging import ray import ray.util.collective as col from ray.util.collective.types import Backend, ReduceOp +from ray.util.collective.collective_group.nccl_util import get_num_gpus import torch +logger = logging.getLogger(__name__) + @ray.remote(num_gpus=1) class Worker: def __init__(self): + self.buffer = None + self.list_buffer = None + + def init_tensors(self): self.buffer = cp.ones((10, ), dtype=cp.float32) self.list_buffer = [ - cp.ones((10, ), dtype=cp.float32), - cp.ones((10, ), dtype=cp.float32) + cp.ones((10, ), dtype=cp.float32) for _ in range(2) ] + cp.cuda.Stream.null.synchronize() + return True def init_group(self, world_size, @@ -79,8 +88,8 @@ def report_nccl_availability(self): avail = col.nccl_available() return avail - def report_mpi_availability(self): - avail = col.mpi_available() + def report_gloo_availability(self): + avail = col.gloo_available() return avail def report_is_group_initialized(self, group_name="default"): @@ -91,7 +100,11 @@ def report_is_group_initialized(self, group_name="default"): def create_collective_workers(num_workers=2, group_name="default", backend="nccl"): - actors = [Worker.remote() for _ in range(num_workers)] + actors = [None] * num_workers + for i in range(num_workers): + actor = Worker.remote() + ray.get([actor.init_tensors.remote()]) + actors[i] = actor world_size = num_workers init_results = ray.get([ actor.init_group.remote(world_size, i, backend, group_name) @@ -112,7 +125,7 @@ def init_tensors_for_gather_scatter(actors, t = torch.ones(array_size, dtype=torch.float32).cuda() * (i + 1) else: raise RuntimeError("Unsupported tensor backend.") - ray.wait([a.set_buffer.remote(t)]) + ray.get([a.set_buffer.remote(t)]) if tensor_backend == "cupy": list_buffer = [ cp.ones(array_size, dtype=dtype) for _ in range(world_size) @@ -125,3 +138,250 @@ def init_tensors_for_gather_scatter(actors, else: raise RuntimeError("Unsupported tensor backend.") ray.get([a.set_list_buffer.remote(list_buffer) for a in actors]) + + +@ray.remote(num_gpus=2) +class MultiGPUWorker: + def __init__(self): + self.buffer0 = None + self.buffer1 = None + self.list_buffer0 = None + self.list_buffer1 = None + + def __del__(self): + self.buffer0 = None + self.buffer1 = None + self.list_buffer0 = None + self.list_buffer1 = None + + def init_tensors(self): + with cp.cuda.Device(0): + self.buffer0 = cp.ones((10, ), dtype=cp.float32) + self.list_buffer0 = [ + cp.ones((10, ), dtype=cp.float32) for _ in range(4) + ] + with cp.cuda.Device(1): + self.buffer1 = cp.ones((10, ), dtype=cp.float32) + self.list_buffer1 = [ + cp.ones((10, ), dtype=cp.float32) for _ in range(4) + ] + cp.cuda.Stream.null.synchronize() + return True + + def init_group(self, + world_size, + rank, + backend=Backend.NCCL, + group_name="default"): + col.init_collective_group(world_size, rank, backend, group_name) + return True + + def set_buffer(self, + size, + value0=1.0, + value1=1.0, + dtype=cp.float32, + tensor_type0="cupy", + tensor_type1="cupy"): + if tensor_type0 == "cupy": + with cp.cuda.Device(0): + self.buffer0 = cp.ones(size, dtype=dtype) * value0 + elif tensor_type0 == "torch": + self.buffer0 = torch.ones( + size, dtype=torch.float32).cuda(0) * value0 + else: + raise RuntimeError() + + if tensor_type1 == "cupy": + with cp.cuda.Device(1): + self.buffer1 = cp.ones(size, dtype=dtype) * value1 + elif tensor_type1 == "torch": + self.buffer1 = torch.ones( + size, dtype=torch.float32).cuda(1) * value1 + else: + raise RuntimeError() + cp.cuda.Device(0).synchronize() + cp.cuda.Device(1).synchronize() + # cp.cuda.Stream.null.synchronize() + return True + + def set_list_buffer(self, + size, + value0=1.0, + value1=1.0, + dtype=cp.float32, + tensor_type0="cupy", + tensor_type1="cupy"): + if tensor_type0 == "cupy": + with cp.cuda.Device(0): + self.list_buffer0 = [ + cp.ones(size, dtype=dtype) * value0 for _ in range(4) + ] + elif tensor_type0 == "torch": + self.list_buffer0 = [ + torch.ones(size, dtype=torch.float32).cuda(0) * value0 + for _ in range(4) + ] + else: + raise RuntimeError() + + if tensor_type1 == "cupy": + with cp.cuda.Device(1): + self.list_buffer1 = [ + cp.ones(size, dtype=dtype) * value1 for _ in range(4) + ] + elif tensor_type1 == "torch": + self.list_buffer1 = [ + torch.ones(size, dtype=torch.float32).cuda(1) * value1 + for _ in range(4) + ] + else: + raise RuntimeError() + cp.cuda.Device(0).synchronize() + cp.cuda.Device(1).synchronize() + return True + + @ray.method(num_returns=2) + def get_buffer(self): + return self.buffer0, self.buffer1 + + def do_allreduce_multigpu(self, group_name="default", op=ReduceOp.SUM): + col.allreduce_multigpu([self.buffer0, self.buffer1], group_name, op) + cp.cuda.Device(0).synchronize() + cp.cuda.Device(1).synchronize() + return self.buffer0 + + def do_reduce_multigpu(self, + group_name="default", + dst_rank=0, + dst_gpu_index=0, + op=ReduceOp.SUM): + col.reduce_multigpu([self.buffer0, self.buffer1], dst_rank, + dst_gpu_index, group_name, op) + cp.cuda.Device(0).synchronize() + cp.cuda.Device(1).synchronize() + return self.buffer0, self.buffer1 + + def do_broadcast_multigpu(self, + group_name="default", + src_rank=0, + src_gpu_index=0): + col.broadcast_multigpu([self.buffer0, self.buffer1], src_rank, + src_gpu_index, group_name) + return self.buffer0, self.buffer1 + + def do_allgather_multigpu(self, group_name="default"): + col.allgather_multigpu([self.list_buffer0, self.list_buffer1], + [self.buffer0, self.buffer1], group_name) + cp.cuda.Device(0).synchronize() + cp.cuda.Device(1).synchronize() + return self.list_buffer0, self.list_buffer1 + + def do_reducescatter_multigpu(self, group_name="default", op=ReduceOp.SUM): + col.reducescatter_multigpu([self.buffer0, self.buffer1], + [self.list_buffer0, self.list_buffer1], + group_name, op) + cp.cuda.Device(0).synchronize() + cp.cuda.Device(1).synchronize() + return self.buffer0, self.buffer1 + + def do_send_multigpu(self, + group_name="default", + dst_rank=0, + dst_gpu_index=0, + src_gpu_index=0): + if src_gpu_index == 0: + col.send_multigpu(self.buffer0, dst_rank, dst_gpu_index, + group_name) + cp.cuda.Device(0).synchronize() + return self.buffer0 + elif src_gpu_index == 1: + col.send_multigpu(self.buffer1, dst_rank, dst_gpu_index, + group_name) + cp.cuda.Device(1).synchronize() + return self.buffer1 + else: + raise RuntimeError() + + def do_recv_multigpu(self, + group_name="default", + src_rank=0, + src_gpu_index=0, + dst_gpu_index=0): + if dst_gpu_index == 0: + col.recv_multigpu(self.buffer0, src_rank, src_gpu_index, + group_name) + cp.cuda.Device(0).synchronize() + return self.buffer0 + elif dst_gpu_index == 1: + col.recv_multigpu(self.buffer1, src_rank, src_gpu_index, + group_name) + cp.cuda.Device(1).synchronize() + return self.buffer1 + else: + raise RuntimeError() + + def destroy_group(self, group_name="default"): + col.destroy_collective_group(group_name) + return True + + def report_rank(self, group_name="default"): + rank = col.get_rank(group_name) + return rank + + def report_world_size(self, group_name="default"): + ws = col.get_world_size(group_name) + return ws + + def report_nccl_availability(self): + avail = col.nccl_available() + return avail + + def report_gloo_availability(self): + avail = col.gloo_available() + return avail + + def report_is_group_initialized(self, group_name="default"): + is_init = col.is_group_initialized(group_name) + return is_init + + def report_num_gpus(self): + n_gpus = get_num_gpus() + return n_gpus + + +def create_collective_multigpu_workers(num_workers=2, + group_name="default", + backend="nccl"): + actors = [None] * num_workers + for i in range(num_workers): + actor = MultiGPUWorker.remote() + ray.get([actor.set_buffer.remote([10])], timeout=10) + ray.get([actor.set_list_buffer.remote([10])], timeout=10) + actors[i] = actor + world_size = num_workers + init_results = ray.get([ + actor.init_group.remote(world_size, i, backend, group_name) + for i, actor in enumerate(actors) + ]) + return actors, init_results + + +def init_tensors_for_gather_scatter_multigpu(actors, + array_size=10, + tensor_backend="cupy"): + for i, a in enumerate(actors): + if tensor_backend == "cupy": + ray.get([a.set_buffer.remote(array_size)]) + ray.get([a.set_list_buffer.remote(array_size)]) + elif tensor_backend == "torch": + ray.get([ + a.set_buffer.remote( + array_size, tensor_type0="torch", tensor_type1="torch") + ]) + ray.get([ + a.set_list_buffer.remote( + array_size, tensor_type0="torch", tensor_type1="torch") + ]) + else: + raise RuntimeError("Unsupported tensor backend.") diff --git a/python/ray/util/collective/types.py b/python/ray/util/collective/types.py index c12dde84cb6a..d3e964486f77 100644 --- a/python/ray/util/collective/types.py +++ b/python/ray/util/collective/types.py @@ -30,6 +30,7 @@ class Backend(object): """A class to represent different backends.""" NCCL = "nccl" MPI = "mpi" + GLOO = "gloo" UNRECOGNIZED = "unrecognized" def __new__(cls, name: str): @@ -38,6 +39,8 @@ def __new__(cls, name: str): raise ValueError("Unrecognized backend: '{}'. " "Only NCCL is supported".format(name)) if backend == Backend.MPI: + raise RuntimeError("Ray does not support MPI backend.") + if backend == Backend.GLOO: raise NotImplementedError() return backend @@ -67,6 +70,7 @@ class BarrierOptions: class ReduceOptions: reduceOp = ReduceOp.SUM root_rank = 0 + root_tensor = 0 # index for multi-gpu reduce operations timeout_ms = unset_timeout_ms @@ -85,6 +89,7 @@ class AllGatherOptions: @dataclass class BroadcastOptions: root_rank = 0 + root_tensor = 0 timeout_ms = unset_timeout_ms @@ -92,3 +97,17 @@ class BroadcastOptions: class ReduceScatterOptions: reduceOp = ReduceOp.SUM timeout_ms = unset_timeout_ms + + +@dataclass +class SendOptions: + dst_rank = 0 + dst_gpu_index = 0 + timeout_ms = unset_timeout_ms + + +@dataclass +class RecvOptions: + src_rank = 0 + src_gpu_index = 0 + unset_timeout_ms = unset_timeout_ms