Skip to content

Commit 7fcb48e

Browse files
committed
Add pynccl RS and AGv
1 parent 5c27d34 commit 7fcb48e

File tree

6 files changed

+127
-89
lines changed

6 files changed

+127
-89
lines changed

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import threading
4-
from typing import Optional, Union, List
4+
from typing import List, Optional, Union
55
from weakref import WeakValueDictionary
66

77
import torch
@@ -144,9 +144,17 @@ def all_gatherv(self,
144144
sizes: Optional[List[int]] = None):
145145
assert False, "not implemented"
146146

147+
def all_gatherv(self,
148+
input_: Union[torch.Tensor, List[torch.Tensor]],
149+
dim: int = 0,
150+
sizes: Optional[List[int]] = None):
151+
assert False, "not implemented"
152+
147153
def reduce_scatter(self,
148154
input_: torch.Tensor,
149-
dim: int = -1) -> torch.Tensor:
155+
dim: int = -1,
156+
sizes: Optional[List[int]] = None) -> torch.Tensor:
157+
assert sizes is None, "Varying size reduce scatter not supported with base device communicator"
150158
world_size = self.world_size
151159
# Bypass the function if we are using only 1 GPU.
152160
if world_size == 1:

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from typing import Optional, Union, List
4+
from typing import List, Optional, Union
55

66
import torch
77
from torch.distributed import ProcessGroup
@@ -99,7 +99,10 @@ def all_reduce(self, input_):
9999
torch.distributed.all_reduce(out, group=self.device_group)
100100
return out
101101

102-
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
102+
def reduce_scatter(self,
103+
input_: torch.Tensor,
104+
dim: int = -1,
105+
sizes: Optional[List[int]] = None):
103106
world_size = self.world_size
104107
pynccl_comm = self.pynccl_comm
105108
assert pynccl_comm is not None
@@ -111,15 +114,20 @@ def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
111114
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
112115
input_tensor = input_.movedim(0, dim).contiguous()
113116

114-
assert input_tensor.shape[0] % world_size == 0
115-
chunk_size = input_tensor.shape[0] // world_size
117+
if sizes is not None:
118+
assert len(sizes) == world_size
119+
assert input_tensor.shape[0] == sum(sizes)
120+
chunk_size = sizes[self.rank_in_group]
121+
else:
122+
assert input_tensor.shape[0] % world_size == 0
123+
chunk_size = input_tensor.shape[0] // world_size
116124
output_shape = (chunk_size, ) + input_tensor.shape[1:]
117125

118126
output = torch.empty(output_shape,
119127
dtype=input_tensor.dtype,
120128
device=input_tensor.device)
121129

122-
pynccl_comm.reduce_scatter(output, input_)
130+
pynccl_comm.reduce_scatter(output, input_, sizes=sizes)
123131

124132
# Reshape before returning
125133
return output.movedim(0, dim).contiguous()
@@ -170,28 +178,34 @@ def destroy(self):
170178
Use this:
171179
... = get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], dim=0, sizes=get_forward_context().dp_metadata.num_tokens_across_dp_cpu)
172180
"""
173-
def all_gatherv(self, input_: Union[torch.Tensor, List[torch.Tensor]], dim: int = 0, sizes: Optional[List[int]] = None):
181+
182+
def all_gatherv(self,
183+
input_: Union[torch.Tensor, List[torch.Tensor]],
184+
dim: int = 0,
185+
sizes: Optional[List[int]] = None):
174186
assert dim == 0, "only dim 0 all-gather is supported"
175187
world_size = self.world_size
176188
pynccl_comm = self.pynccl_comm
177189
assert pynccl_comm is not None and not pynccl_comm.disabled
178190

179-
def _all_gather_single(input_: torch.Tensor, sizes: Optional[List[int]] = None):
191+
def _all_gather_single(input_: torch.Tensor,
192+
sizes: Optional[List[int]] = None):
180193
input_size = input_.size()
181194
if sizes is not None:
182195
assert len(sizes) == world_size
183196
assert input_.shape[dim] == sizes[self.rank_in_group]
197+
output_size = (sum(sizes), ) + input_size[1:]
184198
# 'sizes' is not needed if all inputs in the same group have the same shape
185199
if all(s == sizes[0] for s in sizes):
186200
sizes = None
187-
output_size = (sum(sizes),) + input_size[1:]
188201
else:
189-
output_size = (input_size[0] * world_size,) + input_size[1:]
202+
output_size = (input_size[0] * world_size, ) + input_size[1:]
190203
# Allocate output tensor.
191-
output_tensor = torch.empty(
192-
output_size, dtype=input_.dtype, device=input_.device
193-
)
204+
output_tensor = torch.empty(output_size,
205+
dtype=input_.dtype,
206+
device=input_.device)
194207
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
208+
return output_tensor
195209

196210
if isinstance(input_, torch.Tensor):
197211
return _all_gather_single(input_, sizes)
@@ -201,7 +215,6 @@ def _all_gather_single(input_: torch.Tensor, sizes: Optional[List[int]] = None):
201215
for inp in input_:
202216
output_list.append(_all_gather_single(inp, sizes=sizes))
203217
pynccl_comm.group_end()
204-
205218
return output_list
206219

207220
def dispatch(

vllm/distributed/device_communicators/pynccl.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from typing import Optional, Union, List
4+
from typing import List, Optional, Union
55

6+
import numpy as np
67
# ===================== import region =====================
78
import torch
89
import torch.distributed as dist
@@ -147,13 +148,14 @@ def all_gather(self,
147148
f"but the input tensor is on {input_tensor.device}")
148149
if stream is None:
149150
stream = current_stream()
150-
if sizes:
151+
if sizes is not None:
151152
assert output_tensor.shape[0] == sum(sizes)
152153
numel_base = int(np.prod(output_tensor.shape[1:]))
153154
split_offset = 0
154155
self.nccl.ncclGroupStart()
155156
for root, split_size in enumerate(sizes):
156-
dst_slice = output_tensor[split_offset:split_offset + split_size]
157+
dst_slice = output_tensor[split_offset:split_offset +
158+
split_size]
157159
self.nccl.ncclBroadcast(
158160
buffer_type(input_tensor.data_ptr()),
159161
buffer_type(dst_slice.data_ptr()),
@@ -176,7 +178,8 @@ def reduce_scatter(self,
176178
output_tensor: torch.Tensor,
177179
input_tensor: torch.Tensor,
178180
op: ReduceOp = ReduceOp.SUM,
179-
stream=None):
181+
stream=None,
182+
sizes: Optional[List[int]] = None):
180183
if self.disabled:
181184
return
182185
# nccl communicator created on a specific device
@@ -187,12 +190,29 @@ def reduce_scatter(self,
187190
f"but the input tensor is on {input_tensor.device}")
188191
if stream is None:
189192
stream = current_stream()
190-
self.nccl.ncclReduceScatter(
191-
buffer_type(input_tensor.data_ptr()),
192-
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
193-
ncclDataTypeEnum.from_torch(input_tensor.dtype),
194-
ncclRedOpTypeEnum.from_torch(op), self.comm,
195-
cudaStream_t(stream.cuda_stream))
193+
194+
if sizes is not None:
195+
numel_base = int(np.prod(input_tensor.shape[1:]))
196+
split_offset = 0
197+
self.nccl.ncclGroupStart()
198+
for root, split_size in enumerate(sizes):
199+
chunk = input_tensor[split_offset:split_offset + split_size, :]
200+
self.nccl.ncclReduce(
201+
buffer_type(chunk.data_ptr()),
202+
buffer_type(output_tensor.data_ptr()),
203+
split_size * numel_base,
204+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
205+
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
206+
cudaStream_t(stream.cuda_stream))
207+
split_offset += split_size
208+
self.nccl.ncclGroupEnd()
209+
else:
210+
self.nccl.ncclReduceScatter(
211+
buffer_type(input_tensor.data_ptr()),
212+
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
213+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
214+
ncclRedOpTypeEnum.from_torch(op), self.comm,
215+
cudaStream_t(stream.cuda_stream))
196216

197217
def send(self, tensor: torch.Tensor, dst: int, stream=None):
198218
if self.disabled:
@@ -236,9 +256,9 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
236256
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
237257
ncclDataTypeEnum.from_torch(tensor.dtype), src,
238258
self.comm, cudaStream_t(stream.cuda_stream))
239-
259+
240260
def group_start(self):
241261
self.nccl.ncclGroupStart()
242262

243263
def group_end(self):
244-
self.nccl.ncclGroupEnd()
264+
self.nccl.ncclGroupEnd()

vllm/distributed/device_communicators/pynccl_wrapper.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,16 @@ class NCCLLibrary:
154154
ncclRedOp_t, ncclComm_t, cudaStream_t
155155
]),
156156

157+
# ncclResult_t ncclReduce(
158+
# const void* sendbuff, void* recvbuff, size_t count,
159+
# ncclDataType_t datatype, ncclRedOp_t op, int root,
160+
# ncclComm_t comm, cudaStream_t stream);
161+
# note that cudaStream_t is a pointer type, so the last argument
162+
# is a pointer
163+
Function("ncclReduce", ncclResult_t, [
164+
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
165+
ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t
166+
]),
157167
# ncclResult_t ncclAllGather(
158168
# const void* sendbuff, void* recvbuff, size_t count,
159169
# ncclDataType_t datatype, ncclComm_t comm,
@@ -207,7 +217,7 @@ class NCCLLibrary:
207217
# it is better not to call it at all.
208218
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
209219
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
210-
# ncclResult_t ncclGroupStart();Add commentMore actions
220+
# ncclResult_t ncclGroupStart();
211221
Function("ncclGroupStart", ncclResult_t, []),
212222
# ncclResult_t ncclGroupEnd();
213223
Function("ncclGroupEnd", ncclResult_t, []),
@@ -304,6 +314,18 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
304314
datatype, op, comm,
305315
stream))
306316

317+
def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
318+
count: int, datatype: int, op: int, root: int,
319+
comm: ncclComm_t, stream: cudaStream_t) -> None:
320+
# `datatype` actually should be `ncclDataType_t`
321+
# and `op` should be `ncclRedOp_t`
322+
# both are aliases of `ctypes.c_int`
323+
# when we pass int to a function, it will be converted to `ctypes.c_int`
324+
# by ctypes automatically
325+
self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count,
326+
datatype, op, root, comm,
327+
stream))
328+
307329
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
308330
count: int, datatype: int, op: int, comm: ncclComm_t,
309331
stream: cudaStream_t) -> None:
@@ -348,9 +370,11 @@ def ncclCommDestroy(self, comm: ncclComm_t) -> None:
348370

349371
def ncclGroupStart(self) -> None:
350372
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
373+
351374
def ncclGroupEnd(self) -> None:
352375
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
353376

377+
354378
__all__ = [
355379
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
356380
"ncclComm_t", "cudaStream_t", "buffer_type"

vllm/distributed/parallel_state.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from contextlib import contextmanager, nullcontext
3131
from dataclasses import dataclass
3232
from multiprocessing import shared_memory
33-
from typing import Any, Callable, Optional, Union, List
33+
from typing import Any, Callable, List, Optional, Union
3434
from unittest.mock import patch
3535

3636
import torch
@@ -380,16 +380,17 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
380380
def _all_gather_out_place(self, input_: torch.Tensor,
381381
dim: int) -> torch.Tensor:
382382
return self.device_communicator.all_gather(input_, dim)
383-
384-
def all_gatherv(self,
383+
384+
def all_gatherv(self,
385385
input_: Union[torch.Tensor, List[torch.Tensor]],
386386
dim: int = 0,
387387
sizes: Optional[List[int]] = None):
388388
return self.device_communicator.all_gatherv(input_, dim, sizes)
389-
389+
390390
def reduce_scatter(self,
391391
input_: torch.Tensor,
392-
dim: int = -1) -> torch.Tensor:
392+
dim: int = -1,
393+
sizes: Optional[List[int]] = None) -> torch.Tensor:
393394
world_size = self.world_size
394395
# Bypass the function if we are using only 1 GPU.
395396
if world_size == 1:
@@ -398,16 +399,20 @@ def reduce_scatter(self,
398399
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
399400

400401
if self.use_custom_op_call:
402+
assert sizes is None, "Varying size reduce scatter not supported with vllm custom op"
401403
return torch.ops.vllm.reduce_scatter(input_,
402404
dim,
403405
world_size,
404406
group_name=self.unique_name)
405407
else:
406-
return self._reduce_scatter_out_place(input_, dim)
407-
408-
def _reduce_scatter_out_place(self, input_: torch.Tensor,
409-
dim: int) -> torch.Tensor:
410-
return self.device_communicator.reduce_scatter(input_, dim)
408+
return self._reduce_scatter_out_place(input_, dim, sizes)
409+
410+
def _reduce_scatter_out_place(
411+
self,
412+
input_: torch.Tensor,
413+
dim: int,
414+
sizes: Optional[List[int]] = None) -> torch.Tensor:
415+
return self.device_communicator.reduce_scatter(input_, dim, sizes)
411416

412417
def gather(self,
413418
input_: torch.Tensor,

0 commit comments

Comments
 (0)