Skip to content

Commit

Permalink
share apply_strategy method between autounit and autopredictunit
Browse files Browse the repository at this point in the history
Summary:
## Context:

Both `AutoUnit` and `AutoPredictUnit` use the same code block to apply the strategy on the module and check for any incompatibilties:

```
if strategy:
    if isinstance(strategy, str):
        strategy = _convert_str_to_strategy(strategy)
    if isinstance(strategy, DDPStrategy):
        if torch_compile_params and strategy.static_graph is True:
            # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860
            raise RuntimeError(
                "Torch compile requires DDPStrategy's static_graph to be False"
            )
        module = prepare_ddp(module, self.device, strategy)
    elif isinstance(strategy, FSDPStrategy):
        if swa_params:
            raise RuntimeError(
                "Stochastic Weight Averaging is currently not supported with the FSDP strategy"
            )
        # as stated here https://pytorch.org/get-started/pytorch-2.0/
        rank_zero_warn(
            "We recommend setting FSDPStrategy's use_original_params to True when using torch compile."
        )
        module = prepare_fsdp(module, self.device, strategy)
else:
    module = module.to(self.device)
```
If changes are made to this logic, they must be made in both of those classes, which can be easily missed

## This Diff
Creates helper function `_apply_strategy_and_check(...)`  to apply the strategy on the module and calls this function in both `AutoUnit` and `AutoPredictUnit` (other name suggestions are also welcome)

Differential Revision: D48612629

fbshipit-source-id: 4c43193b73a83ad5aaabe69582a213684421729f
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Aug 23, 2023
1 parent 43555dd commit 3c1ecc5
Showing 1 changed file with 46 additions and 45 deletions.
91 changes: 46 additions & 45 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,29 +186,13 @@ def __init__(
dtype=self.precision,
enabled=self.precision is not None,
)
if strategy:
if isinstance(strategy, str):
strategy = _convert_str_to_strategy(strategy)
if isinstance(strategy, DDPStrategy):
if torch_compile_params and strategy.static_graph is True:
# https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860
raise RuntimeError(
"Torch compile requires DDPStrategy's static_graph to be False"
)
module = prepare_ddp(module, self.device, strategy)
elif isinstance(strategy, FSDPStrategy):
if torch_compile_params and strategy.use_orig_params is False:
# as stated here https://pytorch.org/get-started/pytorch-2.0/
rank_zero_warn(
"We recommend setting FSDPStrategy's use_orig_params to True when using torch compile."
)
module = prepare_fsdp(
module,
self.device,
strategy,
)
else:
module = module.to(self.device)
module = _apply_strategy_and_check(
module,
self.device,
strategy,
swa_params=None,
torch_compile_params=torch_compile_params,
)
if torch_compile_params:
try:
# use in-place compile to avoid altering the state_dict keys
Expand Down Expand Up @@ -426,28 +410,9 @@ def __init__(
else precision
)

if strategy:
if isinstance(strategy, str):
strategy = _convert_str_to_strategy(strategy)
if isinstance(strategy, DDPStrategy):
if torch_compile_params and strategy.static_graph is True:
# https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860
raise RuntimeError(
"Torch compile requires DDPStrategy's static_graph to be False"
)
module = prepare_ddp(module, self.device, strategy)
elif isinstance(strategy, FSDPStrategy):
if swa_params:
raise RuntimeError(
"Stochastic Weight Averaging is currently not supported with the FSDP strategy"
)
# as stated here https://pytorch.org/get-started/pytorch-2.0/
rank_zero_warn(
"We recommend setting FSDPStrategy's use_original_params to True when using torch compile."
)
module = prepare_fsdp(module, self.device, strategy)
else:
module = module.to(self.device)
module = _apply_strategy_and_check(
module, self.device, strategy, swa_params, torch_compile_params
)

if activation_checkpoint_params:
checkpoint_impl = activation_checkpoint_params.checkpoint_impl
Expand Down Expand Up @@ -850,6 +815,42 @@ def _validate_torch_compile_available() -> None:
)


def _apply_strategy_and_check(
module: torch.nn.Module,
device: torch.device,
strategy: Optional[Union[Strategy, str]],
swa_params: Optional[SWAParams],
torch_compile_params: Optional[TorchCompileParams],
) -> torch.nn.Module:
"""
Applies the given strategy on the module, and checks for any
incompatibilies between the chosen strategy and swa / torch compile params.
"""
if strategy:
if isinstance(strategy, str):
strategy = _convert_str_to_strategy(strategy)
if isinstance(strategy, DDPStrategy):
if torch_compile_params and strategy.static_graph is True:
# https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860
raise RuntimeError(
"Torch compile requires DDPStrategy's static_graph to be False"
)
module = prepare_ddp(module, device, strategy)
elif isinstance(strategy, FSDPStrategy):
if swa_params:
raise RuntimeError(
"Stochastic Weight Averaging is currently not supported with the FSDP strategy"
)
# as stated here https://pytorch.org/get-started/pytorch-2.0/
rank_zero_warn(
"We recommend setting FSDPStrategy's use_original_params to True when using torch compile."
)
module = prepare_fsdp(module, device, strategy)
else:
module = module.to(device)
return module


def _convert_str_to_strategy(strategy: str) -> Union[DDPStrategy, FSDPStrategy]:
"""
Converts strategy as a string to a default instance of the Strategy dataclass.
Expand Down

0 comments on commit 3c1ecc5

Please sign in to comment.