diff --git a/python/ray/util/__init__.py b/python/ray/util/__init__.py index be876c6491fe..2bead8d1851d 100644 --- a/python/ray/util/__init__.py +++ b/python/ray/util/__init__.py @@ -10,5 +10,5 @@ __all__ = [ "ActorPool", "disable_log_once_globally", "enable_periodic_logging", "iter", "log_once", "pdb", "placement_group", "placement_group_table", - "remove_placement_group", "inspect_serializability" + "remove_placement_group", "inspect_serializability", "collective" ] diff --git a/python/ray/util/collective/__init__.py b/python/ray/util/collective/__init__.py new file mode 100644 index 000000000000..68fcb78d444e --- /dev/null +++ b/python/ray/util/collective/__init__.py @@ -0,0 +1,9 @@ +from .collective import nccl_available, mpi_available, is_group_initialized, \ + init_collective_group, destroy_collective_group, get_rank, \ + get_world_size, allreduce, barrier + +__all__ = [ + "nccl_available", "mpi_available", "is_group_initialized", + "init_collective_group", "destroy_collective_group", "get_rank", + "get_world_size", "allreduce", "barrier" +] diff --git a/python/ray/util/collective/collective.py b/python/ray/util/collective/collective.py new file mode 100644 index 000000000000..343487e718bc --- /dev/null +++ b/python/ray/util/collective/collective.py @@ -0,0 +1,275 @@ +"""APIs exposed under the namespace ray.util.collective.""" +import logging + +import numpy as np +import ray +from ray.util.collective import types +from ray.util.collective.const import get_nccl_store_name + +_MPI_AVAILABLE = False +_NCCL_AVAILABLE = True + +# try: +# from ray.util.collective.collective_group.mpi_collective_group \ +# import MPIGroup +# except ImportError: +# _MPI_AVAILABLE = False +try: + from ray.util.collective.collective_group import NCCLGroup + from ray.util.collective.collective_group import nccl_util +except ImportError: + _NCCL_AVAILABLE = False + +logger = logging.getLogger(__name__) + + +def nccl_available(): + return _NCCL_AVAILABLE + + +def mpi_available(): + return _MPI_AVAILABLE + + +class GroupManager(object): + """ + Use this class to manage the collective groups we created so far. + + Each process will have an instance of `GroupManager`. Each process + could belong to multiple collective groups. The membership information + and other metadata are stored in the global `_group_mgr` object. + """ + + def __init__(self): + self._name_group_map = {} + self._group_name_map = {} + + def create_collective_group(self, backend, world_size, rank, group_name): + """ + The entry to create new collective groups and register in the manager. + + Put the registration and the group information into the manager + metadata as well. + """ + backend = types.Backend(backend) + if backend == types.Backend.MPI: + raise NotImplementedError() + elif backend == types.Backend.NCCL: + # create the ncclUniqueID + if rank == 0: + # availability has been checked before entering here. + group_uid = nccl_util.get_nccl_unique_id() + store_name = get_nccl_store_name(group_name) + # Avoid a potential circular dependency in ray/actor.py + from ray.util.collective.util import NCCLUniqueIDStore + store = NCCLUniqueIDStore.options( + name=store_name, lifetime="detached").remote(store_name) + ray.wait([store.set_id.remote(group_uid)]) + + logger.debug("creating NCCL group: '{}'".format(group_name)) + g = NCCLGroup(world_size, rank, group_name) + self._name_group_map[group_name] = g + self._group_name_map[g] = group_name + return self._name_group_map[group_name] + + def is_group_exist(self, group_name): + return group_name in self._name_group_map + + def get_group_by_name(self, group_name): + """Get the collective group handle by its name.""" + if not self.is_group_exist(group_name): + logger.warning( + "The group '{}' is not initialized.".format(group_name)) + return None + return self._name_group_map[group_name] + + def destroy_collective_group(self, group_name): + """Group destructor.""" + if not self.is_group_exist(group_name): + logger.warning("The group '{}' does not exist.".format(group_name)) + return + + # release the collective group resource + g = self._name_group_map[group_name] + rank = g.rank + backend = g.backend() + + # clean up the dicts + del self._group_name_map[g] + del self._name_group_map[group_name] + if backend == types.Backend.NCCL: + # release the named actor + if rank == 0: + store_name = get_nccl_store_name(group_name) + store = ray.get_actor(store_name) + ray.wait([store.__ray_terminate__.remote()]) + ray.kill(store) + # Release the communicator resources + g.destroy_group() + + +_group_mgr = GroupManager() + + +def is_group_initialized(group_name): + """Check if the group is initialized in this process by the group name.""" + return _group_mgr.is_group_exist(group_name) + + +def init_collective_group(world_size: int, + rank: int, + backend=types.Backend.NCCL, + group_name: str = "default"): + """ + Initialize a collective group inside an actor process. + + Args: + world_size (int): the total number of processed in the group. + rank (int): the rank of the current process. + backend: the CCL backend to use, NCCL or MPI. + group_name (str): the name of the collective group. + + Returns: + None + """ + _check_inside_actor() + backend = types.Backend(backend) + _check_backend_availability(backend) + global _group_mgr + # TODO(Hao): implement a group auto-counter. + if not group_name: + raise ValueError("group_name '{}' needs to be a string." + .format(group_name)) + + if _group_mgr.is_group_exist(group_name): + raise RuntimeError("Trying to initialize a group twice.") + + assert (world_size > 0) + assert (rank >= 0) + assert (rank < world_size) + _group_mgr.create_collective_group(backend, world_size, rank, group_name) + + +def destroy_collective_group(group_name: str = "default") -> None: + """Destroy a collective group given its group name.""" + _check_inside_actor() + global _group_mgr + _group_mgr.destroy_collective_group(group_name) + + +def get_rank(group_name: str = "default") -> int: + """ + Return the rank of this process in the given group. + + Args: + group_name (str): the name of the group to query + + Returns: + the rank of this process in the named group, + -1 if the group does not exist or the process does + not belong to the group. + """ + _check_inside_actor() + if not is_group_initialized(group_name): + return -1 + g = _group_mgr.get_group_by_name(group_name) + return g.rank + + +def get_world_size(group_name="default") -> int: + """ + Return the size of the collective gropu with the given name. + + Args: + group_name: the name of the group to query + + Returns: + The world size of the collective group, + -1 if the group does not exist or the process does + not belong to the group. + """ + _check_inside_actor() + if not is_group_initialized(group_name): + return -1 + g = _group_mgr.get_group_by_name(group_name) + return g.world_size + + +def allreduce(tensor, group_name: str, op=types.ReduceOp.SUM): + """ + Collective allreduce the tensor across the group with name group_name. + + Args: + tensor: the tensor to be all-reduced on this process. + group_name (str): the collective group name to perform allreduce. + op: The reduce operation. + + Returns: + None + """ + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + opts = types.AllReduceOptions + opts.reduceOp = op + g.allreduce(tensor, opts) + + +def barrier(group_name): + """ + Barrier all processes in the collective group. + + Args: + group_name (str): the name of the group to barrier. + + Returns: + None + """ + g = _check_and_get_group(group_name) + g.barrier() + + +def _check_and_get_group(group_name): + """Check the existence and return the group handle.""" + _check_inside_actor() + if not is_group_initialized(group_name): + raise RuntimeError("The collective group '{}' is not " + "initialized in the process.".format(group_name)) + g = _group_mgr.get_group_by_name(group_name) + return g + + +def _check_backend_availability(backend: types.Backend): + """Check whether the backend is available.""" + if backend == types.Backend.MPI: + if not mpi_available(): + raise RuntimeError("MPI is not available.") + elif backend == types.Backend.NCCL: + # expect some slowdown at the first call + # as I defer the import to invocation. + if not nccl_available(): + raise RuntimeError("NCCL is not available.") + + +def _check_single_tensor_input(tensor): + """Check if the tensor is with a supported type.""" + if isinstance(tensor, np.ndarray): + return + if types.cupy_available(): + if isinstance(tensor, types.cp.ndarray): + return + if types.torch_available(): + if isinstance(tensor, types.th.Tensor): + return + raise RuntimeError("Unrecognized tensor type '{}'. Supported types are: " + "np.ndarray, torch.Tensor, cupy.ndarray.".format( + type(tensor))) + + +def _check_inside_actor(): + """Check if currently it is inside a Ray actor/task.""" + worker = ray.worker.global_worker + if worker.mode == ray.WORKER_MODE: + return + else: + raise RuntimeError("The collective APIs shall be only used inside " + "a Ray actor or task.") diff --git a/python/ray/util/collective/collective_group/__init__.py b/python/ray/util/collective/collective_group/__init__.py new file mode 100644 index 000000000000..c8ecc463ea97 --- /dev/null +++ b/python/ray/util/collective/collective_group/__init__.py @@ -0,0 +1,3 @@ +from .nccl_collective_group import NCCLGroup + +__all__ = ["NCCLGroup"] diff --git a/python/ray/util/collective/collective_group/base_collective_group.py b/python/ray/util/collective/collective_group/base_collective_group.py new file mode 100644 index 000000000000..a3f54fa267f8 --- /dev/null +++ b/python/ray/util/collective/collective_group/base_collective_group.py @@ -0,0 +1,52 @@ +"""Abstract class for collective groups.""" +from abc import ABCMeta +from abc import abstractmethod + +from ray.util.collective.types import AllReduceOptions, BarrierOptions + + +class BaseGroup(metaclass=ABCMeta): + def __init__(self, world_size, rank, group_name): + """ + Init the process group with basic information. + + Args: + world_size (int): The total number of processes in the group. + rank (int): The rank of the current process. + group_name (str): The group name. + """ + self._world_size = world_size + self._rank = rank + self._group_name = group_name + + @property + def rank(self): + """Return the rank of the current process.""" + return self._rank + + @property + def world_size(self): + """Return the number of processes in this group.""" + return self._world_size + + @property + def group_name(self): + """Return the group name of this group.""" + return self._group_name + + def destroy_group(self): + """GC the communicators.""" + pass + + @classmethod + def backend(cls): + """The backend of this collective group.""" + raise NotImplementedError() + + @abstractmethod + def allreduce(self, tensor, allreduce_options=AllReduceOptions()): + raise NotImplementedError() + + @abstractmethod + def barrier(self, barrier_options=BarrierOptions()): + raise NotImplementedError() diff --git a/python/ray/util/collective/collective_group/mpi_collective_group.py b/python/ray/util/collective/collective_group/mpi_collective_group.py new file mode 100644 index 000000000000..e045ac7160db --- /dev/null +++ b/python/ray/util/collective/collective_group/mpi_collective_group.py @@ -0,0 +1,5 @@ +"""Implementation of the MPI collective group.""" +try: + import mpi4py # noqa: F401 +except ImportError: + raise diff --git a/python/ray/util/collective/collective_group/nccl_collective_group.py b/python/ray/util/collective/collective_group/nccl_collective_group.py new file mode 100644 index 000000000000..31412b5a4baa --- /dev/null +++ b/python/ray/util/collective/collective_group/nccl_collective_group.py @@ -0,0 +1,219 @@ +import logging +import datetime +import time + +import ray +import cupy + +from ray.util.collective.collective_group import nccl_util +from ray.util.collective.collective_group.base_collective_group \ + import BaseGroup +from ray.util.collective.types import AllReduceOptions, \ + BarrierOptions, Backend +from ray.util.collective.const import get_nccl_store_name + +logger = logging.getLogger(__name__) + +# TODO(Hao): +# (1) stream management, instead of using the default stream, +# using a dedicate stream +# (2) communicator management and support num_gpus > 2 per actor. + + +class Rendezvous: + """ + A rendezvous class for different actor/task processes to meet. + + To initialize an NCCL collective communication group, different + actors/tasks spawned in Ray in a collective group needs to meet + each other to synchronize the NCCLUniqueID. This class guarantees + they meet via the NCCLUniqueIDStore, initialized on the rank=0 + process. + + Args: + group_name (str): the unique user-specified group name. + """ + + def __init__(self, group_name): + if not group_name: + raise ValueError("Invalid group name.") + self._group_name = group_name + self._store_name = None + self._store = None + + def meet(self, timeout_s=180): + """ + Meet at the named actor store. + + Args: + timeout_s: timeout in seconds. + + Return: + None + """ + if timeout_s <= 0: + raise ValueError("The 'timeout' argument must be positive. " + "Got '{}'.".format(timeout_s)) + self._store_name = get_nccl_store_name(self._group_name) + timeout_delta = datetime.timedelta(seconds=timeout_s) + elapsed = datetime.timedelta(seconds=0) + start_time = datetime.datetime.now() + while elapsed < timeout_delta: + try: + logger.debug("Trying to meet at the store '{}'".format( + self._store_name)) + self._store = ray.get_actor(self._store_name) + except ValueError: + logger.debug("Failed to meet at the store '{}'." + "Trying again...".format(self._store_name)) + time.sleep(1) + elapsed = datetime.datetime.now() - start_time + continue + logger.debug("Successful rendezvous!") + break + if not self._store: + raise RuntimeError("Unable to meet other processes " + "at the rendezvous store.") + + @property + def store(self): + return self._store + + def get_nccl_id(self, timeout_s=180): + """ + Get the NCCLUniqueID from the store through Ray. + + Args: + timeout_s: timeout in seconds. + Return: + str: the NCCLUniqueID if successful. + """ + if not self._store: + raise ValueError("Rendezvous store is not setup.") + uid = None + timeout_delta = datetime.timedelta(seconds=timeout_s) + elapsed = datetime.timedelta(seconds=0) + start_time = datetime.datetime.now() + while elapsed < timeout_delta: + uid = ray.get(self._store.get_id.remote()) + if not uid: + time.sleep(1) + elapsed = datetime.datetime.now() - start_time + continue + break + if not uid: + raise RuntimeError( + "Unable to get the NCCLUniqueID from the store.") + return uid + + +class NCCLGroup(BaseGroup): + def __init__(self, world_size, rank, group_name): + """Init an NCCL collective group.""" + super(NCCLGroup, self).__init__(world_size, rank, group_name) + self._nccl_uid = None + + # TODO(Hao): change this to a be a cache + self._nccl_comm = None + + if nccl_util.get_nccl_build_version() < 2000: + raise RuntimeError("NCCL in Ray requires NCCL >= 2.0.") + # TODO(Hao): check version here + if nccl_util.get_nccl_runtime_version() < 2704: + logger.warning("NCCL send/recv calls requires NCCL>=2.7.4") + + self._rendezvous = Rendezvous(self.group_name) + self._rendezvous.meet() + + # Setup the nccl uid using the store + self._init_nccl_unique_id() + + # Setup a tensor for barrier calls + self._barrier_tensor = cupy.array([1]) + + def _init_nccl_unique_id(self): + """ + Init the NCCL unique ID required for setting up NCCL communicator. + + """ + self._nccl_uid = self._rendezvous.get_nccl_id() + + @property + def nccl_uid(self): + return self._nccl_uid + + def destroy_group(self): + """ + Destroy the group and release the NCCL communicators safely. + + """ + if self._nccl_comm is not None: + self.barrier() + # We also need a barrier call here. + stream = self._get_cuda_stream() + stream.synchronize() + # destroy the communicator + self._nccl_comm.destroy() + self._nccl_comm = None + super(NCCLGroup, self).destroy_group() + + @classmethod + def backend(cls): + return Backend.NCCL + + def allreduce(self, tensor, allreduce_options=AllReduceOptions()): + """ + AllReduce a list of tensors following options. + + Args: + tensor: the tensor to be reduced, each tensor locates on a GPU + allreduce_options: + + Returns: + """ + # obtain the communicator + comm = self._get_nccl_communicator() + # obtain the stream: using default stream by now + # TODO(Hao): implement a simple stream manager here + stream = self._get_cuda_stream() + + dtype = nccl_util.get_nccl_tensor_dtype(tensor) + ptr = nccl_util.get_tensor_ptr(tensor) + n_elems = nccl_util.get_tensor_n_elements(tensor) + reduce_op = nccl_util.get_nccl_reduce_op(allreduce_options.reduceOp) + + # in-place allreduce + comm.allReduce(ptr, ptr, n_elems, dtype, reduce_op, stream.ptr) + + def barrier(self, barrier_options=BarrierOptions()): + """ + Blocks until all processes reach this barrier. + + Args: + barrier_options: + + Returns: + """ + self.allreduce(self._barrier_tensor) + + def _get_nccl_communicator(self): + """ + Create or use a cached NCCL communicator for the collective task. + + """ + # TODO(Hao): later change this to use device keys and query from cache. + # TODO(Hao): implement a thin wrapper + if not self._nccl_comm: + self._nccl_comm = nccl_util.create_nccl_communicator( + self.world_size, self.nccl_uid, self.rank) + return self._nccl_comm + + @staticmethod + def _get_cuda_stream(): + """Obtain an idle stream from a stream pool for the collective task.""" + # TODO: implement a simple stream manager. + return cupy.cuda.Stream.null + + # def _collective_call(self, *args): + # """Private method to encapsulate all collective calls""" + # pass diff --git a/python/ray/util/collective/collective_group/nccl_util.py b/python/ray/util/collective/collective_group/nccl_util.py new file mode 100644 index 000000000000..4d2fc456fd04 --- /dev/null +++ b/python/ray/util/collective/collective_group/nccl_util.py @@ -0,0 +1,117 @@ +"""Code to wrap some NCCL API calls.""" +import numpy +try: + import cupy + from cupy.cuda import nccl + from cupy.cuda.nccl import get_version + from cupy.cuda.nccl import get_build_version + from cupy.cuda.nccl import NcclCommunicator +except ImportError: + raise ImportError("NCCL in Ray requires Cupy being available!") + +from ray.util.collective.types import ReduceOp, torch_available + +NCCL_REDUCE_OP_MAP = { + ReduceOp.SUM: nccl.NCCL_SUM, + ReduceOp.PRODUCT: nccl.NCCL_PROD, + ReduceOp.MIN: nccl.NCCL_MIN, + ReduceOp.MAX: nccl.NCCL_MAX, +} + +# cupy types are the same with numpy types +NUMPY_NCCL_DTYPE_MAP = { + numpy.uint8: nccl.NCCL_UINT8, + numpy.float16: nccl.NCCL_FLOAT16, + numpy.float32: nccl.NCCL_FLOAT32, + numpy.float64: nccl.NCCL_FLOAT64, +} + +if torch_available(): + import torch + TORCH_NCCL_DTYPE_MAP = { + torch.uint8: nccl.NCCL_UINT8, + torch.float16: nccl.NCCL_FLOAT16, + torch.float32: nccl.NCCL_FLOAT32, + torch.float64: nccl.NCCL_FLOAT64, + } + + +def get_nccl_build_version(): + return get_build_version() + + +def get_nccl_runtime_version(): + return get_version() + + +def get_nccl_unique_id(): + return nccl.get_unique_id() + + +def create_nccl_communicator(world_size, nccl_unique_id, rank): + """ + Create an NCCL communicator using NCCL APIs. + + Args: + world_size (int): the number of processes of this communcator group. + nccl_unique_id (str): the NCCLUniqueID for this group. + rank (int): the rank of this process. + Returns: + comm (nccl.ncclComm_t): an NCCL communicator. + """ + # TODO(Hao): make this inside the NCCLComm class, + # and implement the abort method. Make it RAII. + comm = NcclCommunicator(world_size, nccl_unique_id, rank) + return comm + + +def get_nccl_reduce_op(reduce_op): + """ + Map the reduce op to NCCL reduce op type. + + Args: + reduce_op (ReduceOp): ReduceOp Enum (SUM/PRODUCT/MIN/MAX). + Returns: + (nccl.ncclRedOp_t): the mapped NCCL reduce op. + """ + if reduce_op not in NCCL_REDUCE_OP_MAP: + raise RuntimeError( + "NCCL does not support reduce op: '{}'".format(reduce_op)) + return NCCL_REDUCE_OP_MAP[reduce_op] + + +def get_nccl_tensor_dtype(tensor): + """Return the corresponded NCCL dtype given a tensor.""" + if isinstance(tensor, cupy.ndarray): + return NUMPY_NCCL_DTYPE_MAP[tensor.dtype.type] + if torch_available(): + if isinstance(tensor, torch.Tensor): + return TORCH_NCCL_DTYPE_MAP[tensor.dtype] + raise ValueError("Unsupported tensor type. " + "Got: {}.".format(type(tensor))) + + +def get_tensor_ptr(tensor): + """Return the pointer to the underlying memory storage of a tensor.""" + if isinstance(tensor, cupy.ndarray): + return tensor.data.ptr + if isinstance(tensor, numpy.ndarray): + return tensor.data + if torch_available(): + if isinstance(tensor, torch.Tensor): + if not tensor.is_cuda: + raise RuntimeError("torch tensor must be on gpu.") + return tensor.data_ptr() + raise ValueError("Unsupported tensor type. " + "Got: {}.".format(type(tensor))) + + +def get_tensor_n_elements(tensor): + """Return the number of elements in a tensor.""" + if isinstance(tensor, cupy.ndarray) or isinstance(tensor, numpy.ndarray): + return tensor.size + if torch_available(): + if isinstance(tensor, torch.Tensor): + return torch.numel(tensor) + raise ValueError("Unsupported tensor type. " + "Got: {}.".format(type(tensor))) diff --git a/python/ray/util/collective/const.py b/python/ray/util/collective/const.py new file mode 100644 index 000000000000..6eded9c51cde --- /dev/null +++ b/python/ray/util/collective/const.py @@ -0,0 +1,21 @@ +""" +Constants. + +Contains constants used to setup collective groups. +""" +import hashlib + + +def get_nccl_store_name(group_name): + """ + Generate the unique name for the NCCLUniqueID store (named actor). + + Args: + group_name (str): unique user name for the store. + Return: + str: MD5-hexlified name for the store. + """ + if not group_name: + raise ValueError("group_name is None.") + hexlified_name = hashlib.md5(group_name.encode()).hexdigest() + return hexlified_name diff --git a/python/ray/util/collective/examples/nccl_allreduce_example.py b/python/ray/util/collective/examples/nccl_allreduce_example.py new file mode 100644 index 000000000000..7010d69249f2 --- /dev/null +++ b/python/ray/util/collective/examples/nccl_allreduce_example.py @@ -0,0 +1,42 @@ +import ray +import cupy as cp + +import ray.util.collective as collective + + +@ray.remote(num_gpus=1) +class Worker: + def __init__(self): + self.send = cp.ones((4, ), dtype=cp.float32) + self.recv = cp.zeros((4, ), dtype=cp.float32) + + def setup(self, world_size, rank): + collective.init_collective_group("nccl", world_size, rank, "default") + return True + + def compute(self): + collective.allreduce(self.send, "default") + print(self.send) + return self.send + + def destroy(self): + collective.destroy_group() + + +if __name__ == "__main__": + + send = cp.ones((4, ), dtype=cp.float32) + + ray.init(num_gpus=2) + + num_workers = 2 + workers = [] + init_rets = [] + for i in range(num_workers): + w = Worker.remote() + workers.append(w) + init_rets.append(w.setup.remote(num_workers, i)) + _ = ray.get(init_rets) + results = ray.get([w.compute.remote() for w in workers]) + # print(results) + ray.shutdown() diff --git a/python/ray/util/collective/requirements.txt b/python/ray/util/collective/requirements.txt new file mode 100644 index 000000000000..ce5057b221f1 --- /dev/null +++ b/python/ray/util/collective/requirements.txt @@ -0,0 +1 @@ +cupy-cuda100 \ No newline at end of file diff --git a/python/ray/util/collective/tests/__init__.py b/python/ray/util/collective/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/util/collective/tests/conftest.py b/python/ray/util/collective/tests/conftest.py new file mode 100644 index 000000000000..b84a01742bf8 --- /dev/null +++ b/python/ray/util/collective/tests/conftest.py @@ -0,0 +1,37 @@ +"""Some fixtures for collective tests.""" +import pytest + +import ray +from ray.util.collective.const import get_nccl_store_name + + +def clean_up(): + group_names = ["default", "test", "123?34!", "default2", "random"] + group_names.extend([str(i) for i in range(10)]) + for group_name in group_names: + try: + store_name = get_nccl_store_name(group_name) + actor = ray.get_actor(store_name) + except ValueError: + actor = None + if actor: + ray.kill(actor) + + +@pytest.fixture +def ray_start_single_node_2_gpus(): + # Please start this fixture in a cluster with 2 GPUs. + address_info = ray.init(num_gpus=2) + yield address_info + ray.shutdown() + + +# Hao: this fixture is a bit tricky. +# I use a bash script to start a ray cluster on +# my own on-premise cluster before run this fixture. +@pytest.fixture +def ray_start_distributed_2_nodes_4_gpus(): + ray.init("auto") + yield + clean_up() + ray.shutdown() diff --git a/python/ray/util/collective/tests/test_collective_2_nodes_4_gpus.py b/python/ray/util/collective/tests/test_collective_2_nodes_4_gpus.py new file mode 100644 index 000000000000..c35e48b9ad96 --- /dev/null +++ b/python/ray/util/collective/tests/test_collective_2_nodes_4_gpus.py @@ -0,0 +1,276 @@ +"""Test the collective group APIs.""" +from random import shuffle +import pytest +import ray +from ray.util.collective.types import ReduceOp + +import cupy as cp +import torch + +from .util import Worker + + +def get_actors_group(num_workers=2, group_name="default", backend="nccl"): + actors = [Worker.remote() for i in range(num_workers)] + world_size = num_workers + init_results = ray.get([ + actor.init_group.remote(world_size, i, backend, group_name) + for i, actor in enumerate(actors) + ]) + return actors, init_results + + +@pytest.mark.parametrize("world_size", [2, 3, 4]) +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +def test_init_two_actors(ray_start_distributed_2_nodes_4_gpus, world_size, + group_name): + actors, results = get_actors_group(world_size, group_name) + for i in range(world_size): + assert (results[i]) + + +@pytest.mark.parametrize("world_size", [2, 3, 4]) +def test_init_multiple_groups(ray_start_distributed_2_nodes_4_gpus, + world_size): + num_groups = 1 + actors = [Worker.remote() for _ in range(world_size)] + for i in range(num_groups): + group_name = str(i) + init_results = ray.get([ + actor.init_group.remote(world_size, i, group_name=group_name) + for i, actor in enumerate(actors) + ]) + for j in range(world_size): + assert init_results[j] + + +@pytest.mark.parametrize("world_size", [2, 3, 4]) +def test_get_rank(ray_start_distributed_2_nodes_4_gpus, world_size): + actors, _ = get_actors_group(world_size) + actor0_rank = ray.get(actors[0].report_rank.remote()) + assert actor0_rank == 0 + actor1_rank = ray.get(actors[1].report_rank.remote()) + assert actor1_rank == 1 + + # create a second group with a different name, and different + # orders of ranks. + new_group_name = "default2" + ranks = list(range(world_size)) + shuffle(ranks) + _ = ray.get([ + actor.init_group.remote( + world_size, ranks[i], group_name=new_group_name) + for i, actor in enumerate(actors) + ]) + actor0_rank = ray.get(actors[0].report_rank.remote(new_group_name)) + assert actor0_rank == ranks[0] + actor1_rank = ray.get(actors[1].report_rank.remote(new_group_name)) + assert actor1_rank == ranks[1] + + +@pytest.mark.parametrize("world_size", [2, 3, 4]) +def test_get_world_size(ray_start_distributed_2_nodes_4_gpus, world_size): + actors, _ = get_actors_group(world_size) + actor0_world_size = ray.get(actors[0].report_world_size.remote()) + actor1_world_size = ray.get(actors[1].report_world_size.remote()) + assert actor0_world_size == actor1_world_size == world_size + + +def test_availability(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + actors, _ = get_actors_group(world_size) + actor0_nccl_availability = ray.get( + actors[0].report_nccl_availability.remote()) + assert actor0_nccl_availability + actor0_mpi_availability = ray.get( + actors[0].report_mpi_availability.remote()) + assert not actor0_mpi_availability + + +def test_is_group_initialized(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + actors, _ = get_actors_group(world_size) + # check group is_init + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("random")) + assert not actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("123")) + assert not actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + actor1_is_init = ray.get( + actors[0].report_is_group_initialized.remote("456")) + assert not actor1_is_init + + +def test_destroy_group(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + actors, _ = get_actors_group(world_size) + # Now destroy the group at actor0 + ray.wait([actors[0].destroy_group.remote()]) + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert not actor0_is_init + + # should go well as the group `random` does not exist at all + ray.wait([actors[0].destroy_group.remote("random")]) + + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("random")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("default")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert not actor1_is_init + for i in [2, 3]: + ray.wait([actors[i].destroy_group.remote("default")]) + + # Now reconstruct the group using the same name + init_results = ray.get([ + actor.init_group.remote(world_size, i) + for i, actor in enumerate(actors) + ]) + for i in range(world_size): + assert init_results[i] + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("world_size", [2, 3, 4]) +def test_allreduce_different_name(ray_start_distributed_2_nodes_4_gpus, + group_name, world_size): + actors, _ = get_actors_group(num_workers=world_size, group_name=group_name) + results = ray.get([a.do_work.remote(group_name) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +def test_allreduce_different_array_size(ray_start_distributed_2_nodes_4_gpus, + array_size): + world_size = 4 + actors, _ = get_actors_group(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) + for a in actors + ]) + results = ray.get([a.do_work.remote() for a in actors]) + assert (results[0] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + + +def test_allreduce_destroy(ray_start_distributed_2_nodes_4_gpus, + backend="nccl", + group_name="default"): + world_size = 4 + actors, _ = get_actors_group(world_size) + + results = ray.get([a.do_work.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + + # destroy the group and try do work, should fail + ray.wait([a.destroy_group.remote() for a in actors]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_work.remote() for a in actors]) + + # reinit the same group and all reduce + ray.get([ + actor.init_group.remote(world_size, i, backend, group_name) + for i, actor in enumerate(actors) + ]) + results = ray.get([a.do_work.remote() for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * world_size * world_size).all() + assert (results[1] == cp.ones( + (10, ), dtype=cp.float32) * world_size * world_size).all() + + +def test_allreduce_multiple_group(ray_start_distributed_2_nodes_4_gpus, + backend="nccl", + num_groups=5): + world_size = 4 + actors, _ = get_actors_group(world_size) + for group_name in range(1, num_groups): + ray.get([ + actor.init_group.remote(world_size, i, backend, str(group_name)) + for i, actor in enumerate(actors) + ]) + for i in range(num_groups): + group_name = "default" if i == 0 else str(i) + results = ray.get([a.do_work.remote(group_name) for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * (world_size**(i + 1))).all() + + +def test_allreduce_different_op(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + actors, _ = get_actors_group(world_size) + + # check product + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_work.remote(op=ReduceOp.PRODUCT) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 120).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 120).all() + + # check min + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_work.remote(op=ReduceOp.MIN) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 2).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 2).all() + + # check max + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_work.remote(op=ReduceOp.MAX) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 5).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 5).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_allreduce_different_dtype(ray_start_distributed_2_nodes_4_gpus, + dtype): + world_size = 4 + actors, _ = get_actors_group(world_size) + ray.wait([a.set_buffer.remote(cp.ones(10, dtype=dtype)) for a in actors]) + results = ray.get([a.do_work.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=dtype) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=dtype) * world_size).all() + + +def test_allreduce_torch_cupy(ray_start_distributed_2_nodes_4_gpus): + # import torch + world_size = 4 + actors, _ = get_actors_group(world_size) + ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) + results = ray.get([a.do_work.remote() for a in actors]) + assert (results[0] == cp.ones((10, )) * world_size).all() + + ray.wait([actors[0].set_buffer.remote(torch.ones(10, ))]) + ray.wait([actors[1].set_buffer.remote(cp.ones(10, ))]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_work.remote() for a in actors]) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_collective_single_node_2_gpus.py b/python/ray/util/collective/tests/test_collective_single_node_2_gpus.py new file mode 100644 index 000000000000..267375e29eb9 --- /dev/null +++ b/python/ray/util/collective/tests/test_collective_single_node_2_gpus.py @@ -0,0 +1,267 @@ +"""Test the collective group APIs.""" +import pytest +import ray +from ray.util.collective.types import ReduceOp + +import cupy as cp +import torch + +from .util import Worker + + +def get_actors_group(num_workers=2, group_name="default", backend="nccl"): + actors = [Worker.remote() for _ in range(num_workers)] + world_size = num_workers + init_results = ray.get([ + actor.init_group.remote(world_size, i, backend, group_name) + for i, actor in enumerate(actors) + ]) + return actors, init_results + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +def test_init_two_actors(ray_start_single_node_2_gpus, group_name): + world_size = 2 + actors, results = get_actors_group(world_size, group_name) + for i in range(world_size): + assert (results[i]) + + +def test_init_multiple_groups(ray_start_single_node_2_gpus): + world_size = 2 + num_groups = 10 + actors = [Worker.remote() for i in range(world_size)] + for i in range(num_groups): + group_name = str(i) + init_results = ray.get([ + actor.init_group.remote(world_size, i, group_name=group_name) + for i, actor in enumerate(actors) + ]) + for j in range(world_size): + assert init_results[j] + + +def test_get_rank(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = get_actors_group(world_size) + actor0_rank = ray.get(actors[0].report_rank.remote()) + assert actor0_rank == 0 + actor1_rank = ray.get(actors[1].report_rank.remote()) + assert actor1_rank == 1 + + # create a second group with a different name, + # and different order of ranks. + new_group_name = "default2" + _ = ray.get([ + actor.init_group.remote( + world_size, world_size - 1 - i, group_name=new_group_name) + for i, actor in enumerate(actors) + ]) + actor0_rank = ray.get(actors[0].report_rank.remote(new_group_name)) + assert actor0_rank == 1 + actor1_rank = ray.get(actors[1].report_rank.remote(new_group_name)) + assert actor1_rank == 0 + + +def test_get_world_size(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = get_actors_group(world_size) + actor0_world_size = ray.get(actors[0].report_world_size.remote()) + actor1_world_size = ray.get(actors[1].report_world_size.remote()) + assert actor0_world_size == actor1_world_size == world_size + + +def test_availability(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = get_actors_group(world_size) + actor0_nccl_availability = ray.get( + actors[0].report_nccl_availability.remote()) + assert actor0_nccl_availability + actor0_mpi_availability = ray.get( + actors[0].report_mpi_availability.remote()) + assert not actor0_mpi_availability + + +def test_is_group_initialized(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = get_actors_group(world_size) + # check group is_init + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("random")) + assert not actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("123")) + assert not actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + actor1_is_init = ray.get( + actors[0].report_is_group_initialized.remote("456")) + assert not actor1_is_init + + +def test_destroy_group(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = get_actors_group(world_size) + # Now destroy the group at actor0 + ray.wait([actors[0].destroy_group.remote()]) + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert not actor0_is_init + + # should go well as the group `random` does not exist at all + ray.wait([actors[0].destroy_group.remote("random")]) + + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("random")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("default")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert not actor1_is_init + + # Now reconstruct the group using the same name + init_results = ray.get([ + actor.init_group.remote(world_size, i) + for i, actor in enumerate(actors) + ]) + for i in range(world_size): + assert init_results[i] + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +# @pytest.mark.parametrize("group_name", ['123?34!']) +def test_allreduce_different_name(ray_start_single_node_2_gpus, group_name): + world_size = 2 + actors, _ = get_actors_group(num_workers=world_size, group_name=group_name) + results = ray.get([a.do_work.remote(group_name) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +def test_allreduce_different_array_size(ray_start_single_node_2_gpus, + array_size): + world_size = 2 + actors, _ = get_actors_group(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) + for a in actors + ]) + results = ray.get([a.do_work.remote() for a in actors]) + assert (results[0] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + + +def test_allreduce_destroy(ray_start_single_node_2_gpus, + backend="nccl", + group_name="default"): + world_size = 2 + actors, _ = get_actors_group(world_size) + + results = ray.get([a.do_work.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + + # destroy the group and try do work, should fail + ray.wait([a.destroy_group.remote() for a in actors]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_work.remote() for a in actors]) + + # reinit the same group and all reduce + ray.get([ + actor.init_group.remote(world_size, i, backend, group_name) + for i, actor in enumerate(actors) + ]) + results = ray.get([a.do_work.remote() for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * world_size * 2).all() + assert (results[1] == cp.ones( + (10, ), dtype=cp.float32) * world_size * 2).all() + + +def test_allreduce_multiple_group(ray_start_single_node_2_gpus, + backend="nccl", + num_groups=5): + world_size = 2 + actors, _ = get_actors_group(world_size) + for group_name in range(1, num_groups): + ray.get([ + actor.init_group.remote(world_size, i, backend, str(group_name)) + for i, actor in enumerate(actors) + ]) + for i in range(num_groups): + group_name = "default" if i == 0 else str(i) + results = ray.get([a.do_work.remote(group_name) for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * (world_size**(i + 1))).all() + + +def test_allreduce_different_op(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = get_actors_group(world_size) + + # check product + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_work.remote(op=ReduceOp.PRODUCT) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 6).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 6).all() + + # check min + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_work.remote(op=ReduceOp.MIN) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 2).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 2).all() + + # check max + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_work.remote(op=ReduceOp.MAX) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 3).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 3).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_allreduce_different_dtype(ray_start_single_node_2_gpus, dtype): + world_size = 2 + actors, _ = get_actors_group(world_size) + ray.wait([a.set_buffer.remote(cp.ones(10, dtype=dtype)) for a in actors]) + results = ray.get([a.do_work.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=dtype) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=dtype) * world_size).all() + + +def test_allreduce_torch_cupy(ray_start_single_node_2_gpus): + # import torch + world_size = 2 + actors, _ = get_actors_group(world_size) + ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) + results = ray.get([a.do_work.remote() for a in actors]) + assert (results[0] == cp.ones((10, )) * world_size).all() + + ray.wait([actors[0].set_buffer.remote(torch.ones(10, ))]) + ray.wait([actors[1].set_buffer.remote(cp.ones(10, ))]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_work.remote() for a in actors]) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/util.py b/python/ray/util/collective/tests/util.py new file mode 100644 index 000000000000..d59294d3f5bd --- /dev/null +++ b/python/ray/util/collective/tests/util.py @@ -0,0 +1,51 @@ +import cupy as cp + +import ray +import ray.util.collective as col +from ray.util.collective.types import Backend, ReduceOp + + +@ray.remote(num_gpus=1) +class Worker: + def __init__(self): + self.buffer = cp.ones((10, ), dtype=cp.float32) + + def init_group(self, + world_size, + rank, + backend=Backend.NCCL, + group_name="default"): + col.init_collective_group(world_size, rank, backend, group_name) + return True + + def set_buffer(self, data): + self.buffer = data + return self.buffer + + def do_work(self, group_name="default", op=ReduceOp.SUM): + col.allreduce(self.buffer, group_name, op) + return self.buffer + + def destroy_group(self, group_name="default"): + col.destroy_collective_group(group_name) + return True + + def report_rank(self, group_name="default"): + rank = col.get_rank(group_name) + return rank + + def report_world_size(self, group_name="default"): + ws = col.get_world_size(group_name) + return ws + + def report_nccl_availability(self): + avail = col.nccl_available() + return avail + + def report_mpi_availability(self): + avail = col.mpi_available() + return avail + + def report_is_group_initialized(self, group_name="default"): + is_init = col.is_group_initialized(group_name) + return is_init diff --git a/python/ray/util/collective/types.py b/python/ray/util/collective/types.py new file mode 100644 index 000000000000..ef037373a4f9 --- /dev/null +++ b/python/ray/util/collective/types.py @@ -0,0 +1,64 @@ +"""Types conversion between different backends.""" +from enum import Enum +from dataclasses import dataclass +from datetime import timedelta + +_NUMPY_AVAILABLE = True +_TORCH_AVAILABLE = True +_CUPY_AVAILABLE = True + +try: + import torch as th # noqa: F401 +except ImportError: + _TORCH_AVAILABLE = False + +try: + import cupy as cp # noqa: F401 +except ImportError: + _CUPY_AVAILABLE = False + + +def cupy_available(): + return _CUPY_AVAILABLE + + +def torch_available(): + return _TORCH_AVAILABLE + + +class Backend(object): + """A class to represent different backends.""" + NCCL = "nccl" + MPI = "mpi" + UNRECOGNIZED = "unrecognized" + + def __new__(cls, name: str): + backend = getattr(Backend, name.upper(), Backend.UNRECOGNIZED) + if backend == Backend.UNRECOGNIZED: + raise ValueError("Unrecognized backend: '{}'. " + "Only NCCL is supported".format(name)) + if backend == Backend.MPI: + raise NotImplementedError() + return backend + + +# TODO(Hao): extend this to support more MPI types +class ReduceOp(Enum): + SUM = 0 + PRODUCT = 1 + MIN = 2 + MAX = 3 + + +unset_timeout = timedelta(milliseconds=-1) + + +@dataclass +class AllReduceOptions: + reduceOp = ReduceOp.SUM + timeout = unset_timeout + + +@dataclass +class BarrierOptions: + timeout = unset_timeout diff --git a/python/ray/util/collective/util.py b/python/ray/util/collective/util.py new file mode 100644 index 000000000000..e591e9b93f0b --- /dev/null +++ b/python/ray/util/collective/util.py @@ -0,0 +1,42 @@ +"""Some utility class for Collectives.""" +import ray +import logging + +logger = logging.getLogger(__name__) + + +@ray.remote +class NCCLUniqueIDStore: + """NCCLUniqueID Store as a named actor class. + + Args: + name (str): the unique name for this named actor. + + Attributes: + name (str): the unique name for this named actor. + nccl_id (str): the NCCLUniqueID held in this store. + """ + + def __init__(self, name): + self.name = name + self.nccl_id = None + + def set_id(self, uid): + """ + Initialize the NCCL unique ID for this store. + + Args: + uid (str): the unique ID generated via the NCCL get_unique_id API. + + Returns: + None + """ + self.nccl_id = uid + return self.nccl_id + + def get_id(self): + """Get the NCCL unique ID held in this store.""" + if not self.nccl_id: + logger.warning("The NCCL ID has not been " + "set yet for store {}.".format(self.name)) + return self.nccl_id