From fa3874ac48c4f90775992491a5205ba2ff674fd0 Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Fri, 19 Jan 2024 08:30:40 +0900 Subject: [PATCH] Add the interface to use SHARP to FSDP strategy Signed-off-by: Sangkug Lym --- nemo/collections/nlp/parts/megatron_trainer_builder.py | 1 + nemo/collections/nlp/parts/nlp_overrides.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index c58e0be4a508..3a19d8e186e8 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -69,6 +69,7 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]: sharded_checkpoint=sharded_checkpoint, precision=self.cfg.trainer.precision, nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None), + sharp=self.cfg.model.get('sharp', False), ) return NLPDDPStrategy( diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 5075863c3dbb..6ee36d6983cb 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -517,6 +517,7 @@ def __init__( sharded_checkpoint: bool = False, precision: Union[int, str] = 'bf16-mixed', nccl_communicator_config_path: Optional[str] = None, + sharp: bool = False, **kwargs: Union[Any, Dict[str, Any]], ) -> None: if not HAVE_APEX: @@ -561,6 +562,7 @@ def __init__( ) self.nccl_communicator_config_path = nccl_communicator_config_path + self.sharp = sharp super().__init__(**kwargs) def _set_mixed_precision_recipe(