Skip to content

Commit

Permalink
replace warning w/ hard error in AutoUnit's fsdp + torch compile
Browse files Browse the repository at this point in the history
Reviewed By: daniellepintz, rohan-varma

Differential Revision: D47473960

fbshipit-source-id: bbe5017c5eaf1e49a54e3ae14ad114c2df30d984
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jul 17, 2023
1 parent 3cbec3c commit f3cd143
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,14 @@ def __init__(
if torch_compile_params and strategy.static_graph is True:
# https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860
raise RuntimeError(
"Torch compile requires DDPStrategy's static_graph to be False"
"Torch compile requires DDPStrategy's static_graph to be False."
)
module = prepare_ddp(module, self.device, strategy)
elif isinstance(strategy, FSDPStrategy):
if torch_compile_params and strategy.use_orig_params is False:
# as stated here https://pytorch.org/get-started/pytorch-2.0/
rank_zero_warn(
"We recommend setting FSDPStrategy's use_orig_params to True when using torch compile."
raise RuntimeError(
"Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters."
)
module = prepare_fsdp(
module,
Expand Down Expand Up @@ -436,18 +436,19 @@ def __init__(
if torch_compile_params and strategy.static_graph is True:
# https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860
raise RuntimeError(
"Torch compile requires DDPStrategy's static_graph to be False"
"Torch compile requires DDPStrategy's static_graph to be False."
)
module = prepare_ddp(module, self.device, strategy)
elif isinstance(strategy, FSDPStrategy):
if swa_params:
raise RuntimeError(
"Stochastic Weight Averaging is currently not supported with the FSDP strategy"
)
# as stated here https://pytorch.org/get-started/pytorch-2.0/
rank_zero_warn(
"We recommend setting FSDPStrategy's use_original_params to True when using torch compile."
)
if torch_compile_params and strategy.use_orig_params is False:
# as stated here https://pytorch.org/get-started/pytorch-2.0/
raise RuntimeError(
"Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters."
)
module = prepare_fsdp(module, self.device, strategy, self.precision)
else:
module = module.to(self.device)
Expand Down

0 comments on commit f3cd143

Please sign in to comment.