Skip to content

Commit

Permalink
add precision as an argument to FSDP strategy dataclass (#501)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #501

# Context
We want to add `mixed_precision` to `FSDPStrategy` dataclass, and stop populating the `FSDPStrategy` with whatever `precision` the user has passed to the `AutoUnit`'s constructor. See attached task for more details

# This diff
- Add the above.
- Modify from `asdict()` to `__dict__` as the former works recursively, which then throws an exception when `MixedPrecision` becomes a dict
- Adjust users who use `FSDPStrategy` to pass in the `mixed_precision`. Found the right `precision` to use by looking for the `precision` passed in the `AutoUnit` constructor

Reviewed By: JKSenthil

Differential Revision: D48248411

fbshipit-source-id: 2f163cd39b1e23487853ff1d2e8e47018a07acf6
  • Loading branch information
galrotem authored and facebook-github-bot committed Aug 16, 2023
1 parent 8046981 commit 95f6944
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
27 changes: 27 additions & 0 deletions tests/utils/test_prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.distributed.launcher as launcher
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.nn.parallel import DistributedDataParallel as DDP
from torchtnt.utils.env import init_from_env
from torchtnt.utils.prepare_module import (
Expand Down Expand Up @@ -121,3 +122,29 @@ def _test_is_fsdp_module() -> None:
def test_is_fsdp_module(self) -> None:
config = get_pet_launch_config(2)
launcher.elastic_launch(config, entrypoint=self._test_is_fsdp_module)()

@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_fdsp_precision(self) -> None:
config = get_pet_launch_config(2)
launcher.elastic_launch(config, entrypoint=self._test_fdsp_precision)()

@staticmethod
def _test_fdsp_precision() -> None:
module = torch.nn.Linear(1, 1)
device = init_from_env()
mixed_precision = MixedPrecision(
param_dtype=torch.float64,
)
fsdp_module = prepare_fsdp(
module, device, FSDPStrategy(mixed_precision=mixed_precision)
)
tc = unittest.TestCase()
tc.assertTrue(isinstance(fsdp_module, FSDP))
tc.assertEqual(
fsdp_module.mixed_precision.param_dtype, mixed_precision.param_dtype
)
3 changes: 1 addition & 2 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def __init__(
module,
self.device,
strategy,
self.precision,
)
else:
module = module.to(self.device)
Expand Down Expand Up @@ -446,7 +445,7 @@ def __init__(
rank_zero_warn(
"We recommend setting FSDPStrategy's use_original_params to True when using torch compile."
)
module = prepare_fsdp(module, self.device, strategy, self.precision)
module = prepare_fsdp(module, self.device, strategy)
else:
module = module.to(self.device)

Expand Down
14 changes: 3 additions & 11 deletions torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class FSDPStrategy(Strategy):
cpu_offload: Optional[CPUOffload] = None
auto_wrap_policy: Optional[Callable[[torch.nn.Module, bool, int], bool]] = None
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE
mixed_precision: Optional[MixedPrecision] = None
ignored_modules: Optional[Iterable[torch.nn.Module]] = None
sync_module_states: bool = False
forward_prefetch: bool = False
Expand Down Expand Up @@ -135,7 +136,6 @@ def prepare_fsdp(
module: torch.nn.Module,
device: torch.device,
strategy: Optional[FSDPStrategy] = None,
precision: Optional[torch.dtype] = None,
) -> FSDP:
"""
Utility to move a module to device and wrap in `FullyShardedDataParallel <https://pytorch.org/docs/stable/fsdp.html>`_.
Expand All @@ -144,7 +144,6 @@ def prepare_fsdp(
module: module to be wrapped in FSDP
strategy: an instance of FSDPStrategy which defines the settings of FSDP APIs
device: device to which module will be moved
precision: precision to use when wrapping in FSDP
Examples::
strategy = FSDPStrategy(limit_all_gathers=True)
Expand All @@ -157,15 +156,9 @@ def prepare_fsdp(
"Please install PyTorch 1.12 or higher to use FSDP: https://pytorch.org/get-started/locally/"
)
strategy = strategy if strategy is not None else FSDPStrategy()
mixed_precision = None
if precision:
mixed_precision = MixedPrecision(
param_dtype=precision,
reduce_dtype=precision,
buffer_dtype=precision,
)

params_dict = asdict(strategy)
# we use __dict__ and not asdict() here because asdict() is recursively applied on nested objects
params_dict = strategy.__dict__.copy()

# extract params to set state dict type
state_dict_type = params_dict.pop("state_dict_type")
Expand All @@ -176,7 +169,6 @@ def prepare_fsdp(
module = FSDP(
module,
device_id=device,
mixed_precision=mixed_precision,
**params_dict,
)

Expand Down

0 comments on commit 95f6944

Please sign in to comment.