File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed
vllm/distributed/device_communicators Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -208,17 +208,18 @@ def all_gatherv(self,
208208 pynccl_comm = self .pynccl_comm
209209 assert pynccl_comm is not None and not pynccl_comm .disabled
210210
211+ # 'sizes' is not needed if all inputs in the same group have the same
212+ # shape
213+ if sizes is not None and all (s == sizes [0 ] for s in sizes ):
214+ sizes = None
215+
211216 def _all_gather_single (input_ : torch .Tensor ,
212217 sizes : Optional [list [int ]] = None ):
213218 input_size = input_ .size ()
214219 if sizes is not None :
215220 assert len (sizes ) == world_size
216221 assert input_ .shape [dim ] == sizes [self .rank_in_group ]
217222 output_size = (sum (sizes ), ) + input_size [1 :]
218- # 'sizes' is not needed if all inputs in the same group have the
219- # same shape
220- if all (s == sizes [0 ] for s in sizes ):
221- sizes = None
222223 else :
223224 output_size = (input_size [0 ] * world_size , ) + input_size [1 :]
224225 # Allocate output tensor.
You can’t perform that action at this time.
0 commit comments