From b7216b81635e399d2024adafd97ee6173f88d267 Mon Sep 17 00:00:00 2001 From: Farhad Ramezanghorbani Date: Fri, 9 Aug 2024 12:44:44 -0600 Subject: [PATCH] Drop PyTorch 2.1 version check from fabric strategies (#10079) * rm torch version check Signed-off-by: Farhad Ramezanghorbani * bump min torch version Signed-off-by: Farhad Ramezanghorbani * rm version Signed-off-by: Farhad Ramezanghorbani --------- Signed-off-by: Farhad Ramezanghorbani Co-authored-by: Marc Romeyn --- nemo/lightning/fabric/strategies.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nemo/lightning/fabric/strategies.py b/nemo/lightning/fabric/strategies.py index a662386a9119..5c2b634ea282 100644 --- a/nemo/lightning/fabric/strategies.py +++ b/nemo/lightning/fabric/strategies.py @@ -23,7 +23,6 @@ from lightning_fabric.plugins.precision import Precision from lightning_fabric.strategies import DDPStrategy from lightning_fabric.strategies.strategy import _validate_keys_for_strict_loading -from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from lightning_fabric.utilities.types import _PATH, _Stateful from megatron.core.distributed import DistributedDataParallelConfig from pytorch_lightning.loops.fetchers import _DataFetcher @@ -208,7 +207,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.megatron_context() stack = ExitStack() - if _TORCH_GREATER_EQUAL_2_1 and empty_init: + if empty_init: # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: # 1) materialize module 2) call `reset_parameters()` 3) shard the module. # These operations are applied to each submodule 'bottom up' in the module hierarchy.