From 9c6db112d5045934b824cea5dc6c3669bdb4ee70 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 29 Mar 2023 13:11:57 -0700 Subject: [PATCH] Revert changes to batch sampler, swap to pretrained sampler Signed-off-by: SeanNaren --- .../megatron/megatron_batch_samplers.py | 15 +++++++-------- .../megatron_lm_encoder_decoder_model.py | 12 ++++++------ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py b/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py index 71c02af766ea..c9791bc0147d 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py @@ -140,24 +140,23 @@ def __iter__(self): class MegatronPretrainingBatchSampler(BaseMegatronBatchSampler): def get_start_end_idx(self) -> Tuple[int, int]: - start_idx = self.data_parallel_rank * self.micro_batch_size - end_idx = start_idx + self.micro_batch_size + start_idx = self.data_parallel_rank * self._global_batch_size_on_this_data_parallel_rank + end_idx = start_idx + self._global_batch_size_on_this_data_parallel_rank return start_idx, end_idx def __iter__(self): batch = [] - - data_parallel_micro_batch_size = self.data_parallel_size * self.micro_batch_size # Last batch will be dropped if drop_last is not set False for idx in range(self.consumed_samples, self.total_samples): batch.append(idx) - if len(batch) == data_parallel_micro_batch_size: + if len(batch) == self._global_batch_size: + # start_idx, end_idx = self.get_start_end_idx() indices = [ - batch[i] - for i in range(self.data_parallel_rank, data_parallel_micro_batch_size, self.data_parallel_size) + batch[i] for i in range(self.data_parallel_rank, self._global_batch_size, self.data_parallel_size,) ] - assert len(indices) == self.micro_batch_size + assert len(indices) == self._global_batch_size_on_this_data_parallel_rank yield indices + # yield batch[start_idx:end_idx] batch = [] # Check the last partial batch and see drop_last is set diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index e2ee05686956..0d166a69f3ec 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -23,9 +23,9 @@ from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.trainer.trainer import Trainer -from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( - MegatronPretrainingBatchSampler, - MegatronPretrainingRandomBatchSampler, +from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( + MegatronPretrainingRandomSampler, + MegatronPretrainingSampler, ) from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel from nemo.collections.nlp.modules.common.megatron.build_model import build_model @@ -803,7 +803,7 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, num_workers): # Megatron sampler if hasattr(self._cfg.data, 'dataloader_type') and self._cfg.data.dataloader_type is not None: if self._cfg.data.dataloader_type == 'single': - batch_sampler = MegatronPretrainingBatchSampler( + batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self._cfg.micro_batch_size, @@ -813,11 +813,11 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, num_workers): drop_last=self._cfg.get('drop_last', True), ) elif self._cfg.data.dataloader_type == 'cyclic': - batch_sampler = MegatronPretrainingRandomBatchSampler( + batch_sampler = MegatronPretrainingRandomSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self._cfg.micro_batch_size, - global_batch_size=self._cffg.global_batch_size, + global_batch_size=self._cfg.global_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), drop_last=self._cfg.get('drop_last', True),