diff --git a/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/_utils.py b/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/_utils.py index f84b15c6..b6cd53e0 100644 --- a/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/_utils.py +++ b/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/_utils.py @@ -23,7 +23,6 @@ from oslo.torch.distributed import ParallelMode -# TODO def is_model_parallel_parameter(p: torch.Tensor) -> bool: """ Check if a parameter is parallel in either Pipeline or Tensor mode. @@ -34,10 +33,19 @@ def is_model_parallel_parameter(p: torch.Tensor) -> bool: Returns: bool: True if the parameter is parallel in either mode, False otherwise. """ - parallel_mode = getattr(p, "oslo_parallel", {}) - return ParallelMode.PIPELINE in parallel_mode or any( - key.startswith("tensor") for key in parallel_mode - ) + oslo_parallel = getattr(p, "oslo_parallel", {}) + parallel_modes = [ + ParallelMode.PIPELINE, + ParallelMode.TENSOR_1D, + ParallelMode.TENSOR_2D_ROW, + ParallelMode.TENSOR_2D_COL, + ParallelMode.TENSOR_2P5D_ROW, + ParallelMode.TENSOR_2P5D_COL, + ParallelMode.TENSOR_2P5D_DEP, + ParallelMode.TENSOR_3D_INPUT, + ParallelMode.TENSOR_3D_OUTPUT, + ] + return any(mode in oslo_parallel for mode in parallel_modes) def flatten(input_: Iterable[torch.Tensor]) -> torch.Tensor: