Skip to content

Commit

Permalink
Reworked MegatronPretrainingRandomBatchSampler to correctly handle ep…
Browse files Browse the repository at this point in the history
…ochs > 1 (NVIDIA#7920)

* Initital commit of reworked MegatronPretrainingRandomBatchSampler

Signed-off-by: Daniel Egert <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed small length based bug

Signed-off-by: Daniel Egert <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Daniel Egert <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
3 people authored and pzelasko committed Jan 3, 2024
1 parent cefafb6 commit 0dc0b2c
Showing 1 changed file with 44 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ class MegatronPretrainingRandomBatchSampler(BaseMegatronBatchSampler):
# are necessary for ViT training. However, to keep this simple,
# I omit those two arguments.
# commit: https://github.com/NVIDIA/Megatron-LM/commit/7a77abd9b6267dc0020a60b424b4748fc22790bb
#
# NOTE (degert): I have re-written this class somewhat as previous implementation relied on the
# base class constructor which would have thrown in the case of consumed_samples >= total_samples
# which this class was designed to do, as that is how it implicitly calculates the current epoch
# I have also added an explicit seed which allows us to remove Dataset-side shuffling in Nemo-Aligner
def __init__(
self,
total_samples: int,
Expand All @@ -184,20 +189,47 @@ def __init__(
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool,
pad_samples_to_global_batch_size: bool = False,
seed: int = 0,
) -> None:
super().__init__(
total_samples=total_samples,
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
data_parallel_rank=data_parallel_rank,
data_parallel_size=data_parallel_size,
drop_last=drop_last,
)

# Sanity checks.
if total_samples <= 0:
raise RuntimeError("no sample to consume: {}".format(total_samples))
if micro_batch_size <= 0:
raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise RuntimeError(
"data_parallel_rank should be smaller than data size, but {} >= {}".format(
data_parallel_rank, data_parallel_size
)
)

self.total_samples: int = total_samples
self.consumed_samples: int = consumed_samples
self.micro_batch_size: int = micro_batch_size
self.data_parallel_rank: int = data_parallel_rank
self.data_parallel_size: int = data_parallel_size
self.drop_last: bool = drop_last
self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size
self.seed = seed

self.update_global_batch_size(global_batch_size)
self.last_batch_size = self.total_samples % self._global_batch_size

def __len__(self):
num_available_samples = self.total_samples
def __len__(self) -> int:
"""Length of Random Batch Sampler.
..note::
When `rampup_batch_size` is enabled, the return value can be not exactly precise.
"""
active_total_samples = self.total_samples - self.last_batch_size
num_available_samples = (
active_total_samples * (1 + (self.consumed_samples // active_total_samples))
) - self.consumed_samples
if self.drop_last:
return num_available_samples // self.global_batch_size
else:
Expand All @@ -215,7 +247,7 @@ def __iter__(self):
start_idx = self.data_parallel_rank * bucket_size

g = torch.Generator()
g.manual_seed(self.epoch)
g.manual_seed(self.seed + self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]

Expand Down

0 comments on commit 0dc0b2c

Please sign in to comment.