Skip to content

Commit

Permalink
Revert changes to batch sampler, swap to pretrained sampler
Browse files Browse the repository at this point in the history
Signed-off-by: SeanNaren <[email protected]>
  • Loading branch information
SeanNaren committed Apr 4, 2023
1 parent bb447e9 commit 9c6db11
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,24 +140,23 @@ def __iter__(self):

class MegatronPretrainingBatchSampler(BaseMegatronBatchSampler):
def get_start_end_idx(self) -> Tuple[int, int]:
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
start_idx = self.data_parallel_rank * self._global_batch_size_on_this_data_parallel_rank
end_idx = start_idx + self._global_batch_size_on_this_data_parallel_rank
return start_idx, end_idx

def __iter__(self):
batch = []

data_parallel_micro_batch_size = self.data_parallel_size * self.micro_batch_size
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == data_parallel_micro_batch_size:
if len(batch) == self._global_batch_size:
# start_idx, end_idx = self.get_start_end_idx()
indices = [
batch[i]
for i in range(self.data_parallel_rank, data_parallel_micro_batch_size, self.data_parallel_size)
batch[i] for i in range(self.data_parallel_rank, self._global_batch_size, self.data_parallel_size,)
]
assert len(indices) == self.micro_batch_size
assert len(indices) == self._global_batch_size_on_this_data_parallel_rank
yield indices
# yield batch[start_idx:end_idx]
batch = []

# Check the last partial batch and see drop_last is set
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import (
MegatronPretrainingBatchSampler,
MegatronPretrainingRandomBatchSampler,
from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import (
MegatronPretrainingRandomSampler,
MegatronPretrainingSampler,
)
from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel
from nemo.collections.nlp.modules.common.megatron.build_model import build_model
Expand Down Expand Up @@ -803,7 +803,7 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, num_workers):
# Megatron sampler
if hasattr(self._cfg.data, 'dataloader_type') and self._cfg.data.dataloader_type is not None:
if self._cfg.data.dataloader_type == 'single':
batch_sampler = MegatronPretrainingBatchSampler(
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=self._cfg.micro_batch_size,
Expand All @@ -813,11 +813,11 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, num_workers):
drop_last=self._cfg.get('drop_last', True),
)
elif self._cfg.data.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomBatchSampler(
batch_sampler = MegatronPretrainingRandomSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=self._cfg.micro_batch_size,
global_batch_size=self._cffg.global_batch_size,
global_batch_size=self._cfg.global_batch_size,
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size(),
drop_last=self._cfg.get('drop_last', True),
Expand Down

0 comments on commit 9c6db11

Please sign in to comment.