diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index c731da3800..391ae74dff 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from collections.abc import Sequence +from collections.abc import Generator, Sequence from contextlib import contextmanager from dataclasses import dataclass @@ -27,7 +27,7 @@ @contextmanager -def disable_active_parametrization(): +def disable_active_parametrization() -> Generator[None, None, None]: global _active_parametrization try: _active_parametrization = False @@ -180,18 +180,18 @@ def _register_parametrization( class ReplicateComputation(torch.nn.Module): def __init__( self, - device_mesh, - param_sharding, - mode, - mp_policy, - reduction_divide_factor, - ): + device_mesh: DeviceMesh, + param_sharding: tuple[Placement, ...], + mode: str, + mp_policy: MixedPrecisionPolicy | None, + reduction_divide_factor: float | None, + ) -> None: super().__init__() self.device_mesh = device_mesh self.param_sharding = param_sharding self.mode = mode - self.compute_placements = [Replicate()] * self.device_mesh.ndim - self.grad_placements = [ + self.compute_placements: list[Placement] = [Replicate()] * self.device_mesh.ndim + self.grad_placements: list[Placement] = [ _ScaledPartial( reduction_divide_factor=reduction_divide_factor, ) @@ -199,8 +199,8 @@ def __init__( else Partial(reduce_op="avg") ] * self.device_mesh.ndim mp_policy = mp_policy or MixedPrecisionPolicy() - self.param_dtype = mp_policy.param_dtype - self.reduce_dtype = mp_policy.reduce_dtype + self.param_dtype: torch.dtype | None = mp_policy.param_dtype + self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype def replicate_compute(self, x: DTensor) -> torch.Tensor: # data parallel runtime replicate parameters and do local compute @@ -274,7 +274,8 @@ def data_parallel( mp_policy: MixedPrecisionPolicy | None = None, shard_dim: int = 0, reduction_divide_factor: float | None = None, -): +) -> nn.Module: + param_sharding: tuple[Placement, ...] if mode == "replicate": param_sharding = (Replicate(),) elif mode == "fully_shard":