-
-
Notifications
You must be signed in to change notification settings - Fork 13.2k
[XPU]Support AgRsAll2AllManager on XPU device #32654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -23,23 +23,146 @@ def __init__( | |||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
| super().__init__(cpu_group, device, device_group, unique_name) | ||||||||||||||||||||||||||
| if self.use_all2all: | ||||||||||||||||||||||||||
| if self.all2all_backend != "naive": # type: ignore[has-type] | ||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||
| "`%s` all2all manager is not supported on XPU. " | ||||||||||||||||||||||||||
| "Falling back to `naive` all2all manager for XPU.", | ||||||||||||||||||||||||||
| self.all2all_backend, # type: ignore[has-type] | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| self.all2all_backend = "naive" | ||||||||||||||||||||||||||
| if self.all2all_backend == "naive": | ||||||||||||||||||||||||||
| from .all2all import NaiveAll2AllManager | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.all2all_manager = NaiveAll2AllManager(self.cpu_group) | ||||||||||||||||||||||||||
| logger.info("Using naive all2all manager.") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| elif self.all2all_backend == "allgather_reducescatter": | ||||||||||||||||||||||||||
| from .all2all import AgRsAll2AllManager | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.all2all_manager = AgRsAll2AllManager(self.cpu_group) | ||||||||||||||||||||||||||
| logger.info("Using AgRs manager on XPU device.") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| else: # type: ignore[has-type] | ||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||
| "`%s` all2all manager is not supported on XPU. " | ||||||||||||||||||||||||||
| "Falling back to AgRs manager for XPU, " | ||||||||||||||||||||||||||
| "which is the Default backend", | ||||||||||||||||||||||||||
| self.all2all_backend, # type: ignore[has-type] | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| from .all2all import AgRsAll2AllManager | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.all2all_manager = AgRsAll2AllManager(self.cpu_group) | ||||||||||||||||||||||||||
| logger.info("Using AgRs manager on XPU device.") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def all_reduce(self, input_) -> torch.Tensor: | ||||||||||||||||||||||||||
| dist.all_reduce(input_, group=self.device_group) | ||||||||||||||||||||||||||
| return input_ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): | ||||||||||||||||||||||||||
| world_size = self.world_size | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if dim < 0: | ||||||||||||||||||||||||||
| # Convert negative dim to positive. | ||||||||||||||||||||||||||
| dim += input_.dim() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Note: This will produce an incorrect answer if we don't make | ||||||||||||||||||||||||||
| # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? | ||||||||||||||||||||||||||
| input_tensor = input_.movedim(0, dim).contiguous() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| assert input_tensor.shape[0] % world_size == 0 | ||||||||||||||||||||||||||
| chunk_size = input_tensor.shape[0] // world_size | ||||||||||||||||||||||||||
| output_shape = (chunk_size,) + input_tensor.shape[1:] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| output = torch.empty( | ||||||||||||||||||||||||||
| output_shape, dtype=input_tensor.dtype, device=input_tensor.device | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| dist.reduce_scatter_tensor(output, input_tensor) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Reshape before returning | ||||||||||||||||||||||||||
| return output.movedim(0, dim).contiguous() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def reduce_scatterv( | ||||||||||||||||||||||||||
| self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None | ||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
| world_size = self.world_size | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if dim < 0: | ||||||||||||||||||||||||||
| # Convert negative dim to positive. | ||||||||||||||||||||||||||
| dim += input_.dim() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Note: This will produce an incorrect answer if we don't make | ||||||||||||||||||||||||||
| # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? | ||||||||||||||||||||||||||
| input_tensor = input_.movedim(0, dim).contiguous() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if sizes is not None: | ||||||||||||||||||||||||||
| assert len(sizes) == world_size | ||||||||||||||||||||||||||
| assert input_tensor.shape[0] == sum(sizes) | ||||||||||||||||||||||||||
| chunk_size = sizes[self.rank_in_group] | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| assert input_tensor.shape[0] % world_size == 0 | ||||||||||||||||||||||||||
| chunk_size = input_tensor.shape[0] // world_size | ||||||||||||||||||||||||||
| output_shape = (chunk_size,) + input_tensor.shape[1:] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| output = torch.empty( | ||||||||||||||||||||||||||
| output_shape, dtype=input_tensor.dtype, device=input_tensor.device | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| if sizes is not None and sizes.count(sizes[0]) != len(sizes): | ||||||||||||||||||||||||||
| # if inputs shape in different ranks is not the same using reduce_scatter | ||||||||||||||||||||||||||
| input_splits = list(input_tensor.split(sizes, dim=0)) | ||||||||||||||||||||||||||
| dist.reduce_scatter(output, input_splits) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| dist.reduce_scatter_tensor(output, input_tensor) | ||||||||||||||||||||||||||
|
Comment on lines
+103
to
+108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||
| # Reshape before returning | ||||||||||||||||||||||||||
| return output.movedim(0, dim).contiguous() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def all_gatherv( | ||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||
| input_: torch.Tensor | list[torch.Tensor], | ||||||||||||||||||||||||||
| dim: int = 0, | ||||||||||||||||||||||||||
| sizes: list[int] | None = None, | ||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
| if dim != 0: | ||||||||||||||||||||||||||
| raise NotImplementedError("only dim 0 all-gatherv is supported") | ||||||||||||||||||||||||||
| world_size = self.world_size | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 'sizes' is not needed if all inputs in the same group have the same | ||||||||||||||||||||||||||
| # shape | ||||||||||||||||||||||||||
| if sizes is not None and all(s == sizes[0] for s in sizes): | ||||||||||||||||||||||||||
| sizes = None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None): | ||||||||||||||||||||||||||
| input_size = input_.size() | ||||||||||||||||||||||||||
| if sizes is not None: | ||||||||||||||||||||||||||
| assert len(sizes) == world_size | ||||||||||||||||||||||||||
| assert input_.shape[dim] == sizes[self.rank_in_group], ( | ||||||||||||||||||||||||||
| f"{input_.shape[dim]} != {sizes[self.rank_in_group]}" | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| output_size = (sum(sizes),) + input_size[1:] | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| output_size = (input_size[0] * world_size,) + input_size[1:] | ||||||||||||||||||||||||||
| # Allocate output tensor. | ||||||||||||||||||||||||||
| output_tensor = torch.empty( | ||||||||||||||||||||||||||
| output_size, dtype=input_.dtype, device=input_.device | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if sizes is not None: | ||||||||||||||||||||||||||
| all_gather_list = [] | ||||||||||||||||||||||||||
| for size in sizes: | ||||||||||||||||||||||||||
| all_gather_list.append( | ||||||||||||||||||||||||||
| torch.empty( | ||||||||||||||||||||||||||
| (size,) + input_.shape[1:], | ||||||||||||||||||||||||||
| dtype=input_.dtype, | ||||||||||||||||||||||||||
| device=input_.device, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| dist.all_gather(all_gather_list, input_) | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||
| output_tensor = torch.cat(all_gather_list, dim=0) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| dist.all_gather([output_tensor], input_) | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of
Suggested change
|
||||||||||||||||||||||||||
| return output_tensor | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if isinstance(input_, torch.Tensor): | ||||||||||||||||||||||||||
| return _all_gather_single(input_, sizes) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| output_list = [] | ||||||||||||||||||||||||||
| for inp in input_: | ||||||||||||||||||||||||||
| output_list.append(_all_gather_single(inp, sizes=sizes)) | ||||||||||||||||||||||||||
| return output_list | ||||||||||||||||||||||||||
|
Comment on lines
+161
to
+164
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def gather( | ||||||||||||||||||||||||||
| self, input_: torch.Tensor, dst: int = 0, dim: int = -1 | ||||||||||||||||||||||||||
| ) -> torch.Tensor | None: | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to
dist.reduce_scatter_tensoris missing thegroupargument. This will cause it to use the default process group, which is likely incorrect for tensor parallelism. You should passgroup=self.device_groupto ensure communication occurs within the intended process group.