diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 4f9261175e..1c3b614698 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -186,29 +186,13 @@ def __init__( dtype=self.precision, enabled=self.precision is not None, ) - if strategy: - if isinstance(strategy, str): - strategy = _convert_str_to_strategy(strategy) - if isinstance(strategy, DDPStrategy): - if torch_compile_params and strategy.static_graph is True: - # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 - raise RuntimeError( - "Torch compile requires DDPStrategy's static_graph to be False" - ) - module = prepare_ddp(module, self.device, strategy) - elif isinstance(strategy, FSDPStrategy): - if torch_compile_params and strategy.use_orig_params is False: - # as stated here https://pytorch.org/get-started/pytorch-2.0/ - rank_zero_warn( - "We recommend setting FSDPStrategy's use_orig_params to True when using torch compile." - ) - module = prepare_fsdp( - module, - self.device, - strategy, - ) - else: - module = module.to(self.device) + module = _apply_strategy_and_check( + module, + self.device, + strategy, + swa_params=None, + torch_compile_params=torch_compile_params, + ) if torch_compile_params: try: # use in-place compile to avoid altering the state_dict keys @@ -426,28 +410,9 @@ def __init__( else precision ) - if strategy: - if isinstance(strategy, str): - strategy = _convert_str_to_strategy(strategy) - if isinstance(strategy, DDPStrategy): - if torch_compile_params and strategy.static_graph is True: - # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 - raise RuntimeError( - "Torch compile requires DDPStrategy's static_graph to be False" - ) - module = prepare_ddp(module, self.device, strategy) - elif isinstance(strategy, FSDPStrategy): - if swa_params: - raise RuntimeError( - "Stochastic Weight Averaging is currently not supported with the FSDP strategy" - ) - # as stated here https://pytorch.org/get-started/pytorch-2.0/ - rank_zero_warn( - "We recommend setting FSDPStrategy's use_original_params to True when using torch compile." - ) - module = prepare_fsdp(module, self.device, strategy) - else: - module = module.to(self.device) + module = _apply_strategy_and_check( + module, self.device, strategy, swa_params, torch_compile_params + ) if activation_checkpoint_params: checkpoint_impl = activation_checkpoint_params.checkpoint_impl @@ -850,6 +815,42 @@ def _validate_torch_compile_available() -> None: ) +def _apply_strategy_and_check( + module: torch.nn.Module, + device: torch.device, + strategy: Optional[Union[Strategy, str]], + swa_params: Optional[SWAParams], + torch_compile_params: Optional[TorchCompileParams], +) -> torch.nn.Module: + """ + Applies the given strategy on the module, and checks for any + incompatibilies between the chosen strategy and swa / torch compile params. + """ + if strategy: + if isinstance(strategy, str): + strategy = _convert_str_to_strategy(strategy) + if isinstance(strategy, DDPStrategy): + if torch_compile_params and strategy.static_graph is True: + # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 + raise RuntimeError( + "Torch compile requires DDPStrategy's static_graph to be False" + ) + module = prepare_ddp(module, device, strategy) + elif isinstance(strategy, FSDPStrategy): + if swa_params: + raise RuntimeError( + "Stochastic Weight Averaging is currently not supported with the FSDP strategy" + ) + # as stated here https://pytorch.org/get-started/pytorch-2.0/ + rank_zero_warn( + "We recommend setting FSDPStrategy's use_original_params to True when using torch compile." + ) + module = prepare_fsdp(module, device, strategy) + else: + module = module.to(device) + return module + + def _convert_str_to_strategy(strategy: str) -> Union[DDPStrategy, FSDPStrategy]: """ Converts strategy as a string to a default instance of the Strategy dataclass.