Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import multiprocessing
import os

import numpy as np
import pytest
import torch
import torch.distributed
Expand Down Expand Up @@ -177,6 +178,38 @@ def test_pynccl_all_gather():
distributed_run(all_gather_worker_fn, 2)


@worker_fn_wrapper
def all_gatherv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)

rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'

assert world_size <= 8
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
num_elems = sizes[rank]
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * 100
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)

expected = torch.cat([
torch.arange(sizes[r], dtype=torch.float32) + r * 100
for r in range(world_size)
]).to(device)

pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gatherv():
distributed_run(all_gatherv_worker_fn, 2)


@worker_fn_wrapper
def reduce_scatter_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
Expand Down Expand Up @@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter():
distributed_run(reduce_scatter_worker_fn, 2)


@worker_fn_wrapper
def reduce_scatterv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)

rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'

assert world_size <= 8
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
num_elems = sum(sizes)
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * 100
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)

# Calculate expected result for this rank's chunk
all_tensors = [
torch.arange(num_elems, dtype=torch.float32) + r * 100
for r in range(world_size)
]
sizes_cumsum = np.cumsum(sizes)
start = 0 if rank == 0 else sizes_cumsum[rank - 1]
end = sizes_cumsum[rank]
expected = sum(tensor[start:end] for tensor in all_tensors).to(device)

pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatterv():
distributed_run(reduce_scatterv_worker_fn, 2)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Optional
from typing import Optional, Union
from weakref import WeakValueDictionary

import torch
Expand Down Expand Up @@ -138,6 +138,14 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
input_size[dim + 1:])
return output_tensor

def all_gatherv(
self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None
) -> Union[torch.Tensor, list[torch.Tensor]]:
raise NotImplementedError

def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
Expand Down Expand Up @@ -172,6 +180,12 @@ def reduce_scatter(self,
# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()

def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None) -> torch.Tensor:
raise NotImplementedError

def gather(self,
input_: torch.Tensor,
dst: int = 0,
Expand Down
83 changes: 82 additions & 1 deletion vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional
from typing import Optional, Union

import torch
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -142,6 +142,42 @@ def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
# Reshape before returning
return output.movedim(0, dim).contiguous()

def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()

# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()

if sizes is not None:
assert len(sizes) == world_size
assert input_tensor.shape[0] == sum(sizes)
chunk_size = sizes[self.rank_in_group]
else:
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]

output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)

if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
else:
pynccl_comm.reduce_scatter(output, input_)

# Reshape before returning
return output.movedim(0, dim).contiguous()

def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
Expand Down Expand Up @@ -180,6 +216,51 @@ def destroy(self):
self.all2all_manager.destroy()
self.all2all_manager = None

def all_gatherv(self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None):
if dim != 0:
raise NotImplementedError("only dim 0 all-gatherv is supported")
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None and not pynccl_comm.disabled

# 'sizes' is not needed if all inputs in the same group have the same
# shape
if sizes is not None and all(s == sizes[0] for s in sizes):
sizes = None

def _all_gather_single(input_: torch.Tensor,
sizes: Optional[list[int]] = None):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[dim] == sizes[self.rank_in_group]
output_size = (sum(sizes), ) + input_size[1:]
else:
output_size = (input_size[0] * world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
if sizes is not None:
pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes)
else:
pynccl_comm.all_gather(output_tensor, input_)
return output_tensor

if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)

output_list = []
pynccl_comm.group_start()
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
pynccl_comm.group_end()

return output_list

def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down
72 changes: 72 additions & 0 deletions vllm/distributed/device_communicators/pynccl.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last comment - I think it would be cleaner and clearer if the all_gather and all_gatherv implementations were completely separate. Right now it's slightly awkward that all_gatherv calls pynccl all_gather with a list of sizes. Ditto for reduce_scatter/reduce_scatterv.

Otherwise looks good to me, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth Could you take a look?

Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,40 @@ def all_gather(self,
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
cudaStream_t(stream.cuda_stream))

def all_gatherv(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
sizes: list[int],
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
assert output_tensor.shape[0] == sum(sizes)
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset:split_offset + split_size]
self.nccl.ncclBroadcast(
buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
dst_slice.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()

def reduce_scatter(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
Expand All @@ -174,6 +208,38 @@ def reduce_scatter(self,
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))

def reduce_scatterv(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
sizes: list[int],
op: ReduceOp = ReduceOp.SUM,
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()

split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
chunk = input_tensor[split_offset:split_offset + split_size, ...]
self.nccl.ncclReduce(
buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()), chunk.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
cudaStream_t(stream.cuda_stream))
split_offset += split_size
self.nccl.ncclGroupEnd()

def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
Expand Down Expand Up @@ -216,3 +282,9 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))

def group_start(self):
self.nccl.ncclGroupStart()

def group_end(self):
self.nccl.ncclGroupEnd()
33 changes: 33 additions & 0 deletions vllm/distributed/device_communicators/pynccl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,17 @@ class NCCLLibrary:
ncclRedOp_t, ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, int root,
# ncclComm_t comm, cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclReduce", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
Expand Down Expand Up @@ -207,6 +218,10 @@ class NCCLLibrary:
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
# ncclResult_t ncclGroupStart();
Function("ncclGroupStart", ncclResult_t, []),
# ncclResult_t ncclGroupEnd();
Function("ncclGroupEnd", ncclResult_t, []),
]

# class attribute to store the mapping from the path to the library
Expand Down Expand Up @@ -300,6 +315,18 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
datatype, op, comm,
stream))

def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, root: int,
comm: ncclComm_t, stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count,
datatype, op, root, comm,
stream))

def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
Expand Down Expand Up @@ -342,6 +369,12 @@ def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))

def ncclGroupStart(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())

def ncclGroupEnd(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())


__all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
Expand Down
Loading