diff --git a/aiter/dist/parallel_state.py b/aiter/dist/parallel_state.py index 1d74e136ac..51aa3ff319 100644 --- a/aiter/dist/parallel_state.py +++ b/aiter/dist/parallel_state.py @@ -389,6 +389,12 @@ def _all_gather_out_place(self, input_: torch.Tensor) -> torch.Tensor: def custom_all_gather(self, input_: torch.Tensor) -> torch.Tensor: return outplace_all_gather(input_, group_name=self.unique_name) + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + if self.device_communicator is None: + raise ValueError("No device communicator found") + return self.device_communicator.reduce_scatter(input_, dim) + + def all_gather( self, input_: torch.Tensor, use_custom: bool = False, dim: int = -1 ) -> torch.Tensor: @@ -878,7 +884,6 @@ def get_pp_group() -> GroupCoordinator: return _PP -from typing import Optional _DP: Optional[GroupCoordinator] = None