diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index f3d9262d20cf..6bc26b6f3b1c 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -23,23 +23,146 @@ def __init__( ): super().__init__(cpu_group, device, device_group, unique_name) if self.use_all2all: - if self.all2all_backend != "naive": # type: ignore[has-type] - logger.warning( - "`%s` all2all manager is not supported on XPU. " - "Falling back to `naive` all2all manager for XPU.", - self.all2all_backend, # type: ignore[has-type] - ) - self.all2all_backend = "naive" if self.all2all_backend == "naive": from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") + elif self.all2all_backend == "allgather_reducescatter": + from .all2all import AgRsAll2AllManager + + self.all2all_manager = AgRsAll2AllManager(self.cpu_group) + logger.info("Using AgRs manager on XPU device.") + + else: # type: ignore[has-type] + logger.warning( + "`%s` all2all manager is not supported on XPU. " + "Falling back to AgRs manager for XPU, " + "which is the Default backend", + self.all2all_backend, # type: ignore[has-type] + ) + from .all2all import AgRsAll2AllManager + + self.all2all_manager = AgRsAll2AllManager(self.cpu_group) + logger.info("Using AgRs manager on XPU device.") + def all_reduce(self, input_) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + world_size = self.world_size + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size,) + input_tensor.shape[1:] + + output = torch.empty( + output_shape, dtype=input_tensor.dtype, device=input_tensor.device + ) + + dist.reduce_scatter_tensor(output, input_tensor) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + + def reduce_scatterv( + self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None + ): + world_size = self.world_size + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + if sizes is not None: + assert len(sizes) == world_size + assert input_tensor.shape[0] == sum(sizes) + chunk_size = sizes[self.rank_in_group] + else: + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size,) + input_tensor.shape[1:] + + output = torch.empty( + output_shape, dtype=input_tensor.dtype, device=input_tensor.device + ) + if sizes is not None and sizes.count(sizes[0]) != len(sizes): + # if inputs shape in different ranks is not the same using reduce_scatter + input_splits = list(input_tensor.split(sizes, dim=0)) + dist.reduce_scatter(output, input_splits) + else: + dist.reduce_scatter_tensor(output, input_tensor) + # Reshape before returning + return output.movedim(0, dim).contiguous() + + def all_gatherv( + self, + input_: torch.Tensor | list[torch.Tensor], + dim: int = 0, + sizes: list[int] | None = None, + ): + if dim != 0: + raise NotImplementedError("only dim 0 all-gatherv is supported") + world_size = self.world_size + + # 'sizes' is not needed if all inputs in the same group have the same + # shape + if sizes is not None and all(s == sizes[0] for s in sizes): + sizes = None + + def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None): + input_size = input_.size() + if sizes is not None: + assert len(sizes) == world_size + assert input_.shape[dim] == sizes[self.rank_in_group], ( + f"{input_.shape[dim]} != {sizes[self.rank_in_group]}" + ) + output_size = (sum(sizes),) + input_size[1:] + else: + output_size = (input_size[0] * world_size,) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) + + if sizes is not None: + all_gather_list = [] + for size in sizes: + all_gather_list.append( + torch.empty( + (size,) + input_.shape[1:], + dtype=input_.dtype, + device=input_.device, + ) + ) + dist.all_gather(all_gather_list, input_) + output_tensor = torch.cat(all_gather_list, dim=0) + else: + dist.all_gather([output_tensor], input_) + return output_tensor + + if isinstance(input_, torch.Tensor): + return _all_gather_single(input_, sizes) + + output_list = [] + for inp in input_: + output_list.append(_all_gather_single(inp, sizes=sizes)) + return output_list + def gather( self, input_: torch.Tensor, dst: int = 0, dim: int = -1 ) -> torch.Tensor | None: