Skip to content

Commit b5798de

Browse files
ashors1ShriyaPalsamudrampablo-garay
authored
[NeMo-UX] Use custom BatchProgress class which does not restore states (NVIDIA#10383)
* [WIP] fix batch sampler to match megatron dataloaders Signed-off-by: ashors1 <[email protected]> * make batchprogress configurable Signed-off-by: ashors1 <[email protected]> * Apply isort and black reformatting Signed-off-by: ashors1 <[email protected]> --------- Signed-off-by: ashors1 <[email protected]> Signed-off-by: ashors1 <[email protected]> Co-authored-by: ashors1 <[email protected]> Co-authored-by: Shriya Rishab <[email protected]> Co-authored-by: Pablo Garay <[email protected]>
1 parent 3a60491 commit b5798de

File tree

3 files changed

+25
-0
lines changed

3 files changed

+25
-0
lines changed

nemo/lightning/pytorch/strategies/fsdp_strategy.py

+8
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from nemo.lightning import io
3737
from nemo.lightning.pytorch.strategies.utils import (
38+
_MegatronBatchProgress,
3839
ckpt_to_dir,
3940
create_checkpoint_io,
4041
fix_progress_bar,
@@ -73,13 +74,15 @@ def __init__(
7374
ckpt_load_optimizer: bool = True,
7475
ckpt_save_optimizer: bool = True,
7576
data_sampler=None,
77+
overwrite_batch_progress: bool = True,
7678
**kwargs,
7779
):
7880
super().__init__(auto_wrap_policy=auto_wrap_policy, state_dict_type=state_dict_type, **kwargs)
7981

8082
self.data_sampler = data_sampler
8183
self.ckpt_load_optimizer = ckpt_load_optimizer
8284
self.ckpt_save_optimizer = ckpt_save_optimizer
85+
self.overwrite_batch_progress = overwrite_batch_progress
8386

8487
@override
8588
def setup_environment(self) -> None:
@@ -92,6 +95,11 @@ def setup(self, trainer: pl.Trainer) -> None:
9295
self.trainer = trainer
9396
setup_data_sampler(self.trainer)
9497
fix_progress_bar(trainer)
98+
99+
trainer_fn = trainer.state.fn
100+
if trainer_fn == TrainerFn.FITTING and self.overwrite_batch_progress:
101+
trainer.fit_loop.epoch_loop.batch_progress = _MegatronBatchProgress()
102+
95103
super().setup(trainer)
96104

97105
def _get_loss_reduction(self, step_type: str):

nemo/lightning/pytorch/strategies/megatron_strategy.py

+7
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from nemo.lightning.pytorch.callbacks import ModelTransform
6262
from nemo.lightning.pytorch.strategies.utils import (
6363
RestoreConfig,
64+
_MegatronBatchProgress,
6465
ckpt_to_dir,
6566
create_checkpoint_io,
6667
fix_progress_bar,
@@ -152,6 +153,8 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
152153
that prints the metrics to stdout. Suitable for non-interactive settings.
153154
progress_interval (int): How frequently to print progress to stdout. Only used when
154155
replace_progress_bar is True.
156+
overwrite_batch_progress (bool): Whether to overwrite _BatchProgress class used in PTL by default with
157+
_MegatronBatchProgress. This should be True whenever you're using a Megatron-based dataset.
155158
**kwargs: Additional keyword arguments.
156159
157160
Note:
@@ -194,6 +197,7 @@ def __init__(
194197
replace_progress_bar: bool = True,
195198
progress_interval: int = 1,
196199
restore_config: Optional[RestoreConfig] = None,
200+
overwrite_batch_progress: bool = True,
197201
**kwargs,
198202
) -> None:
199203
super().__init__(
@@ -234,6 +238,7 @@ def __init__(
234238

235239
self.replace_progress_bar = replace_progress_bar
236240
self.progress_interval = progress_interval
241+
self.overwrite_batch_progress = overwrite_batch_progress
237242

238243
self.restore_config = restore_config
239244

@@ -331,6 +336,8 @@ def setup(self, trainer: pl.Trainer) -> None:
331336
self.configure_ddp()
332337

333338
trainer.fit_loop.epoch_loop.automatic_optimization = _MegatronAutomaticOptimization(trainer)
339+
if self.overwrite_batch_progress:
340+
trainer.fit_loop.epoch_loop.batch_progress = _MegatronBatchProgress()
334341

335342
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
336343

nemo/lightning/pytorch/strategies/utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
from megatron.core.dist_checkpointing.strategies.torch import sharded_tensor_to_torch_sharded_tensor
2626
from megatron.core.transformer.utils import _get_extra_state_offsets
2727
from pytorch_lightning.callbacks import TQDMProgressBar
28+
from pytorch_lightning.loops.progress import _BatchProgress
2829
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
2930
from torch.distributed._sharded_tensor import ShardedTensor as TorchShardedTensor
3031
from torch.distributed._tensor import DTensor, Replicate, Shard
3132
from torch.distributed.device_mesh import DeviceMesh
33+
from typing_extensions import override
3234

3335
from nemo.lightning import _strategy_lib
3436
from nemo.lightning.io.pl import MegatronCheckpointIO
@@ -46,6 +48,14 @@ class RestoreConfig:
4648
load_artifacts: bool = True
4749

4850

51+
class _MegatronBatchProgress(_BatchProgress):
52+
@override
53+
def load_state_dict(self, state_dict: dict) -> None:
54+
## in megatron, we want to start the batch progress over when
55+
## restoring from a checkpoint
56+
return
57+
58+
4959
def setup_parallel_ranks(strategy: pl.strategies.Strategy):
5060
from megatron.core.model_parallel_config import ModelParallelConfig
5161

0 commit comments

Comments
 (0)