Skip to content

Commit

Permalink
Override limit_val_batches for pretraining models
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishree <[email protected]>
  • Loading branch information
athitten committed Aug 25, 2023
1 parent de1e5fd commit a18a36e
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):

self._val_micro_batches_consumed = 0

def _reconfigure_val_batches(self):
"""
Reconfigure trainer.limit_val_batches for pretraining
"""
# Override limit_val_batches to be a multiple of num microbatches and so there are limit_val_batches//num_micro_batches num of global batches
self.trainer.limit_val_batches *= get_num_microbatches()
# Override num sanity steps equal to num of microbatches and perform one val_step
self.trainer.num_sanity_val_steps = get_num_microbatches()

def _enable_nvidia_optimizations(self):
"These optimizations are present in NVIDIA NGC PyTorch Containers"

Expand Down Expand Up @@ -811,7 +820,7 @@ def _val_iterator_done(self, iterator):
"""
Check if we reached trainer.limit_val_batches, if so exhaust the iterator to raise a StopIteration and exit validation_step
"""
if self._val_micro_batches_consumed == self.trainer.limit_val_batches:
if self._val_micro_batches_consumed == self.trainer.num_sanity_val_steps or self._val_micro_batches_consumed == self.trainer.limit_val_batches:
self._val_micro_batches_consumed=0
try:
_ = next(iterator) # exhausting the iterator so that PTL knows to go to validation_epoch_end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self._nsys_profile_start_step *= grad_accum_steps
self._nsys_profile_end_step *= grad_accum_steps

# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()

def model_provider_func(self, pre_process, post_process):
cfg = self.cfg
num_tokentypes = 2 if cfg.bert_binary_head else 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def __init__(self, model):
else:
raise ValueError(f"precision: {model.cfg['precision']} is not supported.")

# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()

def forward(self, tokens, position_ids, attention_mask):
if self.fp8_enabled and HAVE_TE:
with transformer_engine.pytorch.onnx_export(self.fp8_enabled), transformer_engine.pytorch.fp8_autocast(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,8 @@ def optimizer_zero_grad(self, *args, **kwargs):
return

def validation_step(self, dataloader_iter, batch_idx):
# Check if iterator is exhausted
# dataloader_iter, done = self._val_iterator_done(dataloader_iter)
# if done:
# return
mode = 'test' if self.trainer.testing else 'val'
# try except is sufficient only one batch is passed to the fwd_bwd_step
# Add try except to catch the end of the iterator and exit
try:
batch = next(dataloader_iter)
except StopIteration:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.enc_dec_model.model_type = ModelType.encoder_and_decoder

# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()

def setup_optimizer_param_groups(self):
"""ModelPT override. Optimizer will get self._optimizer_param_groups"""
self._optimizer_param_groups = get_params_for_weight_decay_optimization([self.enc_dec_model])
Expand Down

0 comments on commit a18a36e

Please sign in to comment.