Skip to content

Commit f1dbf2c

Browse files
ashors1Hainan Xu
authored and
Hainan Xu
committed
Update BaseMegatronSampler for compatibility with PTL's _BatchProgress (NVIDIA#11016)
* Revert "[NeMo-UX] Use custom `BatchProgress` class which does not restore states (NVIDIA#10383)" This reverts commit b5798de. * make megatron sampler return the total number of batches in the dataset Signed-off-by: ashors1 <[email protected]> --------- Signed-off-by: ashors1 <[email protected]> Signed-off-by: Hainan Xu <[email protected]>
1 parent 361c720 commit f1dbf2c

File tree

4 files changed

+3
-29
lines changed

4 files changed

+3
-29
lines changed

nemo/lightning/data.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -286,17 +286,16 @@ def __init__(
286286
)
287287

288288
def __len__(self):
289-
num_available_samples: int = self.total_samples - self.consumed_samples
290289
if self.global_batch_size is not None:
291290
if self.drop_last:
292-
num_global_batches = num_available_samples // self.global_batch_size
291+
num_global_batches = self.total_samples // self.global_batch_size
293292
else:
294-
num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
293+
num_global_batches = (self.total_samples + self.global_batch_size - 1) // self.global_batch_size
295294
# return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and
296295
# num of batches fetched (as training step fetches in terms of micro batches)
297296
return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size)
298297
else:
299-
return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1
298+
return (self.total_samples - 1) // self.micro_batch_times_data_parallel_size + 1
300299

301300
@abc.abstractmethod
302301
def __iter__(self): ...

nemo/lightning/pytorch/strategies/fsdp_strategy.py

-8
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
from nemo.lightning import io
3737
from nemo.lightning.pytorch.strategies.utils import (
38-
_MegatronBatchProgress,
3938
ckpt_to_dir,
4039
create_checkpoint_io,
4140
fix_progress_bar,
@@ -74,15 +73,13 @@ def __init__(
7473
ckpt_load_optimizer: bool = True,
7574
ckpt_save_optimizer: bool = True,
7675
data_sampler=None,
77-
overwrite_batch_progress: bool = True,
7876
**kwargs,
7977
):
8078
super().__init__(auto_wrap_policy=auto_wrap_policy, state_dict_type=state_dict_type, **kwargs)
8179

8280
self.data_sampler = data_sampler
8381
self.ckpt_load_optimizer = ckpt_load_optimizer
8482
self.ckpt_save_optimizer = ckpt_save_optimizer
85-
self.overwrite_batch_progress = overwrite_batch_progress
8683

8784
@override
8885
def setup_environment(self) -> None:
@@ -95,11 +92,6 @@ def setup(self, trainer: pl.Trainer) -> None:
9592
self.trainer = trainer
9693
setup_data_sampler(self.trainer)
9794
fix_progress_bar(trainer)
98-
99-
trainer_fn = trainer.state.fn
100-
if trainer_fn == TrainerFn.FITTING and self.overwrite_batch_progress:
101-
trainer.fit_loop.epoch_loop.batch_progress = _MegatronBatchProgress()
102-
10395
super().setup(trainer)
10496

10597
def _get_loss_reduction(self, step_type: str):

nemo/lightning/pytorch/strategies/megatron_strategy.py

-7
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
from nemo.lightning.pytorch.callbacks import ModelTransform
6868
from nemo.lightning.pytorch.strategies.utils import (
6969
RestoreConfig,
70-
_MegatronBatchProgress,
7170
ckpt_to_dir,
7271
create_checkpoint_io,
7372
fix_progress_bar,
@@ -160,8 +159,6 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
160159
that prints the metrics to stdout. Suitable for non-interactive settings.
161160
progress_interval (int): How frequently to print progress to stdout. Only used when
162161
replace_progress_bar is True.
163-
overwrite_batch_progress (bool): Whether to overwrite _BatchProgress class used in PTL by default with
164-
_MegatronBatchProgress. This should be True whenever you're using a Megatron-based dataset.
165162
**kwargs: Additional keyword arguments.
166163
167164
Note:
@@ -204,7 +201,6 @@ def __init__(
204201
replace_progress_bar: bool = True,
205202
progress_interval: int = 1,
206203
restore_config: Optional[RestoreConfig] = None,
207-
overwrite_batch_progress: bool = True,
208204
**kwargs,
209205
) -> None:
210206
super().__init__(
@@ -245,7 +241,6 @@ def __init__(
245241

246242
self.replace_progress_bar = replace_progress_bar
247243
self.progress_interval = progress_interval
248-
self.overwrite_batch_progress = overwrite_batch_progress
249244

250245
self.restore_config = restore_config
251246

@@ -345,8 +340,6 @@ def setup(self, trainer: pl.Trainer) -> None:
345340
self.configure_ddp()
346341

347342
trainer.fit_loop.epoch_loop.automatic_optimization = _MegatronAutomaticOptimization(trainer)
348-
if self.overwrite_batch_progress:
349-
trainer.fit_loop.epoch_loop.batch_progress = _MegatronBatchProgress()
350343

351344
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
352345

nemo/lightning/pytorch/strategies/utils.py

-10
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,10 @@
2525
from megatron.core.dist_checkpointing.strategies.torch import sharded_tensor_to_torch_sharded_tensor
2626
from megatron.core.transformer.utils import _get_extra_state_offsets
2727
from pytorch_lightning.callbacks import TQDMProgressBar
28-
from pytorch_lightning.loops.progress import _BatchProgress
2928
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
3029
from torch.distributed._sharded_tensor import ShardedTensor as TorchShardedTensor
3130
from torch.distributed._tensor import DTensor, Replicate, Shard
3231
from torch.distributed.device_mesh import DeviceMesh
33-
from typing_extensions import override
3432

3533
from nemo.lightning import _strategy_lib
3634
from nemo.lightning.io.pl import MegatronCheckpointIO
@@ -48,14 +46,6 @@ class RestoreConfig:
4846
load_artifacts: bool = True
4947

5048

51-
class _MegatronBatchProgress(_BatchProgress):
52-
@override
53-
def load_state_dict(self, state_dict: dict) -> None:
54-
## in megatron, we want to start the batch progress over when
55-
## restoring from a checkpoint
56-
return
57-
58-
5949
def setup_parallel_ranks(strategy: pl.strategies.Strategy):
6050
from megatron.core.model_parallel_config import ModelParallelConfig
6151

0 commit comments

Comments
 (0)