Skip to content

Commit

Permalink
add first_val_step to mcore scheduler
Browse files Browse the repository at this point in the history
Signed-off-by: jiemingz <[email protected]>
  • Loading branch information
jiemingz committed Jan 10, 2024
1 parent 0a1a5b1 commit 50a8f2c
Showing 1 changed file with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.mcore_gpt = cfg.get('mcore_gpt', False)
self.spec_name = cfg.get('name', '')
if cfg.get('fp8', False):
self.prev_step_training = True

self.rampup_batch_size = self.cfg.get('rampup_batch_size', None)
if self.rampup_batch_size:
Expand Down Expand Up @@ -502,6 +504,8 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
module.config.grad_sync_func = grad_sync_func
module.config.param_sync_func = param_sync_func

first_val_step = self.prev_step_training and not self.training

# run forward and backwards passes for an entire global batch
# we do this inside training_step to support pipeline parallelism
fwd_bwd_function = get_forward_backward_func()
Expand All @@ -515,7 +519,9 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
forward_only=forward_only,
seq_length=self.cfg.encoder_seq_length,
micro_batch_size=self.cfg.micro_batch_size,
first_val_step=first_val_step,
)
self.prev_step_training = self.training

# only the last stages of the pipeline return losses
if losses_reduced_per_micro_batch:
Expand Down

0 comments on commit 50a8f2c

Please sign in to comment.