Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
18ea0e2
scaffold of the code
Nov 25, 2020
bf1051c
some scratch and options change
DachengLi1 Nov 26, 2020
3c5628a
NCCL mostly done, supporting API#1
Dec 1, 2020
0714c4a
interface 2.1 2.2 scratch
DachengLi1 Dec 1, 2020
20df179
put code into ray and fix some importing issues
Dec 1, 2020
5267df1
add an addtional Rendezvous class to safely meet at named actor
Dec 1, 2020
8ff63ad
fix some small bugs in nccl_util
Dec 1, 2020
88fbea1
some small fix
Dec 2, 2020
1e66354
scaffold of the code
Nov 25, 2020
c41f046
some scratch and options change
DachengLi1 Nov 26, 2020
912bd0f
NCCL mostly done, supporting API#1
Dec 1, 2020
5db388f
interface 2.1 2.2 scratch
DachengLi1 Dec 1, 2020
bd91da9
put code into ray and fix some importing issues
Dec 1, 2020
d971237
add an addtional Rendezvous class to safely meet at named actor
Dec 1, 2020
3f2f86b
fix some small bugs in nccl_util
Dec 1, 2020
135b9ec
some small fix
Dec 2, 2020
03e49e7
add a Backend class to make Backend string more robust
Dec 2, 2020
ec02002
fix some conflicts
Dec 2, 2020
5588322
add several useful APIs
Dec 2, 2020
49e59a3
add some tests
Dec 4, 2020
be40e84
added allreduce test
DachengLi1 Dec 4, 2020
0133c6a
fix typos
DachengLi1 Dec 4, 2020
cbeaafe
fix several bugs found via unittests
Dec 4, 2020
893142d
fix and update torch test
DachengLi1 Dec 4, 2020
ec1c07a
changed back actor
DachengLi1 Dec 4, 2020
8f15ba4
rearange a bit before importing distributed test
Dec 5, 2020
5b40ec3
add distributed test
Dec 5, 2020
c76a645
merge master
Dec 5, 2020
793830c
remove scratch code
zhisbug Dec 5, 2020
f8587df
auto-linting
zhisbug Dec 5, 2020
d7e4aee
linting 2
Dec 5, 2020
cd62a50
linting 2
Dec 5, 2020
bdb90de
linting 3
Dec 5, 2020
63973ec
linting 4
Dec 5, 2020
e027891
linting 5
zhisbug Dec 5, 2020
4136fa9
linting 6
Dec 5, 2020
970f99a
fix a small bug in example and linting
Dec 7, 2020
6b46e54
address Richard comments
Dec 12, 2020
658bbd3
run format checker
zhisbug Dec 12, 2020
175cabe
remove a script
Dec 12, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
__all__ = [
"ActorPool", "disable_log_once_globally", "enable_periodic_logging",
"iter", "log_once", "pdb", "placement_group", "placement_group_table",
"remove_placement_group", "inspect_serializability"
"remove_placement_group", "inspect_serializability", "collective"
]
9 changes: 9 additions & 0 deletions python/ray/util/collective/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .collective import nccl_available, mpi_available, is_group_initialized, \
init_collective_group, destroy_collective_group, get_rank, \
get_world_size, allreduce, barrier

__all__ = [
"nccl_available", "mpi_available", "is_group_initialized",
"init_collective_group", "destroy_collective_group", "get_rank",
"get_world_size", "allreduce", "barrier"
]
275 changes: 275 additions & 0 deletions python/ray/util/collective/collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
"""APIs exposed under the namespace ray.util.collective."""
import logging

import numpy as np
import ray
from ray.util.collective import types
from ray.util.collective.const import get_nccl_store_name

_MPI_AVAILABLE = False
_NCCL_AVAILABLE = True

# try:
# from ray.util.collective.collective_group.mpi_collective_group \
# import MPIGroup
# except ImportError:
# _MPI_AVAILABLE = False
try:
from ray.util.collective.collective_group import NCCLGroup
from ray.util.collective.collective_group import nccl_util
except ImportError:
_NCCL_AVAILABLE = False

logger = logging.getLogger(__name__)


def nccl_available():
return _NCCL_AVAILABLE


def mpi_available():
return _MPI_AVAILABLE


class GroupManager(object):
"""
Use this class to manage the collective groups we created so far.

Each process will have an instance of `GroupManager`. Each process
could belong to multiple collective groups. The membership information
and other metadata are stored in the global `_group_mgr` object.
"""

def __init__(self):
self._name_group_map = {}
self._group_name_map = {}

def create_collective_group(self, backend, world_size, rank, group_name):
"""
The entry to create new collective groups and register in the manager.

Put the registration and the group information into the manager
metadata as well.
"""
backend = types.Backend(backend)
if backend == types.Backend.MPI:
raise NotImplementedError()
elif backend == types.Backend.NCCL:
# create the ncclUniqueID
if rank == 0:
# availability has been checked before entering here.
group_uid = nccl_util.get_nccl_unique_id()
store_name = get_nccl_store_name(group_name)
# Avoid a potential circular dependency in ray/actor.py
from ray.util.collective.util import NCCLUniqueIDStore
store = NCCLUniqueIDStore.options(
name=store_name, lifetime="detached").remote(store_name)
ray.wait([store.set_id.remote(group_uid)])

logger.debug("creating NCCL group: '{}'".format(group_name))
g = NCCLGroup(world_size, rank, group_name)
self._name_group_map[group_name] = g
self._group_name_map[g] = group_name
return self._name_group_map[group_name]

def is_group_exist(self, group_name):
return group_name in self._name_group_map

def get_group_by_name(self, group_name):
"""Get the collective group handle by its name."""
if not self.is_group_exist(group_name):
logger.warning(
"The group '{}' is not initialized.".format(group_name))
return None
return self._name_group_map[group_name]

def destroy_collective_group(self, group_name):
"""Group destructor."""
if not self.is_group_exist(group_name):
logger.warning("The group '{}' does not exist.".format(group_name))
return

# release the collective group resource
g = self._name_group_map[group_name]
rank = g.rank
backend = g.backend()

# clean up the dicts
del self._group_name_map[g]
del self._name_group_map[group_name]
if backend == types.Backend.NCCL:
# release the named actor
if rank == 0:
store_name = get_nccl_store_name(group_name)
store = ray.get_actor(store_name)
ray.wait([store.__ray_terminate__.remote()])
ray.kill(store)
# Release the communicator resources
g.destroy_group()


_group_mgr = GroupManager()


def is_group_initialized(group_name):
"""Check if the group is initialized in this process by the group name."""
return _group_mgr.is_group_exist(group_name)


def init_collective_group(world_size: int,
rank: int,
backend=types.Backend.NCCL,
group_name: str = "default"):
"""
Initialize a collective group inside an actor process.

Args:
world_size (int): the total number of processed in the group.
rank (int): the rank of the current process.
backend: the CCL backend to use, NCCL or MPI.
group_name (str): the name of the collective group.

Returns:
None
"""
_check_inside_actor()
backend = types.Backend(backend)
_check_backend_availability(backend)
global _group_mgr
# TODO(Hao): implement a group auto-counter.
if not group_name:
raise ValueError("group_name '{}' needs to be a string."
.format(group_name))

if _group_mgr.is_group_exist(group_name):
raise RuntimeError("Trying to initialize a group twice.")

assert (world_size > 0)
assert (rank >= 0)
assert (rank < world_size)
_group_mgr.create_collective_group(backend, world_size, rank, group_name)


def destroy_collective_group(group_name: str = "default") -> None:
"""Destroy a collective group given its group name."""
_check_inside_actor()
global _group_mgr
_group_mgr.destroy_collective_group(group_name)


def get_rank(group_name: str = "default") -> int:
"""
Return the rank of this process in the given group.

Args:
group_name (str): the name of the group to query

Returns:
the rank of this process in the named group,
-1 if the group does not exist or the process does
not belong to the group.
"""
_check_inside_actor()
if not is_group_initialized(group_name):
return -1
g = _group_mgr.get_group_by_name(group_name)
return g.rank


def get_world_size(group_name="default") -> int:
"""
Return the size of the collective gropu with the given name.

Args:
group_name: the name of the group to query

Returns:
The world size of the collective group,
-1 if the group does not exist or the process does
not belong to the group.
"""
_check_inside_actor()
if not is_group_initialized(group_name):
return -1
g = _group_mgr.get_group_by_name(group_name)
return g.world_size


def allreduce(tensor, group_name: str, op=types.ReduceOp.SUM):
"""
Collective allreduce the tensor across the group with name group_name.

Args:
tensor: the tensor to be all-reduced on this process.
group_name (str): the collective group name to perform allreduce.
op: The reduce operation.

Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
opts = types.AllReduceOptions
opts.reduceOp = op
Comment on lines +212 to +213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this set the global variable?

can we instead create an instance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an exposed user API: it does not write; It only reads from the global variable _group_mgr.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't types.AllReduceOptions refer to a global setting?

anyways, i think this is a nit :)

g.allreduce(tensor, opts)


def barrier(group_name):
"""
Barrier all processes in the collective group.

Args:
group_name (str): the name of the group to barrier.

Returns:
None
"""
g = _check_and_get_group(group_name)
g.barrier()


def _check_and_get_group(group_name):
"""Check the existence and return the group handle."""
_check_inside_actor()
if not is_group_initialized(group_name):
raise RuntimeError("The collective group '{}' is not "
"initialized in the process.".format(group_name))
g = _group_mgr.get_group_by_name(group_name)
return g


def _check_backend_availability(backend: types.Backend):
"""Check whether the backend is available."""
if backend == types.Backend.MPI:
if not mpi_available():
raise RuntimeError("MPI is not available.")
elif backend == types.Backend.NCCL:
# expect some slowdown at the first call
# as I defer the import to invocation.
if not nccl_available():
raise RuntimeError("NCCL is not available.")


def _check_single_tensor_input(tensor):
"""Check if the tensor is with a supported type."""
if isinstance(tensor, np.ndarray):
return
if types.cupy_available():
if isinstance(tensor, types.cp.ndarray):
return
if types.torch_available():
if isinstance(tensor, types.th.Tensor):
return
raise RuntimeError("Unrecognized tensor type '{}'. Supported types are: "
"np.ndarray, torch.Tensor, cupy.ndarray.".format(
type(tensor)))


def _check_inside_actor():
"""Check if currently it is inside a Ray actor/task."""
worker = ray.worker.global_worker
if worker.mode == ray.WORKER_MODE:
return
else:
raise RuntimeError("The collective APIs shall be only used inside "
"a Ray actor or task.")
3 changes: 3 additions & 0 deletions python/ray/util/collective/collective_group/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .nccl_collective_group import NCCLGroup

__all__ = ["NCCLGroup"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Abstract class for collective groups."""
from abc import ABCMeta
from abc import abstractmethod

from ray.util.collective.types import AllReduceOptions, BarrierOptions


class BaseGroup(metaclass=ABCMeta):
def __init__(self, world_size, rank, group_name):
"""
Init the process group with basic information.

Args:
world_size (int): The total number of processes in the group.
rank (int): The rank of the current process.
group_name (str): The group name.
"""
self._world_size = world_size
self._rank = rank
self._group_name = group_name

@property
def rank(self):
"""Return the rank of the current process."""
return self._rank

@property
def world_size(self):
"""Return the number of processes in this group."""
return self._world_size

@property
def group_name(self):
"""Return the group name of this group."""
return self._group_name

def destroy_group(self):
"""GC the communicators."""
pass

@classmethod
def backend(cls):
"""The backend of this collective group."""
raise NotImplementedError()

@abstractmethod
def allreduce(self, tensor, allreduce_options=AllReduceOptions()):
raise NotImplementedError()

@abstractmethod
def barrier(self, barrier_options=BarrierOptions()):
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Implementation of the MPI collective group."""
try:
import mpi4py # noqa: F401
except ImportError:
raise
Loading