diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 5b32b90f3cfe..abfad9ebfe7d 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -4,6 +4,7 @@ import multiprocessing import os +import numpy as np import pytest import torch import torch.distributed @@ -177,6 +178,38 @@ def test_pynccl_all_gather(): distributed_run(all_gather_worker_fn, 2) +@worker_fn_wrapper +def all_gatherv_worker_fn(): + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + + rank = pynccl_comm.rank + world_size = pynccl_comm.world_size + device = f'cuda:{pynccl_comm.rank}' + + assert world_size <= 8 + sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] + num_elems = sizes[rank] + tensor = torch.arange(num_elems, dtype=torch.float32, + device=device) + rank * 100 + result = torch.zeros(sum(sizes), dtype=torch.float32, device=device) + + expected = torch.cat([ + torch.arange(sizes[r], dtype=torch.float32) + r * 100 + for r in range(world_size) + ]).to(device) + + pynccl_comm.all_gatherv(result, tensor, sizes=sizes) + torch.cuda.synchronize() + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_all_gatherv(): + distributed_run(all_gatherv_worker_fn, 2) + + @worker_fn_wrapper def reduce_scatter_worker_fn(): pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, @@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter(): distributed_run(reduce_scatter_worker_fn, 2) +@worker_fn_wrapper +def reduce_scatterv_worker_fn(): + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + + rank = pynccl_comm.rank + world_size = pynccl_comm.world_size + device = f'cuda:{pynccl_comm.rank}' + + assert world_size <= 8 + sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] + num_elems = sum(sizes) + tensor = torch.arange(num_elems, dtype=torch.float32, + device=device) + rank * 100 + result = torch.zeros(sizes[rank], dtype=torch.float32, device=device) + + # Calculate expected result for this rank's chunk + all_tensors = [ + torch.arange(num_elems, dtype=torch.float32) + r * 100 + for r in range(world_size) + ] + sizes_cumsum = np.cumsum(sizes) + start = 0 if rank == 0 else sizes_cumsum[rank - 1] + end = sizes_cumsum[rank] + expected = sum(tensor[start:end] for tensor in all_tensors).to(device) + + pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes) + torch.cuda.synchronize() + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_reduce_scatterv(): + distributed_run(reduce_scatterv_worker_fn, 2) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") def test_pynccl_with_cudagraph(): diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index eb467bb0736a..dc5923cdc5a0 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading -from typing import Optional +from typing import Optional, Union from weakref import WeakValueDictionary import torch @@ -138,6 +138,14 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: input_size[dim + 1:]) return output_tensor + def all_gatherv( + self, + input_: Union[torch.Tensor, list[torch.Tensor]], + dim: int = 0, + sizes: Optional[list[int]] = None + ) -> Union[torch.Tensor, list[torch.Tensor]]: + raise NotImplementedError + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: @@ -172,6 +180,12 @@ def reduce_scatter(self, # Reshape before returning return output_tensor.movedim(0, dim).contiguous() + def reduce_scatterv(self, + input_: torch.Tensor, + dim: int = -1, + sizes: Optional[list[int]] = None) -> torch.Tensor: + raise NotImplementedError + def gather(self, input_: torch.Tensor, dst: int = 0, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 3958d566b174..e4804691f0f6 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union import torch from torch.distributed import ProcessGroup @@ -142,6 +142,42 @@ def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): # Reshape before returning return output.movedim(0, dim).contiguous() + def reduce_scatterv(self, + input_: torch.Tensor, + dim: int = -1, + sizes: Optional[list[int]] = None): + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + if sizes is not None: + assert len(sizes) == world_size + assert input_tensor.shape[0] == sum(sizes) + chunk_size = sizes[self.rank_in_group] + else: + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + if sizes is not None: + pynccl_comm.reduce_scatterv(output, input_, sizes=sizes) + else: + pynccl_comm.reduce_scatter(output, input_) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" @@ -180,6 +216,51 @@ def destroy(self): self.all2all_manager.destroy() self.all2all_manager = None + def all_gatherv(self, + input_: Union[torch.Tensor, list[torch.Tensor]], + dim: int = 0, + sizes: Optional[list[int]] = None): + if dim != 0: + raise NotImplementedError("only dim 0 all-gatherv is supported") + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None and not pynccl_comm.disabled + + # 'sizes' is not needed if all inputs in the same group have the same + # shape + if sizes is not None and all(s == sizes[0] for s in sizes): + sizes = None + + def _all_gather_single(input_: torch.Tensor, + sizes: Optional[list[int]] = None): + input_size = input_.size() + if sizes is not None: + assert len(sizes) == world_size + assert input_.shape[dim] == sizes[self.rank_in_group] + output_size = (sum(sizes), ) + input_size[1:] + else: + output_size = (input_size[0] * world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + if sizes is not None: + pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes) + else: + pynccl_comm.all_gather(output_tensor, input_) + return output_tensor + + if isinstance(input_, torch.Tensor): + return _all_gather_single(input_, sizes) + + output_list = [] + pynccl_comm.group_start() + for inp in input_: + output_list.append(_all_gather_single(inp, sizes=sizes)) + pynccl_comm.group_end() + + return output_list + def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 29486292996a..502bfd39005a 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -152,6 +152,40 @@ def all_gather(self, ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, cudaStream_t(stream.cuda_stream)) + def all_gatherv( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + sizes: list[int], + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + assert output_tensor.shape[0] == sum(sizes) + split_offset = 0 + self.nccl.ncclGroupStart() + for root, split_size in enumerate(sizes): + dst_slice = output_tensor[split_offset:split_offset + split_size] + self.nccl.ncclBroadcast( + buffer_type(input_tensor.data_ptr()), + buffer_type(dst_slice.data_ptr()), + dst_slice.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + root, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + split_offset += split_size + self.nccl.ncclGroupEnd() + def reduce_scatter(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, @@ -174,6 +208,38 @@ def reduce_scatter(self, ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) + def reduce_scatterv( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + sizes: list[int], + op: ReduceOp = ReduceOp.SUM, + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + + split_offset = 0 + self.nccl.ncclGroupStart() + for root, split_size in enumerate(sizes): + chunk = input_tensor[split_offset:split_offset + split_size, ...] + self.nccl.ncclReduce( + buffer_type(chunk.data_ptr()), + buffer_type(output_tensor.data_ptr()), chunk.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), root, self.comm, + cudaStream_t(stream.cuda_stream)) + split_offset += split_size + self.nccl.ncclGroupEnd() + def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return @@ -216,3 +282,9 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) + + def group_start(self): + self.nccl.ncclGroupStart() + + def group_end(self): + self.nccl.ncclGroupEnd() diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 3018a92da07c..a930b63bc26f 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -154,6 +154,17 @@ class NCCLLibrary: ncclRedOp_t, ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, int root, + # ncclComm_t comm, cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t + ]), + # ncclResult_t ncclAllGather( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclComm_t comm, @@ -207,6 +218,10 @@ class NCCLLibrary: # it is better not to call it at all. # ncclResult_t ncclCommDestroy(ncclComm_t comm); Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + # ncclResult_t ncclGroupStart(); + Function("ncclGroupStart", ncclResult_t, []), + # ncclResult_t ncclGroupEnd(); + Function("ncclGroupEnd", ncclResult_t, []), ] # class attribute to store the mapping from the path to the library @@ -300,6 +315,18 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, datatype, op, comm, stream)) + def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, root: int, + comm: ncclComm_t, stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count, + datatype, op, root, comm, + stream)) + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, count: int, datatype: int, op: int, comm: ncclComm_t, stream: cudaStream_t) -> None: @@ -342,6 +369,12 @@ def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + def ncclGroupStart(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupStart"]()) + + def ncclGroupEnd(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 495a758e6069..1bb0ca79cc1d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -383,6 +383,12 @@ def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: return self.device_communicator.all_gather(input_, dim) + def all_gatherv(self, + input_: Union[torch.Tensor, list[torch.Tensor]], + dim: int = 0, + sizes: Optional[list[int]] = None): + return self.device_communicator.all_gatherv(input_, dim, sizes) + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: @@ -401,6 +407,12 @@ def reduce_scatter(self, else: return self._reduce_scatter_out_place(input_, dim) + def reduce_scatterv(self, + input_: torch.Tensor, + dim: int = -1, + sizes: Optional[list[int]] = None) -> torch.Tensor: + return self.device_communicator.reduce_scatterv(input_, dim, sizes) + def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: return self.device_communicator.reduce_scatter(input_, dim)