diff --git a/nemo/lightning/fabric/strategies.py b/nemo/lightning/fabric/strategies.py index a662386a9119..5c2b634ea282 100644 --- a/nemo/lightning/fabric/strategies.py +++ b/nemo/lightning/fabric/strategies.py @@ -23,7 +23,6 @@ from lightning_fabric.plugins.precision import Precision from lightning_fabric.strategies import DDPStrategy from lightning_fabric.strategies.strategy import _validate_keys_for_strict_loading -from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from lightning_fabric.utilities.types import _PATH, _Stateful from megatron.core.distributed import DistributedDataParallelConfig from pytorch_lightning.loops.fetchers import _DataFetcher @@ -208,7 +207,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.megatron_context() stack = ExitStack() - if _TORCH_GREATER_EQUAL_2_1 and empty_init: + if empty_init: # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: # 1) materialize module 2) call `reset_parameters()` 3) shard the module. # These operations are applied to each submodule 'bottom up' in the module hierarchy.