Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update BaseMegatronSampler for compatibility with PTL's _BatchProgress #11016

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions nemo/lightning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,16 @@ def __init__(
)

def __len__(self):
num_available_samples: int = self.total_samples - self.consumed_samples
if self.global_batch_size is not None:
if self.drop_last:
num_global_batches = num_available_samples // self.global_batch_size
num_global_batches = self.total_samples // self.global_batch_size
else:
num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
num_global_batches = (self.total_samples + self.global_batch_size - 1) // self.global_batch_size
# return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and
# num of batches fetched (as training step fetches in terms of micro batches)
return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size)
else:
return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1
return (self.total_samples - 1) // self.micro_batch_times_data_parallel_size + 1

@abc.abstractmethod
def __iter__(self): ...
Expand Down
8 changes: 0 additions & 8 deletions nemo/lightning/pytorch/strategies/fsdp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

from nemo.lightning import io
from nemo.lightning.pytorch.strategies.utils import (
_MegatronBatchProgress,
ckpt_to_dir,
create_checkpoint_io,
fix_progress_bar,
Expand Down Expand Up @@ -74,15 +73,13 @@ def __init__(
ckpt_load_optimizer: bool = True,
ckpt_save_optimizer: bool = True,
data_sampler=None,
overwrite_batch_progress: bool = True,
**kwargs,
):
super().__init__(auto_wrap_policy=auto_wrap_policy, state_dict_type=state_dict_type, **kwargs)

self.data_sampler = data_sampler
self.ckpt_load_optimizer = ckpt_load_optimizer
self.ckpt_save_optimizer = ckpt_save_optimizer
self.overwrite_batch_progress = overwrite_batch_progress

@override
def setup_environment(self) -> None:
Expand All @@ -95,11 +92,6 @@ def setup(self, trainer: pl.Trainer) -> None:
self.trainer = trainer
setup_data_sampler(self.trainer)
fix_progress_bar(trainer)

trainer_fn = trainer.state.fn
if trainer_fn == TrainerFn.FITTING and self.overwrite_batch_progress:
trainer.fit_loop.epoch_loop.batch_progress = _MegatronBatchProgress()

super().setup(trainer)

def _get_loss_reduction(self, step_type: str):
Expand Down
7 changes: 0 additions & 7 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
from nemo.lightning.pytorch.callbacks import ModelTransform
from nemo.lightning.pytorch.strategies.utils import (
RestoreConfig,
_MegatronBatchProgress,
ckpt_to_dir,
create_checkpoint_io,
fix_progress_bar,
Expand Down Expand Up @@ -160,8 +159,6 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
that prints the metrics to stdout. Suitable for non-interactive settings.
progress_interval (int): How frequently to print progress to stdout. Only used when
replace_progress_bar is True.
overwrite_batch_progress (bool): Whether to overwrite _BatchProgress class used in PTL by default with
_MegatronBatchProgress. This should be True whenever you're using a Megatron-based dataset.
**kwargs: Additional keyword arguments.

Note:
Expand Down Expand Up @@ -204,7 +201,6 @@ def __init__(
replace_progress_bar: bool = True,
progress_interval: int = 1,
restore_config: Optional[RestoreConfig] = None,
overwrite_batch_progress: bool = True,
**kwargs,
) -> None:
super().__init__(
Expand Down Expand Up @@ -245,7 +241,6 @@ def __init__(

self.replace_progress_bar = replace_progress_bar
self.progress_interval = progress_interval
self.overwrite_batch_progress = overwrite_batch_progress

self.restore_config = restore_config

Expand Down Expand Up @@ -345,8 +340,6 @@ def setup(self, trainer: pl.Trainer) -> None:
self.configure_ddp()

trainer.fit_loop.epoch_loop.automatic_optimization = _MegatronAutomaticOptimization(trainer)
if self.overwrite_batch_progress:
trainer.fit_loop.epoch_loop.batch_progress = _MegatronBatchProgress()

import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD

Expand Down
10 changes: 0 additions & 10 deletions nemo/lightning/pytorch/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@
from megatron.core.dist_checkpointing.strategies.torch import sharded_tensor_to_torch_sharded_tensor
from megatron.core.transformer.utils import _get_extra_state_offsets
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.loops.progress import _BatchProgress
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from torch.distributed._sharded_tensor import ShardedTensor as TorchShardedTensor
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh
from typing_extensions import override

from nemo.lightning import _strategy_lib
from nemo.lightning.io.pl import MegatronCheckpointIO
Expand All @@ -48,14 +46,6 @@ class RestoreConfig:
load_artifacts: bool = True


class _MegatronBatchProgress(_BatchProgress):
@override
def load_state_dict(self, state_dict: dict) -> None:
## in megatron, we want to start the batch progress over when
## restoring from a checkpoint
return


def setup_parallel_ranks(strategy: pl.strategies.Strategy):
from megatron.core.model_parallel_config import ModelParallelConfig

Expand Down
Loading