Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
unique_id: Optional[ncclUniqueId] = None,
):
"""
Args:
Expand All @@ -34,6 +35,9 @@ def __init__(
it will be bind to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will
use the default library path.
unique_id: the unique id of the communicator. If None, it will
be generated at rank 0 and broadcast to all ranks. If
provided, it should be the same for all the ranks.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
Expand Down Expand Up @@ -69,23 +73,27 @@ def __init__(

logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

if self.rank == 0:
# get the unique id from NCCL
self.unique_id = self.nccl.ncclGetUniqueId()
if unique_id is not None:
self.unique_id = unique_id
else:
# construct an empty unique id
self.unique_id = ncclUniqueId()
if self.rank == 0:
# get the unique id from NCCL
self.unique_id = self.nccl.ncclGetUniqueId()
else:
# construct an empty unique id
self.unique_id = ncclUniqueId()

if not isinstance(group, StatelessProcessGroup):
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)

if not isinstance(group, StatelessProcessGroup):
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
Expand Down
247 changes: 247 additions & 0 deletions vllm/distributed/device_communicators/ray_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import uuid
from typing import Any, Optional

import ray
import torch
from ray.exceptions import RayChannelError
from ray.experimental.channel.communicator import (Communicator,
TorchTensorAllocator)
from torch.distributed import ReduceOp

from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.utils import current_stream

logger = init_logger(__name__)


class RayPPCommunicator(Communicator):
"""
Communicator to be used for pipeline parallelism in Ray Compiled Graph.
This is wraps around the vLLM _PP GroupCoordinator.
This class is not thread-safe.
"""

_comm: Optional[DeviceCommunicatorBase]

def __init__(
self,
world_size: int,
comm_id: Any,
rank: Optional[int],
actor_handles: list["ray.actor.ActorHandle"],
cuda_stream: Optional[torch.cuda.Stream],
use_communication_streams: bool = False,
):
"""
Initialize a RayPPCommunicator that can be used to communicate with
other Ray Compiled Graph actors for pipeline parallelism.
Args:
world_size: The number of participating actors.
comm_id: A unique communicator ID. This is just to conform with
the Ray Communicator API and is not used.
rank: The rank of this actor. If None, then the caller is not a
participant of the RayPPCommunicator group (e.g., the Ray
driver).
actor_handles: A list of actor handles.
cuda_stream: A CUDA stream to dispatch communication ops to.
This is ignored since Ray always passes in the current stream
and vLLM device communicator does not allow passing in a stream
and uses the current stream by default.
"""
self._world_size = world_size
self._rank: Optional[int] = None
self._actor_handles = actor_handles
assert not use_communication_streams, (
"use_communication_streams is not yet supported")

if rank is not None:
# Rank is not None, this is Ray worker
assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned"

self._comm = get_pp_group().device_communicator

# Since we wrap around the vLLM _PP communicator, we use
# the rank from the vLLM communicator, and ignore the rank
# passed in from Ray.
# TODO(rui): refactor the Ray Communicator API so that
# it also supports no rank passed in.
self._rank = self._comm.rank_in_group

self._build_actor_rank_mapping()
else:
# rank is None, this is Ray driver
self._comm = None

self._closed = False

def _build_actor_rank_mapping(self):
"""
Use collective communication to build a mapping from actor IDs to ranks.
This should be called once during initialization.
"""
if self._comm is None:
return {}

current_actor = ray.get_runtime_context().current_actor
actor_id_hash = self._hash_actor(current_actor)

actor_id_tensor = torch.tensor([actor_id_hash],
dtype=torch.int32,
device=self._comm.device)

# All-gather actor ID hashes from all actors
gathered_ids = self._comm.all_gather(actor_id_tensor, dim=0)

# Build mapping: actor_id_hash -> device_comm_rank
self._actor_hash_to_rank = {}
for rank, actor_id_hash in enumerate(gathered_ids.cpu().tolist()):
self._actor_hash_to_rank[actor_id_hash] = rank

def _hash_actor(self, actor: ray.actor.ActorHandle) -> int:
"""
Hash an actor handle to a 32-bit integer.
"""
return hash(actor._actor_id) % (2**31)

def initialize(self, rank: int) -> None:
# No additional initialization is needed.
pass

def get_actor_handles(self) -> list["ray.actor.ActorHandle"]:
return self._actor_handles

def get_rank(self, actor: ray.actor.ActorHandle) -> int:
"""
Return the given actor's rank using device communicator collective ops.
"""
assert hasattr(self, '_actor_hash_to_rank'), (
"Actor rank mapping not built. "
"This should have been done during initialization.")

actor_hash = self._hash_actor(actor)

if actor_hash in self._actor_hash_to_rank:
return self._actor_hash_to_rank[actor_hash] # type: ignore
else:
raise ValueError(f"Actor {actor} not found in communicator group")

def get_self_rank(self) -> Optional[int]:
"""
Return this actor's rank.
"""
return self._rank

def get_world_size(self) -> int:
"""
Return the number of ranks in the RayPPCommunicator group.
"""
return self._world_size

def send(self, buf: "torch.Tensor", peer_rank: int) -> None:
"""
Send a torch.Tensor to a peer.
This returns when the send kernel has been queued, but the kernel may
not have completed. Therefore, the caller should ensure that there are
no concurrent writes to the sent `buf` until the send has finished.
That is, either all writes should be submitted on the current stream
(self._cuda_stream) or, if on a different stream, that stream should
synchronize with the current stream.
Args:
buf: The torch.Tensor to send. It should already be on this
actor's default device.
peer_rank: The rank of the actor to send to.
"""
if self._closed:
raise RayChannelError("RayPPCommunicator has been destroyed.")

assert self._comm is not None
self._comm.send(buf, peer_rank)

def recv(
self,
shape: tuple[int],
dtype: "torch.dtype",
peer_rank: int,
allocator: TorchTensorAllocator,
) -> "torch.Tensor":
"""
Receive a torch.Tensor from a peer and synchronize the current stream.
After this call returns, the receive buffer is safe to read from from
any stream. An RayChannelError will be raised if an error occurred
(e.g., remote actor died), and the buffer is not safe to read.
Args:
shape: The shape of the tensor to receive.
dtype: The dtype of the tensor to receive.
peer_rank: The rank of the actor to receive from.
allocator: The allocator to use to create the received tensor.
"""
if self._closed:
raise RayChannelError("RayPPCommunicator has been destroyed.")

assert self._comm is not None
size = torch.Size(shape)
buf = self._comm.recv(size, dtype, src=peer_rank)

# Buffer values are undefined if NCCL ops are aborted. Therefore, we
# need to synchronize here and check that the channel is still
# open to ensure that the receive buffer is valid.
# TODO(swang): Avoid CUDA synchronization.
current_stream().synchronize()

if self._closed:
raise RayChannelError("RayPPCommunicator has been destroyed.")
return buf

def allgather(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
):
raise NotImplementedError("allgather is not supported")

def allreduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
):
raise NotImplementedError("allreduce is not supported")

def reducescatter(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
):
raise NotImplementedError("reducescatter is not supported")

@property
def recv_stream(self):
return torch.cuda.StreamContext(current_stream())

@property
def send_stream(self):
return torch.cuda.StreamContext(current_stream())

def destroy(self) -> None:
# Just sets a flag, vLLM manages the lifecycle of the underlying
# _PP GroupCoordinator.
self._closed = True

def get_transport_name(self) -> str:
return "nccl"

@classmethod
def generate_communicator_id(cls) -> Any:
return uuid.uuid4()
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
VLLM_XLA_USE_SPMD: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
Expand Down Expand Up @@ -492,6 +493,13 @@ def get_vllm_port() -> Optional[int]:
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0"))
),

# If the env var is set, it uses a Ray Communicator wrapping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this flag? Can we just always use vLLM's PP communicator?

Copy link
Collaborator Author

@ruisearch42 ruisearch42 Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change in this PR is by default on. Wanted to have this config guard for a while so that users can fall back if needed.
This will be cleaned up altogether after it is rolled out for a while.

# vLLM's pipeline parallelism communicator to interact with Ray's
# Compiled Graph. Otherwise, it uses Ray's NCCL communicator.
# This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_WRAPPED_PP_COMM":
lambda: bool(int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1"))),

# Use dedicated multiprocess context for workers.
# Both spawn and fork work
"VLLM_WORKER_MULTIPROC_METHOD":
Expand Down
14 changes: 14 additions & 0 deletions vllm/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,20 @@ def _compiled_ray_dag(self, enable_asyncio: bool):

forward_dag = MultiOutputNode(outputs)

if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
from ray.experimental.channel.accelerator_context import (
register_accelerator_context)

from vllm.distributed.device_communicators.ray_communicator import (
RayPPCommunicator)
register_accelerator_context(torch_module_name="cuda",
communicator_cls=RayPPCommunicator)
logger.info(
"Using vLLM PyNCCL for Ray Compiled Graph communication.")
else:
logger.info("Using Ray's NCCL communicator for "
"Ray Compiled Graph communication.")

return forward_dag.experimental_compile(
enable_asyncio=enable_asyncio,
_overlap_gpu_communication=envs.
Expand Down