From 39ad0373bd49a35272616b61f78ef28ef4299897 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Mon, 14 Aug 2023 10:51:32 -0700 Subject: [PATCH] replace warning w/ hard error in AutoUnit's fsdp + torch compile (#460) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/460 Reviewed By: daniellepintz, rohan-varma Differential Revision: D47473960 fbshipit-source-id: 6b8e943fb2e8bca40eb1b6542d6608078c3629d4 --- torchtnt/framework/auto_unit.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 41c255295a..7a9a57443f 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -194,14 +194,14 @@ def __init__( if torch_compile_params and strategy.static_graph is True: # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 raise RuntimeError( - "Torch compile requires DDPStrategy's static_graph to be False" + "Torch compile requires DDPStrategy's static_graph to be False." ) module = prepare_ddp(module, self.device, strategy) elif isinstance(strategy, FSDPStrategy): if torch_compile_params and strategy.use_orig_params is False: # as stated here https://pytorch.org/get-started/pytorch-2.0/ - rank_zero_warn( - "We recommend setting FSDPStrategy's use_orig_params to True when using torch compile." + raise RuntimeError( + "Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters." ) module = prepare_fsdp( module, @@ -435,7 +435,7 @@ def __init__( if torch_compile_params and strategy.static_graph is True: # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 raise RuntimeError( - "Torch compile requires DDPStrategy's static_graph to be False" + "Torch compile requires DDPStrategy's static_graph to be False." ) module = prepare_ddp(module, self.device, strategy) elif isinstance(strategy, FSDPStrategy): @@ -443,10 +443,11 @@ def __init__( raise RuntimeError( "Stochastic Weight Averaging is currently not supported with the FSDP strategy" ) - # as stated here https://pytorch.org/get-started/pytorch-2.0/ - rank_zero_warn( - "We recommend setting FSDPStrategy's use_original_params to True when using torch compile." - ) + if torch_compile_params and strategy.use_orig_params is False: + # as stated here https://pytorch.org/get-started/pytorch-2.0/ + raise RuntimeError( + "Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters." + ) module = prepare_fsdp(module, self.device, strategy, self.precision) else: module = module.to(self.device)