-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
Introduce RayPPCommunicator for ray-based PP #21660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
65bd157
2fed2a8
0a5349b
2f5d5ce
df51a8d
607b877
4a05fb9
9eb04dd
714f7f9
f4a6cf7
ebce477
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,247 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import uuid | ||
| from typing import Any, Optional | ||
|
|
||
| import ray | ||
| import torch | ||
| from ray.exceptions import RayChannelError | ||
| from ray.experimental.channel.communicator import (Communicator, | ||
| TorchTensorAllocator) | ||
| from torch.distributed import ReduceOp | ||
|
|
||
| from vllm.distributed.device_communicators.base_device_communicator import ( | ||
| DeviceCommunicatorBase) | ||
| from vllm.distributed.parallel_state import get_pp_group | ||
| from vllm.logger import init_logger | ||
| from vllm.utils import current_stream | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class RayPPCommunicator(Communicator): | ||
| """ | ||
| Communicator to be used for pipeline parallelism in Ray Compiled Graph. | ||
| This is wraps around the vLLM _PP GroupCoordinator. | ||
| This class is not thread-safe. | ||
| """ | ||
|
|
||
| _comm: Optional[DeviceCommunicatorBase] | ||
|
|
||
| def __init__( | ||
| self, | ||
| world_size: int, | ||
| comm_id: Any, | ||
| rank: Optional[int], | ||
| actor_handles: list["ray.actor.ActorHandle"], | ||
| cuda_stream: Optional[torch.cuda.Stream], | ||
| use_communication_streams: bool = False, | ||
| ): | ||
| """ | ||
| Initialize a RayPPCommunicator that can be used to communicate with | ||
| other Ray Compiled Graph actors for pipeline parallelism. | ||
| Args: | ||
| world_size: The number of participating actors. | ||
| comm_id: A unique communicator ID. This is just to conform with | ||
| the Ray Communicator API and is not used. | ||
| rank: The rank of this actor. If None, then the caller is not a | ||
| participant of the RayPPCommunicator group (e.g., the Ray | ||
| driver). | ||
| actor_handles: A list of actor handles. | ||
| cuda_stream: A CUDA stream to dispatch communication ops to. | ||
| This is ignored since Ray always passes in the current stream | ||
| and vLLM device communicator does not allow passing in a stream | ||
| and uses the current stream by default. | ||
| """ | ||
| self._world_size = world_size | ||
| self._rank: Optional[int] = None | ||
| self._actor_handles = actor_handles | ||
| assert not use_communication_streams, ( | ||
ruisearch42 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| "use_communication_streams is not yet supported") | ||
|
|
||
| if rank is not None: | ||
| # Rank is not None, this is Ray worker | ||
| assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned" | ||
|
|
||
| self._comm = get_pp_group().device_communicator | ||
|
|
||
| # Since we wrap around the vLLM _PP communicator, we use | ||
| # the rank from the vLLM communicator, and ignore the rank | ||
| # passed in from Ray. | ||
| # TODO(rui): refactor the Ray Communicator API so that | ||
| # it also supports no rank passed in. | ||
| self._rank = self._comm.rank_in_group | ||
|
|
||
| self._build_actor_rank_mapping() | ||
| else: | ||
| # rank is None, this is Ray driver | ||
| self._comm = None | ||
|
|
||
| self._closed = False | ||
|
|
||
| def _build_actor_rank_mapping(self): | ||
| """ | ||
| Use collective communication to build a mapping from actor IDs to ranks. | ||
| This should be called once during initialization. | ||
| """ | ||
| if self._comm is None: | ||
| return {} | ||
|
|
||
| current_actor = ray.get_runtime_context().current_actor | ||
| actor_id_hash = self._hash_actor(current_actor) | ||
|
|
||
| actor_id_tensor = torch.tensor([actor_id_hash], | ||
| dtype=torch.int32, | ||
| device=self._comm.device) | ||
|
|
||
| # All-gather actor ID hashes from all actors | ||
| gathered_ids = self._comm.all_gather(actor_id_tensor, dim=0) | ||
|
|
||
| # Build mapping: actor_id_hash -> device_comm_rank | ||
| self._actor_hash_to_rank = {} | ||
| for rank, actor_id_hash in enumerate(gathered_ids.cpu().tolist()): | ||
| self._actor_hash_to_rank[actor_id_hash] = rank | ||
|
|
||
| def _hash_actor(self, actor: ray.actor.ActorHandle) -> int: | ||
| """ | ||
| Hash an actor handle to a 32-bit integer. | ||
| """ | ||
| return hash(actor._actor_id) % (2**31) | ||
ruisearch42 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def initialize(self, rank: int) -> None: | ||
| # No additional initialization is needed. | ||
| pass | ||
|
|
||
| def get_actor_handles(self) -> list["ray.actor.ActorHandle"]: | ||
| return self._actor_handles | ||
|
|
||
| def get_rank(self, actor: ray.actor.ActorHandle) -> int: | ||
| """ | ||
| Return the given actor's rank using device communicator collective ops. | ||
| """ | ||
| assert hasattr(self, '_actor_hash_to_rank'), ( | ||
| "Actor rank mapping not built. " | ||
| "This should have been done during initialization.") | ||
|
|
||
| actor_hash = self._hash_actor(actor) | ||
|
|
||
| if actor_hash in self._actor_hash_to_rank: | ||
| return self._actor_hash_to_rank[actor_hash] # type: ignore | ||
| else: | ||
| raise ValueError(f"Actor {actor} not found in communicator group") | ||
|
|
||
| def get_self_rank(self) -> Optional[int]: | ||
| """ | ||
| Return this actor's rank. | ||
| """ | ||
| return self._rank | ||
|
|
||
| def get_world_size(self) -> int: | ||
| """ | ||
| Return the number of ranks in the RayPPCommunicator group. | ||
| """ | ||
| return self._world_size | ||
|
|
||
| def send(self, buf: "torch.Tensor", peer_rank: int) -> None: | ||
| """ | ||
| Send a torch.Tensor to a peer. | ||
| This returns when the send kernel has been queued, but the kernel may | ||
| not have completed. Therefore, the caller should ensure that there are | ||
| no concurrent writes to the sent `buf` until the send has finished. | ||
| That is, either all writes should be submitted on the current stream | ||
| (self._cuda_stream) or, if on a different stream, that stream should | ||
| synchronize with the current stream. | ||
| Args: | ||
| buf: The torch.Tensor to send. It should already be on this | ||
| actor's default device. | ||
| peer_rank: The rank of the actor to send to. | ||
| """ | ||
| if self._closed: | ||
| raise RayChannelError("RayPPCommunicator has been destroyed.") | ||
|
|
||
| assert self._comm is not None | ||
| self._comm.send(buf, peer_rank) | ||
|
|
||
| def recv( | ||
| self, | ||
| shape: tuple[int], | ||
| dtype: "torch.dtype", | ||
| peer_rank: int, | ||
| allocator: TorchTensorAllocator, | ||
| ) -> "torch.Tensor": | ||
| """ | ||
| Receive a torch.Tensor from a peer and synchronize the current stream. | ||
| After this call returns, the receive buffer is safe to read from from | ||
| any stream. An RayChannelError will be raised if an error occurred | ||
| (e.g., remote actor died), and the buffer is not safe to read. | ||
| Args: | ||
| shape: The shape of the tensor to receive. | ||
| dtype: The dtype of the tensor to receive. | ||
| peer_rank: The rank of the actor to receive from. | ||
| allocator: The allocator to use to create the received tensor. | ||
| """ | ||
| if self._closed: | ||
| raise RayChannelError("RayPPCommunicator has been destroyed.") | ||
|
|
||
| assert self._comm is not None | ||
| size = torch.Size(shape) | ||
| buf = self._comm.recv(size, dtype, src=peer_rank) | ||
|
|
||
| # Buffer values are undefined if NCCL ops are aborted. Therefore, we | ||
| # need to synchronize here and check that the channel is still | ||
| # open to ensure that the receive buffer is valid. | ||
| # TODO(swang): Avoid CUDA synchronization. | ||
| current_stream().synchronize() | ||
|
|
||
| if self._closed: | ||
| raise RayChannelError("RayPPCommunicator has been destroyed.") | ||
| return buf | ||
|
|
||
| def allgather( | ||
| self, | ||
| send_buf: "torch.Tensor", | ||
| recv_buf: "torch.Tensor", | ||
| ): | ||
| raise NotImplementedError("allgather is not supported") | ||
|
|
||
| def allreduce( | ||
| self, | ||
| send_buf: "torch.Tensor", | ||
| recv_buf: "torch.Tensor", | ||
| op: ReduceOp = ReduceOp.SUM, | ||
| ): | ||
| raise NotImplementedError("allreduce is not supported") | ||
|
|
||
| def reducescatter( | ||
| self, | ||
| send_buf: "torch.Tensor", | ||
| recv_buf: "torch.Tensor", | ||
| op: ReduceOp = ReduceOp.SUM, | ||
| ): | ||
| raise NotImplementedError("reducescatter is not supported") | ||
|
|
||
| @property | ||
| def recv_stream(self): | ||
| return torch.cuda.StreamContext(current_stream()) | ||
|
|
||
| @property | ||
| def send_stream(self): | ||
| return torch.cuda.StreamContext(current_stream()) | ||
|
|
||
| def destroy(self) -> None: | ||
| # Just sets a flag, vLLM manages the lifecycle of the underlying | ||
| # _PP GroupCoordinator. | ||
| self._closed = True | ||
|
|
||
| def get_transport_name(self) -> str: | ||
| return "nccl" | ||
|
|
||
| @classmethod | ||
| def generate_communicator_id(cls) -> Any: | ||
| return uuid.uuid4() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,6 +55,7 @@ | |
| VLLM_USE_RAY_COMPILED_DAG: bool = False | ||
| VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto" | ||
| VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False | ||
| VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True | ||
| VLLM_XLA_USE_SPMD: bool = False | ||
| VLLM_WORKER_MULTIPROC_METHOD: str = "fork" | ||
| VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") | ||
|
|
@@ -492,6 +493,13 @@ def get_vllm_port() -> Optional[int]: | |
| lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0")) | ||
| ), | ||
|
|
||
| # If the env var is set, it uses a Ray Communicator wrapping | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this flag? Can we just always use vLLM's PP communicator?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The change in this PR is by default on. Wanted to have this config guard for a while so that users can fall back if needed. |
||
| # vLLM's pipeline parallelism communicator to interact with Ray's | ||
| # Compiled Graph. Otherwise, it uses Ray's NCCL communicator. | ||
| # This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set. | ||
| "VLLM_USE_RAY_WRAPPED_PP_COMM": | ||
| lambda: bool(int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1"))), | ||
|
|
||
| # Use dedicated multiprocess context for workers. | ||
| # Both spawn and fork work | ||
| "VLLM_WORKER_MULTIPROC_METHOD": | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.