diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index 287c7143f7a1..23be0787143a 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -166,7 +166,7 @@ def with_tensor_transport( _static_shape=_static_shape, _direct_return=_direct_return, ) - elif transport == "nccl": + elif transport == "nccl" or transport == "hccl": self._type_hint = TorchTensorType( transport=transport, _static_shape=_static_shape, @@ -175,7 +175,7 @@ def with_tensor_transport( else: if not isinstance(transport, Communicator): raise ValueError( - "transport must be 'auto', 'nccl' or a Communicator type" + "transport must be 'auto', 'nccl', 'hccl' or a Communicator type" ) self._type_hint = TorchTensorType( transport=transport, diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index 310a13a4e5ab..5fc0bf5ffe83 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -27,6 +27,8 @@ ) from ray.tests.conftest import * # noqa +from ray.air._internal.device_manager.npu import NPU_TORCH_PACKAGE_AVAILABLE + from ray.experimental.util.types import ReduceOp logger = logging.getLogger(__name__) @@ -1637,6 +1639,87 @@ def recv(self, tensor): compiled_dag.teardown() +NPU_DEVICES = "0,1,2,3,4,5,6,7" + + +@ray.remote(resources={"NPU": 1}) +class TorchTensorWorkerNPU: + # NOTE(zhilong): To run NPU test, we need to change + # "from ray.experimental.channel.nccl_group import _NcclGroup" + # to "from ray.experimental.channel.hccl_group import _HcclGroup" + # in "python/ray/experimental/channel/torch_tensor_nccl_channel.py" + # and also disable All GPU device check. + + # TODO(zhilong): Refactor the aDAG channel so it support different + # XPUs. + + def __init__(self, rank): + import torch # noqa: F401 + + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = NPU_DEVICES + import torch_npu + + self.rank = rank + torch_npu.npu.set_device(rank) + + def send(self, shape, dtype, value: int): + import torch + + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = NPU_DEVICES + import torch_npu + + # May need to import twice to keep the context, + # otherwise it will lose the ctx. + # Different from nccl with cupy, NPU channel relies on torch, + # so we need to keep the torch ctx. + # Create and return a tensor filled with 'value' on the current NPU + torch_npu.npu.set_device(self.rank) + tensor = torch.ones(shape, dtype=dtype) * value + return tensor.to(f"npu:{self.rank}") + + def recv(self, tensor): + # Verify the tensor is on the correct device and return it as CPU tensor + tensor = tensor.cpu() + return (tensor[0].item(), tensor.shape, tensor.dtype) + + +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_npu_communication(ray_start_regular): + if not NPU_TORCH_PACKAGE_AVAILABLE: + pytest.skip("This test requires NPUs.") + + assert ( + sum(node["Resources"].get("NPU", 0) for node in ray.nodes()) > 1 + ), "This test requires at least 2 NPUs" + + # Initialize actor class with NPU support + actor_cls = TorchTensorWorkerNPU + sender = actor_cls.remote(0) + receiver = actor_cls.remote(1) + + shape = (10,) + dtype = torch.float16 + + # Define the DAG with NPU actors + with InputNode() as inp: + dag = sender.send.bind(shape, dtype, inp) + # Can use with hccl after PR 47845 merged + dag = dag.with_type_hint( + TorchTensorType(shape, dtype, transport="hccl", _direct_return=True) + ) + dag = receiver.recv.bind(dag) + + compiled_dag = dag.experimental_compile() + + # Test tensor sending and receiving on NPUs + for i in range(3): + ref = compiled_dag.execute(i) + result = ray.get(ref) + assert result == (i, shape, dtype) + + compiled_dag.teardown() + + class TestTorchTensorTypeHintCustomSerializer: # All tests inside this file are running in the same process, so we need to # manually deregister the custom serializer for `torch.Tensor` before and diff --git a/python/ray/experimental/channel/hccl_group.py b/python/ray/experimental/channel/hccl_group.py new file mode 100644 index 000000000000..3495b696c8ea --- /dev/null +++ b/python/ray/experimental/channel/hccl_group.py @@ -0,0 +1,226 @@ +import logging +import os +from typing import Optional + +import torch +import torch.distributed as dist +import torch_npu # The torch_npu for communicate + +import ray +from ray.exceptions import RayChannelError + +from ray.experimental.channel.communicator import ( + Communicator, + TorchTensorAllocator, +) +from ray.experimental.util.types import ReduceOp + +# Set ASCEND_RT_VISIBLE_DEVICES environment variable to ensure all NPUs are visible +# This enables NPU to NPU communication across devices. +# Explaination: Since currently the worker can only see the GPU/NPU asign to +# that worker, the NPU needs to see all NPUs to enable the communication channel. +os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" + +logger = logging.getLogger(__name__) + + +class _HcclGroup(Communicator): + """ + Represents an actor's HCCL communicator using NPUs. + + This is the default HCCL communicator to be used in aDAG if a + custom communicator is not provided. + + This class is not thread-safe. + """ + + def __init__( + self, + world_size: int, + comm_id: int, + rank: int, + actor_handles: list, + cuda_stream: Optional[int], + ): + # TODO(zhilong): Change cuda_stream to more general name like "stream". + """ + Initialize an HCCL communicator that can be used to communicate p2p with + other NPU actors. + + This method blocks until the same call has been made on all other + actors in the group, with the same arguments for world_size and comm_id. + + Args: + world_size: The number of participating actors/devices. + comm_id: A unique communicator ID. + rank: The rank of this actor. If None, then the caller is not a + participant of the HCCL group. + actor_handles: A list of actor handles, in rank order. + cuda_stream: Not used here but to keep same agrs with nccl_group. + """ + self._world_size = world_size + self._comm_id = comm_id + self._rank = rank + self._actor_handles = actor_handles + self._closed = False + self.real_rank = None + self.rank_map = {} + if self._rank is not None: + if not torch.distributed.is_initialized(): + self._init_dist_hccl(rank, world_size) + self.real_rank = dist.get_rank() + + def initialize(self, rank: int) -> None: + # No additional initialization is needed. + pass + + def _init_dist_hccl(self, rank, world_size): + """ + Initialize the HCCL communication group on NPUs. + + Args: + rank: The rank of the current process. + world_size: The total number of processes participating + in the communication. + """ + # Set environment variables if not already set + os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") + os.environ["HCCL_WHITELIST_DISABLE"] = os.environ.get( + "HCCL_WHITELIST_DISABLE", "1" + ) + torch_npu.npu.set_device(rank) # Set the NPU device according to the rank + self.ctx = dist.init_process_group( + backend="hccl", world_size=world_size, rank=rank + ) + dist.barrier() + + def get_actor_handles(self) -> list: + """ + Return the list of actor handles. + + Returns: + list: Actor handles in rank order. + """ + return self._actor_handles + + def get_rank(self, actor: "ray.actor.ActorHandle") -> int: + """ + Return the given actor's rank in the HCCL communicator. + + Args: + actor: The actor handle to look up. + + Returns: + int: The rank of the actor. + """ + actor_ids = [a._ray_actor_id for a in self._actor_handles] + try: + rank = actor_ids.index(actor._ray_actor_id) + except ValueError: + raise ValueError("Actor is not in the HCCL group.") + return rank + + def get_self_rank(self) -> int: + """ + Return this actor's rank. + + Returns: + int: The rank of this actor in the HCCL group. + """ + return self._rank + + def get_world_size(self) -> int: + """ + Return the number of ranks in the HCCL communicator. + + Returns: + int: The world size of the HCCL group. + """ + return self._world_size + + def send(self, tensor: "torch.Tensor", peer_rank: int) -> None: + """ + Send a tensor to a peer using HCCL. + + Args: + tensor: The tensor to be sent. + peer_rank: The rank of the peer to send the tensor to. + """ + real_self_rank = self.rank_map[self._rank] + real_peer_rank = self.rank_map[peer_rank] + if self._closed: + raise RuntimeError("HCCL group has been destroyed.") + logger.info( + f"Start to send to:{real_peer_rank}, real_self_rank: {real_self_rank} " + ) + dist.send(tensor, dst=real_peer_rank) + logger.info( + f"Finished to send to:{real_peer_rank}, real_self_rank : {real_self_rank} " + ) + + def recv( + self, + shape: tuple, + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator], + ) -> "torch.Tensor": + """ + Receive a tensor from a peer using HCCL. + + Args: + shape: The shape of the tensor to receive. + dtype: The data type of the tensor. + peer_rank: The rank of the peer to receive the tensor from. + allocator: Optional allocator to allocate memory for the tensor. + + Returns: + torch.Tensor: The received tensor. + """ + real_self_rank = self.rank_map[self._rank] + real_peer_rank = self.rank_map[peer_rank] + logger.info( + f"Start to receive, real_self_rank : {real_self_rank}, real_peer_rank:{real_peer_rank} " + ) + if self._closed: + raise RuntimeError("HCCL group has been destroyed.") + torch_npu.npu.set_device(f"npu:{real_self_rank}") + tensor = torch.zeros(*shape, dtype=dtype).to(f"npu:{real_self_rank}") + dist.recv(tensor, src=real_peer_rank) + logger.info( + f"Finished to receive, real_self_rank: {real_self_rank}, real_peer_rank:{real_peer_rank} " + ) + + if self._closed: + raise RayChannelError("HCCL group has been destroyed.") + return tensor + + def recv_stream(self): + pass + + def send_stream(self): + pass + + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp, + ) -> None: + pass + + def destroy(self) -> None: + """ + Destroy the HCCL group and clean up resources. + """ + self._closed = True + dist.destroy_process_group() + if self._rank is not None: + logger.info( + "Destructing HCCL group on actor: " + f"{ray.get_runtime_context().current_actor}" + ) + + def get_transport_name(self) -> str: + return "hccl" diff --git a/python/ray/experimental/channel/torch_tensor_nccl_channel.py b/python/ray/experimental/channel/torch_tensor_nccl_channel.py index 400126895621..2faf058fd622 100644 --- a/python/ray/experimental/channel/torch_tensor_nccl_channel.py +++ b/python/ray/experimental/channel/torch_tensor_nccl_channel.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from types import ModuleType from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union - +import os import ray import ray.util.serialization from ray.experimental.channel import ChannelContext, utils @@ -12,7 +12,6 @@ from ray.experimental.channel.communicator import Communicator from ray.experimental.channel.cpu_communicator import CPUCommunicator from ray.experimental.channel.intra_process_channel import IntraProcessChannel -from ray.experimental.channel.nccl_group import _NcclGroup from ray.experimental.channel.shared_memory_channel import SharedMemoryType from ray.experimental.channel.torch_tensor_type import TorchTensorType from ray.util.annotations import DeveloperAPI @@ -27,6 +26,22 @@ # into the program using Ray. Ray provides a default configuration at # entry/init points. logger = logging.getLogger(__name__) +USE_GPU = True +USE_NPU = False +if os.getenv("ASCEND_RT_VISIBLE_DEVICES"): + try: + from ray.experimental.channel.hccl_group import _HcclGroup as _NcclGroup + + USE_GPU = False + USE_NPU = True + except Exception as e: + logger.warning( + f"Failed in import hccl_group, use nccl_group instead with exception {e}" + ) + from ray.experimental.channel.nccl_group import _NcclGroup + +else: + from ray.experimental.channel.nccl_group import _NcclGroup @dataclass @@ -259,12 +274,18 @@ def write(self, value: Any, timeout: Optional[float] = None) -> None: "return a CUDA torch.Tensor, instead found value " f"`{value}`. DAG will shut down." ) - elif not value.is_cuda: + if USE_GPU and (not value.is_cuda): raise ValueError( "Task annotated with _direct_return=True must " "return a CUDA torch.Tensor, instead found CPU tensor. " "DAG will shut down." ) + elif USE_NPU and (not value.is_npu): + raise ValueError( + "Task annotated with _direct_return=True must " + "return a NPU torch.Tensor, instead found CPU tensor. " + "DAG will shut down." + ) self._gpu_data_channel.write([value], timeout=timeout) else: self._send_cpu_and_gpu_data(value, timeout) @@ -448,13 +469,13 @@ def ensure_registered_as_writer(self): assert self._nccl_group is not None, "Actor is not part of a NCCL group" assert self._writer_registered ctx = ChannelContext.get_current() - assert ctx.torch_device.type == "cuda" + assert ctx.torch_device.type in ["cuda", "npu"] def ensure_registered_as_reader(self) -> bool: assert self._nccl_group is not None, "Actor is not part of a NCCL group" assert self._reader_registered ctx = ChannelContext.get_current() - assert ctx.torch_device.type == "cuda" + assert ctx.torch_device.type in ["cuda", "npu"] def __reduce__(self): return ( @@ -632,6 +653,12 @@ def close(self) -> None: del ctx.communicators[self._nccl_group_id] +def _do_set_rank(self, group_id, rank_map): + ctx = ChannelContext.get_current() + ctx.communicators[group_id].rank_map = rank_map + return ctx.communicators[group_id] + + def _do_init_communicator( self, group_id, @@ -645,15 +672,15 @@ def _do_init_communicator( import torch if not custom_communicator: - assert ( - ray.get_gpu_ids() - ), "Actors participating in NCCL group must have at least one GPU assigned" + assert bool(ray.get_gpu_ids()) or bool( + "NPU" in ray.cluster_resources() + ), "Actors participating in Communicator group must have at least one XPU assigned" ctx = ChannelContext.get_current() if custom_communicator is not None: custom_communicator.initialize(rank) ctx.communicators[group_id] = custom_communicator - else: + elif USE_GPU: # default to NcclGroup ctx.communicators[group_id] = _NcclGroup( world_size, @@ -663,6 +690,16 @@ def _do_init_communicator( torch.cuda.current_stream().cuda_stream, use_communication_streams, ) + else: + ctx.communicators[group_id] = _NcclGroup( + world_size, + comm_id, + rank, + actor_handles, + None, + ) + # Need return this actor in NPU to get the rank map + return ctx.communicators[group_id] def _do_destroy_communicator(self, group_id): @@ -676,13 +713,17 @@ def _do_destroy_communicator(self, group_id): def _do_check_has_gpu(self) -> bool: - return bool(ray.get_gpu_ids()) + return bool(ray.get_gpu_ids()) or bool("NPU" in ray.cluster_resources()) def _do_get_unique_nccl_id(self) -> bool: - from cupy.cuda import nccl + if "NPU" in ray.cluster_resources(): + # NPU doesn't have get_unique_id + return uuid.uuid4() + else: + from cupy.cuda import nccl - return nccl.get_unique_id() + return nccl.get_unique_id() def _get_ranks( @@ -787,7 +828,23 @@ def _init_communicator( for rank, actor in zip(ranks, actors) ] try: - ray.get(init_tasks, timeout=30) + if USE_GPU: + ray.get(init_tasks, timeout=30) + else: + # Since the NPU use torch distributed for communication. + # If the init is call outside hccl group, i.e. in vLLM, + # The rank in dist is not the same the rank in rank. + # We need a rank map to map the rank in ray to correct + # rank in dist for device. + tmp_actors = ray.get(init_tasks, timeout=30) + rank_map = dict() + for rank, tmp_actor in zip(ranks, tmp_actors): + rank_map[rank] = tmp_actor.real_rank + set_tasks = [ + actor.__ray_call__.remote(_do_set_rank, group_id, rank_map) + for actor in actors + ] + ray.get(set_tasks, timeout=30) except ray.exceptions.GetTimeoutError: logger.warning( "NCCL group creation not done after 30s. NCCL group creation may be hung." diff --git a/python/ray/experimental/channel/torch_tensor_type.py b/python/ray/experimental/channel/torch_tensor_type.py index 004aa2b43d06..a6cea184d635 100644 --- a/python/ray/experimental/channel/torch_tensor_type.py +++ b/python/ray/experimental/channel/torch_tensor_type.py @@ -18,6 +18,8 @@ class TorchTensorType(ChannelOutputType): AUTO = "auto" NCCL = "nccl" CPU = "cpu" + HCCL = "hccl" + COMMUNICATOR_TYPES = [NCCL, HCCL] def __init__( self, @@ -70,7 +72,7 @@ def __init__( self._communicator = transport transport = transport.get_transport_name() - if transport not in [self.AUTO, self.NCCL, self.CPU]: + if transport not in [self.AUTO, self.NCCL, self.CPU, self.HCCL]: raise ValueError( "`transport` must be TorchTensorType.AUTO, TorchTensorType.NCCL, " "or TorchTensorType.CPU" @@ -146,7 +148,7 @@ def create_channel( return typ.create_channel(writer, reader_and_node_list, driver_actor_id) def requires_nccl(self) -> bool: - return self.transport == self.NCCL + return self.transport in self.COMMUNICATOR_TYPES def get_custom_communicator(self) -> Optional[Communicator]: """