From bed92062f9b844b9c39c6fbac17df4520a2838b8 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Mon, 17 Jul 2023 11:12:02 -0700 Subject: [PATCH] replace warning w/ hard error in AutoUnit's fsdp + torch compile (#460) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/460 Reviewed By: daniellepintz, rohan-varma Differential Revision: D47473960 fbshipit-source-id: 728a57f5df5af78318be76d2c8830630fabf3e67 --- torchtnt/framework/auto_unit.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index e3c8d9f889..cb53d7468a 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -43,7 +43,6 @@ prepare_fsdp, Strategy, ) -from torchtnt.utils.rank_zero_log import rank_zero_warn from torchtnt.utils.version import is_torch_version_ge_1_13_1 from typing_extensions import Literal @@ -198,14 +197,14 @@ def __init__( if torch_compile_params and strategy.static_graph is True: # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 raise RuntimeError( - "Torch compile requires DDPStrategy's static_graph to be False" + "Torch compile requires DDPStrategy's static_graph to be False." ) module = prepare_ddp(module, self.device, strategy) elif isinstance(strategy, FSDPStrategy): if torch_compile_params and strategy.use_orig_params is False: # as stated here https://pytorch.org/get-started/pytorch-2.0/ - rank_zero_warn( - "We recommend setting FSDPStrategy's use_orig_params to True when using torch compile." + raise RuntimeError( + "Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters." ) module = prepare_fsdp( module, @@ -436,7 +435,7 @@ def __init__( if torch_compile_params and strategy.static_graph is True: # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 raise RuntimeError( - "Torch compile requires DDPStrategy's static_graph to be False" + "Torch compile requires DDPStrategy's static_graph to be False." ) module = prepare_ddp(module, self.device, strategy) elif isinstance(strategy, FSDPStrategy): @@ -444,10 +443,11 @@ def __init__( raise RuntimeError( "Stochastic Weight Averaging is currently not supported with the FSDP strategy" ) - # as stated here https://pytorch.org/get-started/pytorch-2.0/ - rank_zero_warn( - "We recommend setting FSDPStrategy's use_original_params to True when using torch compile." - ) + if torch_compile_params and strategy.use_orig_params is False: + # as stated here https://pytorch.org/get-started/pytorch-2.0/ + raise RuntimeError( + "Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters." + ) module = prepare_fsdp(module, self.device, strategy, self.precision) else: module = module.to(self.device)