From 3c1ecc5629a14f7f77865a586396225017d68f4a Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Wed, 23 Aug 2023 11:44:24 -0700 Subject: [PATCH] share `apply_strategy` method between autounit and autopredictunit Summary: ## Context: Both `AutoUnit` and `AutoPredictUnit` use the same code block to apply the strategy on the module and check for any incompatibilties: ``` if strategy: if isinstance(strategy, str): strategy = _convert_str_to_strategy(strategy) if isinstance(strategy, DDPStrategy): if torch_compile_params and strategy.static_graph is True: # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 raise RuntimeError( "Torch compile requires DDPStrategy's static_graph to be False" ) module = prepare_ddp(module, self.device, strategy) elif isinstance(strategy, FSDPStrategy): if swa_params: raise RuntimeError( "Stochastic Weight Averaging is currently not supported with the FSDP strategy" ) # as stated here https://pytorch.org/get-started/pytorch-2.0/ rank_zero_warn( "We recommend setting FSDPStrategy's use_original_params to True when using torch compile." ) module = prepare_fsdp(module, self.device, strategy) else: module = module.to(self.device) ``` If changes are made to this logic, they must be made in both of those classes, which can be easily missed ## This Diff Creates helper function `_apply_strategy_and_check(...)` to apply the strategy on the module and calls this function in both `AutoUnit` and `AutoPredictUnit` (other name suggestions are also welcome) Differential Revision: D48612629 fbshipit-source-id: 4c43193b73a83ad5aaabe69582a213684421729f --- torchtnt/framework/auto_unit.py | 91 +++++++++++++++++---------------- 1 file changed, 46 insertions(+), 45 deletions(-) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 4f9261175e..1c3b614698 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -186,29 +186,13 @@ def __init__( dtype=self.precision, enabled=self.precision is not None, ) - if strategy: - if isinstance(strategy, str): - strategy = _convert_str_to_strategy(strategy) - if isinstance(strategy, DDPStrategy): - if torch_compile_params and strategy.static_graph is True: - # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 - raise RuntimeError( - "Torch compile requires DDPStrategy's static_graph to be False" - ) - module = prepare_ddp(module, self.device, strategy) - elif isinstance(strategy, FSDPStrategy): - if torch_compile_params and strategy.use_orig_params is False: - # as stated here https://pytorch.org/get-started/pytorch-2.0/ - rank_zero_warn( - "We recommend setting FSDPStrategy's use_orig_params to True when using torch compile." - ) - module = prepare_fsdp( - module, - self.device, - strategy, - ) - else: - module = module.to(self.device) + module = _apply_strategy_and_check( + module, + self.device, + strategy, + swa_params=None, + torch_compile_params=torch_compile_params, + ) if torch_compile_params: try: # use in-place compile to avoid altering the state_dict keys @@ -426,28 +410,9 @@ def __init__( else precision ) - if strategy: - if isinstance(strategy, str): - strategy = _convert_str_to_strategy(strategy) - if isinstance(strategy, DDPStrategy): - if torch_compile_params and strategy.static_graph is True: - # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 - raise RuntimeError( - "Torch compile requires DDPStrategy's static_graph to be False" - ) - module = prepare_ddp(module, self.device, strategy) - elif isinstance(strategy, FSDPStrategy): - if swa_params: - raise RuntimeError( - "Stochastic Weight Averaging is currently not supported with the FSDP strategy" - ) - # as stated here https://pytorch.org/get-started/pytorch-2.0/ - rank_zero_warn( - "We recommend setting FSDPStrategy's use_original_params to True when using torch compile." - ) - module = prepare_fsdp(module, self.device, strategy) - else: - module = module.to(self.device) + module = _apply_strategy_and_check( + module, self.device, strategy, swa_params, torch_compile_params + ) if activation_checkpoint_params: checkpoint_impl = activation_checkpoint_params.checkpoint_impl @@ -850,6 +815,42 @@ def _validate_torch_compile_available() -> None: ) +def _apply_strategy_and_check( + module: torch.nn.Module, + device: torch.device, + strategy: Optional[Union[Strategy, str]], + swa_params: Optional[SWAParams], + torch_compile_params: Optional[TorchCompileParams], +) -> torch.nn.Module: + """ + Applies the given strategy on the module, and checks for any + incompatibilies between the chosen strategy and swa / torch compile params. + """ + if strategy: + if isinstance(strategy, str): + strategy = _convert_str_to_strategy(strategy) + if isinstance(strategy, DDPStrategy): + if torch_compile_params and strategy.static_graph is True: + # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 + raise RuntimeError( + "Torch compile requires DDPStrategy's static_graph to be False" + ) + module = prepare_ddp(module, device, strategy) + elif isinstance(strategy, FSDPStrategy): + if swa_params: + raise RuntimeError( + "Stochastic Weight Averaging is currently not supported with the FSDP strategy" + ) + # as stated here https://pytorch.org/get-started/pytorch-2.0/ + rank_zero_warn( + "We recommend setting FSDPStrategy's use_original_params to True when using torch compile." + ) + module = prepare_fsdp(module, device, strategy) + else: + module = module.to(device) + return module + + def _convert_str_to_strategy(strategy: str) -> Union[DDPStrategy, FSDPStrategy]: """ Converts strategy as a string to a default instance of the Strategy dataclass.