Skip to content

Commit 00fcfc9

Browse files
committed
Move special case handling outside
Signed-off-by: Trevor Morris <[email protected]>
1 parent 0343099 commit 00fcfc9

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,17 +208,18 @@ def all_gatherv(self,
208208
pynccl_comm = self.pynccl_comm
209209
assert pynccl_comm is not None and not pynccl_comm.disabled
210210

211+
# 'sizes' is not needed if all inputs in the same group have the same
212+
# shape
213+
if sizes is not None and all(s == sizes[0] for s in sizes):
214+
sizes = None
215+
211216
def _all_gather_single(input_: torch.Tensor,
212217
sizes: Optional[list[int]] = None):
213218
input_size = input_.size()
214219
if sizes is not None:
215220
assert len(sizes) == world_size
216221
assert input_.shape[dim] == sizes[self.rank_in_group]
217222
output_size = (sum(sizes), ) + input_size[1:]
218-
# 'sizes' is not needed if all inputs in the same group have the
219-
# same shape
220-
if all(s == sizes[0] for s in sizes):
221-
sizes = None
222223
else:
223224
output_size = (input_size[0] * world_size, ) + input_size[1:]
224225
# Allocate output tensor.

0 commit comments

Comments
 (0)