Skip to content

Commit

Permalink
replace warning w/ hard error in AutoUnit's fsdp + torch compile (#460)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #460

Reviewed By: daniellepintz, rohan-varma

Differential Revision: D47473960

fbshipit-source-id: 728a57f5df5af78318be76d2c8830630fabf3e67
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jul 17, 2023
1 parent 3cbec3c commit bed9206
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
prepare_fsdp,
Strategy,
)
from torchtnt.utils.rank_zero_log import rank_zero_warn
from torchtnt.utils.version import is_torch_version_ge_1_13_1
from typing_extensions import Literal

Expand Down Expand Up @@ -198,14 +197,14 @@ def __init__(
if torch_compile_params and strategy.static_graph is True:
# https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860
raise RuntimeError(
"Torch compile requires DDPStrategy's static_graph to be False"
"Torch compile requires DDPStrategy's static_graph to be False."
)
module = prepare_ddp(module, self.device, strategy)
elif isinstance(strategy, FSDPStrategy):
if torch_compile_params and strategy.use_orig_params is False:
# as stated here https://pytorch.org/get-started/pytorch-2.0/
rank_zero_warn(
"We recommend setting FSDPStrategy's use_orig_params to True when using torch compile."
raise RuntimeError(
"Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters."
)
module = prepare_fsdp(
module,
Expand Down Expand Up @@ -436,18 +435,19 @@ def __init__(
if torch_compile_params and strategy.static_graph is True:
# https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860
raise RuntimeError(
"Torch compile requires DDPStrategy's static_graph to be False"
"Torch compile requires DDPStrategy's static_graph to be False."
)
module = prepare_ddp(module, self.device, strategy)
elif isinstance(strategy, FSDPStrategy):
if swa_params:
raise RuntimeError(
"Stochastic Weight Averaging is currently not supported with the FSDP strategy"
)
# as stated here https://pytorch.org/get-started/pytorch-2.0/
rank_zero_warn(
"We recommend setting FSDPStrategy's use_original_params to True when using torch compile."
)
if torch_compile_params and strategy.use_orig_params is False:
# as stated here https://pytorch.org/get-started/pytorch-2.0/
raise RuntimeError(
"Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters."
)
module = prepare_fsdp(module, self.device, strategy, self.precision)
else:
module = module.to(self.device)
Expand Down

0 comments on commit bed9206

Please sign in to comment.