diff --git a/doc/source/ray-core/direct-transport.rst b/doc/source/ray-core/direct-transport.rst index e3ea64d5bac2..1387dab73393 100644 --- a/doc/source/ray-core/direct-transport.rst +++ b/doc/source/ray-core/direct-transport.rst @@ -12,12 +12,12 @@ For example, passing a CUDA ``torch.Tensor`` from one Ray task to another would *Ray Direct Transport (RDT)* is a new feature that allows Ray to store and pass objects directly between Ray actors. This feature augments the familiar Ray :class:`ObjectRef ` API by: -- Keeping GPU data in GPU memory until a transfer is needed +- Keeping GPU data in GPU memory until a transfer is necessary - Avoiding expensive serialization and copies to and from the Ray object store - Using efficient data transports like collective communication libraries (`Gloo `__ or `NCCL `__) or point-to-point RDMA (via `NVIDIA's NIXL `__) to transfer data directly between devices, including both CPU and GPUs .. note:: - RDT is currently in **alpha**. Not all Ray Core APIs are supported yet. Future releases may introduce breaking API changes. See the :ref:`limitations ` section for more details. + RDT is currently in **alpha** and doesn't support all Ray Core APIs yet. Future releases may introduce breaking API changes. See the :ref:`limitations ` section for more details. Getting started =============== @@ -290,12 +290,6 @@ For collective-based tensor transports (Gloo and NCCL): * Similarly, the process that created the collective group cannot serialize and pass RDT :class:`ray.ObjectRefs ` to other Ray tasks or actors. Instead, the :class:`ray.ObjectRef`\s can only be passed as direct arguments to other actor tasks, and those actors must be in the same collective group. * Each actor can only be in one collective group per tensor transport at a time. * No support for :func:`ray.put `. -* If a system-level error occurs during a collective operation, the collective group will be destroyed and the actors will no longer be able to communicate via the collective group. Note that application-level errors, i.e. exceptions raised by user code, will not destroy the collective group and will instead be propagated to any dependent task(s), as for non-RDT Ray objects. System-level errors include: - - * Errors internal to the third-party transport, e.g., NCCL network errors - * Actor and node failure - * Tensors returned by the user that are located on an unsupported device, e.g., a CPU tensor when using NCCL - * Any unexpected system bugs Due to a known issue, for NIXL, we currently do not support storing different GPU objects at the same actor, where the objects contain an overlapping but not equal set of tensors. To support this pattern, ensure that the first `ObjectRef` has gone out of scope before storing the same tensor(s) again in a second object. @@ -305,6 +299,23 @@ Due to a known issue, for NIXL, we currently do not support storing different GP :start-after: __nixl_limitations_start__ :end-before: __nixl_limitations_end__ +Error handling +============== + +* Application-level errors, i.e. exceptions raised by user code, will not destroy the collective group and will instead be propagated to any dependent task(s), as for non-RDT Ray objects. + +* If a system-level error occurs during a GLOO or NCCL collective operation, the collective group will be destroyed and the actors will be killed to prevent any hanging. + +* If a system-level error occurs during a NIXL transfer, Ray or NIXL will abort the transfer with an exception and Ray will raise the exception in the dependent task or on the ray.get on the NIXL ref. + +* System-level errors include: + * Errors internal to the third-party transport, e.g., NCCL network errors + * Actor or node failures + * Transport errors due to tensor device / transport mismatches, e.g., a CPU tensor when using NCCL + * Ray object fetch timeouts (can be overridden by setting the ``RAY_fetch_fail_timeout_milliseconds`` environment variable) + * Any unexpected system bugs + + Advanced: RDT Internals ======================= diff --git a/doc/source/ray-core/doc_code/direct_transport_nixl.py b/doc/source/ray-core/doc_code/direct_transport_nixl.py index dd4cd3925524..774fcd6bab7a 100644 --- a/doc/source/ray-core/doc_code/direct_transport_nixl.py +++ b/doc/source/ray-core/doc_code/direct_transport_nixl.py @@ -88,6 +88,6 @@ def sum_dict(self, dict): result2 = receiver.sum_dict.remote(ref2) try: print(ray.get(result2)) -except ActorDiedError as e: +except ValueError as e: print("Error caught:", e) # __nixl_limitations_end__ diff --git a/python/ray/actor.py b/python/ray/actor.py index 8f337e7a8d62..47a22d6aaac2 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -1180,6 +1180,7 @@ def _process_option_dict(actor_options, has_tensor_transport_methods): if _filled_options.get("concurrency_groups", None) is None: _filled_options["concurrency_groups"] = {} _filled_options["concurrency_groups"]["_ray_system"] = 1 + _filled_options["concurrency_groups"]["_ray_system_error"] = 1 return _filled_options diff --git a/python/ray/experimental/collective/collective_tensor_transport.py b/python/ray/experimental/collective/collective_tensor_transport.py index fd02a6645b1a..78f97d64c87a 100644 --- a/python/ray/experimental/collective/collective_tensor_transport.py +++ b/python/ray/experimental/collective/collective_tensor_transport.py @@ -26,6 +26,10 @@ def tensor_transport_backend(self) -> Backend: def is_one_sided() -> bool: return False + @staticmethod + def can_abort_transport() -> bool: + return False + def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool: from ray.experimental.collective import get_collective_groups @@ -137,6 +141,7 @@ def get_communicator_metadata( @staticmethod def recv_multiple_tensors( tensors, + obj_id: str, tensor_transport_metadata: CollectiveTransportMetadata, communicator_metadata: CollectiveCommunicatorMetadata, ): @@ -183,3 +188,12 @@ def garbage_collect( obj_id: str, tensor_transport_meta: CollectiveTransportMetadata ): pass + + @staticmethod + def abort_transport( + obj_id: str, + communicator_metadata: CollectiveCommunicatorMetadata, + ): + raise NotImplementedError( + "Collective transport does not support abort_transport for now." + ) diff --git a/python/ray/experimental/collective/nixl_tensor_transport.py b/python/ray/experimental/collective/nixl_tensor_transport.py index 021f5225e57f..6ab5b9d03ee3 100644 --- a/python/ray/experimental/collective/nixl_tensor_transport.py +++ b/python/ray/experimental/collective/nixl_tensor_transport.py @@ -24,6 +24,10 @@ def tensor_transport_backend(self) -> Backend: def is_one_sided() -> bool: return True + @staticmethod + def can_abort_transport() -> bool: + return True + def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool: def __ray_actor_has_tensor_transport__( self: "ray.actor.ActorHandle", @@ -134,6 +138,7 @@ def get_communicator_metadata( @staticmethod def recv_multiple_tensors( tensors, + obj_id: str, tensor_transport_metadata: NixlTransportMetadata, communicator_metadata: NixlCommunicatorMetadata, ): @@ -152,6 +157,7 @@ def recv_multiple_tensors( g.recv( tensors, + obj_id, tensor_transport_metadata.nixl_serialized_descs, tensor_transport_metadata.nixl_agent_meta, ) @@ -178,3 +184,14 @@ def garbage_collect(obj_id: str, tensor_transport_meta: NixlTransportMetadata): if descs is not None: nixl_backend = get_group_handle(NIXL_GROUP_NAME) nixl_backend.deregister_memory(descs) + + @staticmethod + def abort_transport( + obj_id: str, + communicator_metadata: NixlCommunicatorMetadata, + ): + from ray.util.collective.collective import get_group_handle + + g = get_group_handle(communicator_metadata.communicator_name) + if g: + g.abort(obj_id) diff --git a/python/ray/experimental/collective/tensor_transport_manager.py b/python/ray/experimental/collective/tensor_transport_manager.py index d910f1efec1d..7f0210b7d9ca 100644 --- a/python/ray/experimental/collective/tensor_transport_manager.py +++ b/python/ray/experimental/collective/tensor_transport_manager.py @@ -31,6 +31,17 @@ def is_one_sided() -> bool: bool: True if the backend is one-sided, False otherwise. """ + @staticmethod + @abstractmethod + def can_abort_transport() -> bool: + """ + Whether the backend can abort the transport. + If this returns False, then Ray will kill involved actors upon system errors to avoid hanging. + + Returns: + bool: True if the backend can abort the transport. + """ + @abstractmethod def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool: """Whether the actor has the tensor transport available. @@ -102,6 +113,7 @@ def get_communicator_metadata( @abstractmethod def recv_multiple_tensors( tensors: List["torch.Tensor"], + obj_id: str, tensor_transport_metadata: TensorTransportMetadata, communicator_metadata: CommunicatorMetadata, ): @@ -110,6 +122,7 @@ def recv_multiple_tensors( Args: tensors: The pre-allocated tensor space to receive the tensors. + obj_id: The object ID for related GPU object. tensor_transport_metadata: The tensor transport metadata for the GPU object. communicator_metadata: The communicator metadata for the send/recv operation. @@ -139,3 +152,17 @@ def garbage_collect(obj_id: str, tensor_transport_meta: TensorTransportMetadata) obj_id: The ID of the GPU object to garbage collect. tensor_transport_meta: The tensor transport metadata. """ + + @staticmethod + @abstractmethod + def abort_transport( + obj_id: str, + communicator_metadata: CommunicatorMetadata, + ): + """ + Abort the transport. + + Args: + obj_id: The object ID for related GPU object. + communicator_metadata: The communicator metadata for the send/recv operation. + """ diff --git a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py index cb46a21358ac..3c2cd2526970 100644 --- a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py +++ b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py @@ -47,6 +47,7 @@ class TransferMetadata(NamedTuple): recv_ref: ObjectRef communicator_meta: "CommunicatorMetadata" backend: str + obj_id: str timeout: float @@ -179,28 +180,59 @@ def _abort_transport( Cleans up the ref_info_map, kill the src and dst actors, and destroy the collective group if necessary. """ - from ray.experimental.collective import destroy_collective_group + from ray.experimental.collective import ( + destroy_collective_group, + get_tensor_transport_manager, + ) + from ray.experimental.gpu_object_manager.gpu_object_store import ( + __ray_abort_transport__, + ) from ray.util.collective.types import CollectiveCommunicatorMetadata ref_info = ref_info_map.pop(failed_ref.hex(), None) if ref_info is None: return - logger.error( - "RDT transfer with src actor %s and dst actor %s failed. Killing the actors. " - "Transfer failed with exception: %s", - ref_info.src_actor, - ref_info.dst_actor, - exception, - ) - if ref_info.send_ref: ref_info_map.pop(ref_info.send_ref.hex(), None) ref_info_map.pop(ref_info.recv_ref.hex(), None) - # TODO(#51276): Kill all actors in the collective group when we support more collective operations - ray.kill(ref_info.src_actor) - ray.kill(ref_info.dst_actor) + tensor_transport_manager = get_tensor_transport_manager(ref_info.backend) + if tensor_transport_manager.can_abort_transport(): + if not tensor_transport_manager.is_one_sided(): + # This is dead code until we implement a NCCL abort since NIXL + # is the only abortable transport for now and is one-sided. + ref_info.src_actor.__ray_call__.options( + concurrency_group="_ray_system_error" + ).remote( + __ray_abort_transport__, + ref_info.obj_id, + ref_info.communicator_meta, + ) + ref_info.dst_actor.__ray_call__.options( + concurrency_group="_ray_system_error" + ).remote( + __ray_abort_transport__, + ref_info.obj_id, + ref_info.communicator_meta, + ) + logger.info( + "RDT transfer with src actor %s and dst actor %s failed due to %s.", + ref_info.src_actor, + ref_info.dst_actor, + exception, + ) + else: + # TODO(#51276): Kill all actors in the collective group when we support more collective operations + ray.kill(ref_info.src_actor) + ray.kill(ref_info.dst_actor) + logger.error( + "RDT transfer with src actor %s and dst actor %s failed. Killing the actors. " + "Transfer failed with exception: %s", + ref_info.src_actor, + ref_info.dst_actor, + exception, + ) # isinstance does an implicit cast and makes communicator_name inaccessible # so we have to get communicator_name before the cast. @@ -336,7 +368,7 @@ def _fetch_object( __ray_fetch_gpu_object__, obj_id ) ) - self.gpu_object_store.add_object(obj_id, tensors) + self.gpu_object_store.add_object(obj_id, tensors, is_primary=False) else: if isinstance(gpu_object_meta.tensor_transport_meta, ObjectRef): # If the tensor transport meta is an ObjectRef, gpu object manager @@ -358,7 +390,7 @@ def _fetch_object( None, None, tensor_transport_backend ) __ray_recv__( - None, obj_id, gpu_object_meta.tensor_transport_meta, communicator_meta + None, obj_id, [gpu_object_meta.tensor_transport_meta], communicator_meta ) def trigger_out_of_band_tensor_transfer( @@ -474,7 +506,7 @@ def trigger_out_of_band_tensor_transfer( ).remote( __ray_recv__, obj_id, - tensor_transport_meta, + [tensor_transport_meta], communicator_meta, ) @@ -486,6 +518,7 @@ def trigger_out_of_band_tensor_transfer( recv_ref=recv_ref, communicator_meta=communicator_meta, backend=gpu_object_meta.tensor_transport_backend, + obj_id=obj_id, timeout=time.time() + ray_constants.FETCH_FAIL_TIMEOUT_SECONDS, ) ) diff --git a/python/ray/experimental/gpu_object_manager/gpu_object_store.py b/python/ray/experimental/gpu_object_manager/gpu_object_store.py index fd0b0224402d..0242cd55a001 100644 --- a/python/ray/experimental/gpu_object_manager/gpu_object_store.py +++ b/python/ray/experimental/gpu_object_manager/gpu_object_store.py @@ -1,10 +1,12 @@ import threading from collections import defaultdict, deque from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Optional, Set, Union +import ray import ray.util.collective as collective from ray._private.custom_types import TensorTransportEnum +from ray._raylet import ObjectRef from ray.experimental.collective import get_tensor_transport_manager from ray.experimental.collective.util import device_match_transport from ray.util.collective.types import ( @@ -72,36 +74,56 @@ def __ray_send__( def __ray_recv__( self, obj_id: str, - tensor_transport_meta: TensorTransportMetadata, + tensor_transport_meta: List[Union[ObjectRef, TensorTransportMetadata]], communicator_meta: CommunicatorMetadata, ): """Helper function that runs on the dst actor to receive tensors from the src actor.""" from ray._private.worker import global_worker - backend = collective.get_group_handle(communicator_meta.communicator_name).backend() + gpu_object_store = global_worker.gpu_object_manager.gpu_object_store + try: + tensor_transport_meta: TensorTransportMetadata = ( + ray.get(tensor_transport_meta[0]) + if isinstance(tensor_transport_meta[0], ObjectRef) + else tensor_transport_meta[0] + ) + device = tensor_transport_meta.tensor_device + tensor_meta = tensor_transport_meta.tensor_meta - device = tensor_transport_meta.tensor_device - tensor_meta = tensor_transport_meta.tensor_meta + backend = collective.get_group_handle( + communicator_meta.communicator_name + ).backend() - gpu_object_store = global_worker.gpu_object_manager.gpu_object_store - if tensor_meta and not device_match_transport(device, backend): - raise ValueError( - f"Tensor transport backend {backend} does not support tensor transfer on device {device}." + if tensor_meta and not device_match_transport(device, backend): + raise ValueError( + f"Tensor transport backend {backend} does not support tensor transfer on device {device}." + ) + + tensors = [] + for meta in tensor_meta: + shape, dtype = meta + tensor = torch.empty(shape, dtype=dtype, device=device) + tensors.append(tensor) + + tensor_transport_manager = get_tensor_transport_manager(backend) + tensor_transport_manager.recv_multiple_tensors( + tensors, + obj_id, + tensor_transport_meta, + communicator_meta, ) - tensors = [] - for meta in tensor_meta: - shape, dtype = meta - tensor = torch.empty(shape, dtype=dtype, device=device) - tensors.append(tensor) + gpu_object_store.add_object(obj_id, tensors, is_primary=False) + except Exception as e: + # Store the error as a gpu object if the recv fails, + # so waiters will raise the error. + gpu_object_store.add_object(obj_id, e, is_primary=False) - tensor_transport_manager = get_tensor_transport_manager(backend) - tensor_transport_manager.recv_multiple_tensors( - tensors, - tensor_transport_meta, - communicator_meta, - ) - gpu_object_store.add_object(obj_id, tensors) +def __ray_abort_transport__(self, obj_id: str, communicator_meta: CommunicatorMetadata): + """Helper function that can run on an actor doing a send or recv to abort the transport.""" + backend = collective.get_group_handle(communicator_meta.communicator_name).backend() + tensor_transport_manager = get_tensor_transport_manager(backend) + tensor_transport_manager.abort_transport(obj_id, communicator_meta) def __ray_free__( @@ -145,6 +167,8 @@ class _GPUObject: data: List["torch.Tensor"] # Whether the GPU object is the primary copy. is_primary: bool + # If a recv failed, we store the error here. + error: Optional[Exception] = None class GPUObjectStore: @@ -194,13 +218,15 @@ def has_tensor(self, tensor: "torch.Tensor") -> bool: def get_object(self, obj_id: str) -> Optional[List["torch.Tensor"]]: with self._lock: + if self._gpu_object_store[obj_id][0].error: + raise self._gpu_object_store[obj_id][0].error return self._gpu_object_store[obj_id][0].data def add_object( self, obj_id: str, - gpu_object: List["torch.Tensor"], - is_primary: bool = False, + gpu_object: Union[List["torch.Tensor"], Exception], + is_primary: bool, ): """ Add a GPU object to the GPU object store. @@ -211,15 +237,20 @@ def add_object( is_primary: Whether the GPU object is the primary copy. """ with self._object_present_cv: - for tensor in gpu_object: - self._tensor_to_object_ids[tensor.data_ptr()].add(obj_id) - # Append to the queue instead of overwriting - self._gpu_object_store[obj_id].append( - _GPUObject( - gpu_object, - is_primary, + if isinstance(gpu_object, Exception): + self._gpu_object_store[obj_id].append( + _GPUObject([], is_primary, error=gpu_object) + ) + else: + for tensor in gpu_object: + self._tensor_to_object_ids[tensor.data_ptr()].add(obj_id) + # Append to the queue instead of overwriting + self._gpu_object_store[obj_id].append( + _GPUObject( + gpu_object, + is_primary, + ) ) - ) self._object_present_cv.notify_all() def is_primary_copy(self, obj_id: str) -> bool: @@ -355,6 +386,8 @@ def pop_object(self, obj_id: str) -> List["torch.Tensor"]: gpu_object = queue.popleft() if len(queue) == 0: del self._gpu_object_store[obj_id] + if gpu_object.error: + raise gpu_object.error for tensor in gpu_object.data: self._tensor_to_object_ids[tensor.data_ptr()].remove(obj_id) if len(self._tensor_to_object_ids[tensor.data_ptr()]) == 0: diff --git a/python/ray/tests/gpu_objects/test_gpu_objects_nixl.py b/python/ray/tests/gpu_objects/test_gpu_objects_nixl.py index e6539db895f3..f38199611390 100644 --- a/python/ray/tests/gpu_objects/test_gpu_objects_nixl.py +++ b/python/ray/tests/gpu_objects/test_gpu_objects_nixl.py @@ -4,7 +4,7 @@ import torch import ray -from ray._common.test_utils import wait_for_condition +from ray._common.test_utils import SignalActor, wait_for_condition @ray.remote(num_gpus=1, num_cpus=0, enable_tensor_transport=True) @@ -76,6 +76,10 @@ def get_num_managed_meta_nixl(self): gpu_object_manager = ray._private.worker.global_worker.gpu_object_manager return gpu_object_manager.gpu_object_store.get_num_managed_meta_nixl() + @ray.method(concurrency_group="_ray_system") + def block_background_thread(self, signal_actor): + ray.get(signal_actor.wait.remote()) + @pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 1}], indirect=True) def test_ray_get_gpu_ref_created_by_actor_task(ray_start_regular): @@ -193,5 +197,29 @@ def test_send_duplicate_tensor(ray_start_regular): ) +@pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 2}], indirect=True) +def test_nixl_abort(ray_start_regular): + actors = [GPUTestActor.remote() for _ in range(2)] + + # Trigger transfer and kill sender before the receiver starts receiving + signal_actor = SignalActor.remote() + actors[1].block_background_thread.remote(signal_actor) + ref = actors[0].echo.remote(torch.randn((100, 100)), "cuda") + result = actors[1].sum.remote(ref, "cuda") + ray.kill(actors[0]) + signal_actor.send.remote() + + with pytest.raises(ray.exceptions.RayTaskError) as excinfo: + ray.get(result) + + assert "ActorDiedError" in str(excinfo.value) + + # Try a transfer with actor[1] receiving again + new_actor = GPUTestActor.remote() + ref = new_actor.echo.remote(torch.tensor([4, 5, 6]), "cuda") + result = actors[1].sum.remote(ref, "cuda") + assert ray.get(result) == 15 + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/util/collective/collective.py b/python/ray/util/collective/collective.py index e06f2bb57d2d..5e1681015c25 100644 --- a/python/ray/util/collective/collective.py +++ b/python/ray/util/collective/collective.py @@ -2,6 +2,7 @@ import logging import os +import threading import time from typing import List @@ -82,7 +83,6 @@ class GroupManager(object): def __init__(self): self._name_group_map = {} - self._group_name_map = {} def create_collective_group( self, backend, world_size, rank, group_name, gloo_timeout @@ -132,8 +132,6 @@ def create_collective_group( raise RuntimeError(f"Unexpected backend: {backend}") self._name_group_map[group_name] = g - self._group_name_map[g] = group_name - return self._name_group_map[group_name] def is_group_exist(self, group_name): @@ -155,7 +153,6 @@ def destroy_collective_group(self, group_name): # release the collective group resource g = self._name_group_map[group_name] # clean up the dicts - del self._group_name_map[g] del self._name_group_map[group_name] # Release the communicator resources g.destroy_group() @@ -170,11 +167,16 @@ def destroy_collective_group(self, group_name): _group_mgr = GroupManager() +# This lock is used to make external calls to the _group_mgr thread-safe. +_group_mgr_lock = threading.Lock() def is_group_initialized(group_name): """Check if the group is initialized in this process by the group name.""" - return _group_mgr.is_group_exist(group_name) + global _group_mgr + global _group_mgr_lock + with _group_mgr_lock: + return _group_mgr.is_group_exist(group_name) def init_collective_group( @@ -199,19 +201,22 @@ def init_collective_group( backend = types.Backend(backend) _check_backend_availability(backend) global _group_mgr + global _group_mgr_lock + # TODO(Hao): implement a group auto-counter. if not group_name: raise ValueError("group_name '{}' needs to be a string.".format(group_name)) - if _group_mgr.is_group_exist(group_name): - raise RuntimeError("Trying to initialize a group twice.") + with _group_mgr_lock: + if _group_mgr.is_group_exist(group_name): + raise RuntimeError("Trying to initialize a group twice.") - assert world_size > 0 - assert rank >= 0 - assert rank < world_size - _group_mgr.create_collective_group( - backend, world_size, rank, group_name, gloo_timeout - ) + assert world_size > 0 + assert rank >= 0 + assert rank < world_size + _group_mgr.create_collective_group( + backend, world_size, rank, group_name, gloo_timeout + ) def create_collective_group( @@ -284,7 +289,9 @@ def destroy_collective_group(group_name: str = "default") -> None: """Destroy a collective group given its group name.""" _check_inside_actor() global _group_mgr - _group_mgr.destroy_collective_group(group_name) + global _group_mgr_lock + with _group_mgr_lock: + _group_mgr.destroy_collective_group(group_name) def get_rank(group_name: str = "default") -> int: @@ -299,10 +306,14 @@ def get_rank(group_name: str = "default") -> int: not belong to the group. """ _check_inside_actor() - if not is_group_initialized(group_name): - return -1 - g = _group_mgr.get_group_by_name(group_name) - return g.rank + + global _group_mgr + global _group_mgr_lock + with _group_mgr_lock: + if not _group_mgr.is_group_exist(group_name): + return -1 + g = _group_mgr.get_group_by_name(group_name) + return g.rank def get_collective_group_size(group_name: str = "default") -> int: @@ -316,10 +327,13 @@ def get_collective_group_size(group_name: str = "default") -> int: not exist or the process does not belong to the group. """ _check_inside_actor() - if not is_group_initialized(group_name): - return -1 - g = _group_mgr.get_group_by_name(group_name) - return g.world_size + global _group_mgr + global _group_mgr_lock + with _group_mgr_lock: + if not _group_mgr.is_group_exist(group_name): + return -1 + g = _group_mgr.get_group_by_name(group_name) + return g.world_size def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM): @@ -747,47 +761,49 @@ def get_group_handle(group_name: str = "default"): if group_name != types.NIXL_GROUP_NAME: _check_inside_actor() global _group_mgr - if not is_group_initialized(group_name): - # try loading from remote info store - try: - if group_name == types.NIXL_GROUP_NAME: - _group_mgr.create_collective_group( - types.Backend.NIXL, None, None, group_name, None - ) - else: - # if the information is stored in an Info object, - # get and create the group. - name = "info_" + group_name - mgr = ray.get_actor(name=name) - ids, world_size, rank, backend, gloo_timeout = ray.get( - mgr.get_info.remote() - ) - worker = ray._private.worker.global_worker - id_ = worker.core_worker.get_actor_id() - r = rank[ids.index(id_)] - _group_mgr.create_collective_group( - backend, world_size, r, group_name, gloo_timeout - ) - except ValueError as exc: - # check if this group is initialized using options() - if ( - "collective_group_name" in os.environ - and os.environ["collective_group_name"] == group_name - ): - rank = int(os.environ["collective_rank"]) - world_size = int(os.environ["collective_world_size"]) - backend = os.environ["collective_backend"] - gloo_timeout = os.getenv("collective_gloo_timeout", 30000) - _group_mgr.create_collective_group( - backend, world_size, rank, group_name, gloo_timeout - ) - else: - raise RuntimeError( - "The collective group '{}' is not " - "initialized in the process.".format(group_name) - ) from exc - g = _group_mgr.get_group_by_name(group_name) - return g + global _group_mgr_lock + with _group_mgr_lock: + if not _group_mgr.is_group_exist(group_name): + # try loading from remote info store + try: + if group_name == types.NIXL_GROUP_NAME: + _group_mgr.create_collective_group( + types.Backend.NIXL, None, None, group_name, None + ) + else: + # if the information is stored in an Info object, + # get and create the group. + name = "info_" + group_name + mgr = ray.get_actor(name=name) + ids, world_size, rank, backend, gloo_timeout = ray.get( + mgr.get_info.remote() + ) + worker = ray._private.worker.global_worker + id_ = worker.core_worker.get_actor_id() + r = rank[ids.index(id_)] + _group_mgr.create_collective_group( + backend, world_size, r, group_name, gloo_timeout + ) + except ValueError as exc: + # check if this group is initialized using options() + if ( + "collective_group_name" in os.environ + and os.environ["collective_group_name"] == group_name + ): + rank = int(os.environ["collective_rank"]) + world_size = int(os.environ["collective_world_size"]) + backend = os.environ["collective_backend"] + gloo_timeout = os.getenv("collective_gloo_timeout", 30000) + _group_mgr.create_collective_group( + backend, world_size, rank, group_name, gloo_timeout + ) + else: + raise RuntimeError( + "The collective group '{}' is not " + "initialized in the process.".format(group_name) + ) from exc + g = _group_mgr.get_group_by_name(group_name) + return g def _check_single_tensor_input(tensor): diff --git a/python/ray/util/collective/collective_group/nixl_backend.py b/python/ray/util/collective/collective_group/nixl_backend.py index c26d3b4b1a9d..beff753b055a 100644 --- a/python/ray/util/collective/collective_group/nixl_backend.py +++ b/python/ray/util/collective/collective_group/nixl_backend.py @@ -1,3 +1,4 @@ +import threading import time from typing import TYPE_CHECKING, Any, List, Tuple @@ -31,6 +32,8 @@ def __init__(self): actor_id = f"RAY-DRIVER-{uuid.uuid4()}" self._nixl_agent = nixl_agent(actor_id, agent_config) + self._aborted_transfer_obj_ids = set() + self._aborted_transfer_obj_ids_lock = threading.Lock() @classmethod def backend(cls): @@ -44,6 +47,7 @@ def backend(cls): def recv( self, tensors: List["torch.Tensor"], + obj_id: str, nixl_serialized_descs: bytes, remote_nixl_agent_meta: bytes, ): @@ -51,44 +55,67 @@ def recv( Args: tensors: List of tensors to receive into. + obj_id: The object ID for related GPU object. nixl_serialized_descs: Serialized NIXL descriptors for the remote tensors. remote_nixl_agent_meta: Metadata about the remote NIXL agent. Raises: RuntimeError: If the NIXL transfer enters an error state. """ - nixl_agent = self._nixl_agent - remote_descs = nixl_agent.deserialize_descs(nixl_serialized_descs) - local_descs = nixl_agent.register_memory(tensors) - remote_name = nixl_agent.add_remote_agent(remote_nixl_agent_meta) - - xfer_handle = nixl_agent.initialize_xfer( - # "UUID" here is just a placeholder, can be any bytes, but without it, - # nixl will fail to transfer multiple times. - "READ", - local_descs.trim(), - remote_descs, - remote_name, - "UUID", - ) - - state = nixl_agent.transfer(xfer_handle) - if state == "ERR": - raise RuntimeError("NIXL transfer got to Error state.") - # Since current nixl does not provide a better way, we need to check the state of - # the transfer continuously. - while True: - state = nixl_agent.check_xfer_state(xfer_handle) + with self._aborted_transfer_obj_ids_lock: + if obj_id in self._aborted_transfer_obj_ids: + self._aborted_transfer_obj_ids.remove(obj_id) + raise RuntimeError(f"NIXL transfer aborted for object id: {obj_id}") + + local_descs = None + remote_name = None + xfer_handle = None + try: + nixl_agent = self._nixl_agent + remote_descs = nixl_agent.deserialize_descs(nixl_serialized_descs) + local_descs = nixl_agent.register_memory(tensors) + remote_name = nixl_agent.add_remote_agent(remote_nixl_agent_meta) + + xfer_handle = nixl_agent.initialize_xfer( + # "UUID" here is just a placeholder, can be any bytes, but without it, + # nixl will fail to transfer multiple times. + "READ", + local_descs.trim(), + remote_descs, + remote_name, + "UUID", + ) + + state = nixl_agent.transfer(xfer_handle) if state == "ERR": raise RuntimeError("NIXL transfer got to Error state.") - if state == "PROC": - time.sleep(0.001) # Avoid busy waiting - elif state == "DONE": - break - - nixl_agent.release_xfer_handle(xfer_handle) - nixl_agent.deregister_memory(local_descs) - nixl_agent.remove_remote_agent(remote_name) + # Since current nixl does not provide a better way, we need to check the state of + # the transfer continuously. + while True: + state = nixl_agent.check_xfer_state(xfer_handle) + if state == "ERR": + raise RuntimeError("NIXL transfer got to Error state.") + if state == "PROC": + with self._aborted_transfer_obj_ids_lock: + if obj_id in self._aborted_transfer_obj_ids: + self._aborted_transfer_obj_ids.remove(obj_id) + raise RuntimeError( + f"NIXL transfer aborted for object id: {obj_id}" + ) + time.sleep(0.001) # Avoid busy waiting + elif state == "DONE": + break + finally: + # We could raise errors or NIXL could raise errors like NIXL_ERR_REMOTE_DISCONNECT, + # so doing best effort cleanup. + with self._aborted_transfer_obj_ids_lock: + self._aborted_transfer_obj_ids.discard(obj_id) + if xfer_handle: + nixl_agent.release_xfer_handle(xfer_handle) + if remote_name: + nixl_agent.remove_remote_agent(remote_name) + if local_descs: + nixl_agent.deregister_memory(local_descs) def get_nixl_metadata( self, tensors: List["torch.Tensor"] @@ -114,3 +141,7 @@ def get_nixl_metadata( def deregister_memory(self, descs: Any): self._nixl_agent.deregister_memory(descs) + + def abort(self, obj_id: str): + with self._aborted_transfer_obj_ids_lock: + self._aborted_transfer_obj_ids.add(obj_id)