From f3cd1437ea02007cd94924c39315c799e44d0c32 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Mon, 17 Jul 2023 10:50:18 -0700 Subject: [PATCH] replace warning w/ hard error in AutoUnit's fsdp + torch compile Reviewed By: daniellepintz, rohan-varma Differential Revision: D47473960 fbshipit-source-id: bbe5017c5eaf1e49a54e3ae14ad114c2df30d984 --- torchtnt/framework/auto_unit.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index e3c8d9f889..a556220610 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -198,14 +198,14 @@ def __init__( if torch_compile_params and strategy.static_graph is True: # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 raise RuntimeError( - "Torch compile requires DDPStrategy's static_graph to be False" + "Torch compile requires DDPStrategy's static_graph to be False." ) module = prepare_ddp(module, self.device, strategy) elif isinstance(strategy, FSDPStrategy): if torch_compile_params and strategy.use_orig_params is False: # as stated here https://pytorch.org/get-started/pytorch-2.0/ - rank_zero_warn( - "We recommend setting FSDPStrategy's use_orig_params to True when using torch compile." + raise RuntimeError( + "Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters." ) module = prepare_fsdp( module, @@ -436,7 +436,7 @@ def __init__( if torch_compile_params and strategy.static_graph is True: # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860 raise RuntimeError( - "Torch compile requires DDPStrategy's static_graph to be False" + "Torch compile requires DDPStrategy's static_graph to be False." ) module = prepare_ddp(module, self.device, strategy) elif isinstance(strategy, FSDPStrategy): @@ -444,10 +444,11 @@ def __init__( raise RuntimeError( "Stochastic Weight Averaging is currently not supported with the FSDP strategy" ) - # as stated here https://pytorch.org/get-started/pytorch-2.0/ - rank_zero_warn( - "We recommend setting FSDPStrategy's use_original_params to True when using torch compile." - ) + if torch_compile_params and strategy.use_orig_params is False: + # as stated here https://pytorch.org/get-started/pytorch-2.0/ + raise RuntimeError( + "Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters." + ) module = prepare_fsdp(module, self.device, strategy, self.precision) else: module = module.to(self.device)