Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add first_val_step to mcore scheduler #8150

Merged
merged 11 commits into from
Jan 25, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.mcore_gpt = cfg.get('mcore_gpt', False)
self.spec_name = cfg.get('name', '')
if cfg.get('fp8', False):
self.prev_step_training = True

self.rampup_batch_size = self.cfg.get('rampup_batch_size', None)
if self.rampup_batch_size:
Expand Down Expand Up @@ -484,7 +486,7 @@ def forward(self, tokens, text_position_ids, attention_mask, labels):
output_tensor = self.model(tokens, text_position_ids, attention_mask, labels=labels)
return output_tensor

def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only, first_val_step=None):

# handle asynchronous grad reduction
no_sync_func = None
Expand Down Expand Up @@ -514,6 +516,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
forward_only=forward_only,
seq_length=self.cfg.encoder_seq_length,
micro_batch_size=self.cfg.micro_batch_size,
first_val_step=first_val_step,
)

# only the last stages of the pipeline return losses
Expand Down Expand Up @@ -608,6 +611,9 @@ def training_step(self, dataloader_iter, batch_idx):

loss_mean = self.fwd_bwd_step(dataloader_iter, batch_idx, False)

if self.cfg.get('fp8', False):
self.prev_step_training = self.training

# when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced
if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False):
self.allreduce_sequence_parallel_gradients()
Expand Down Expand Up @@ -1024,7 +1030,13 @@ def validation_step(self, dataloader_iter, batch_idx):
for model_module in self.model:
model_module.eval()

loss = self.fwd_bwd_step(dataloader_iter, batch_idx, True)
if self.cfg.get('fp8', False):
first_val_step = self.prev_step_training and not self.training
self.prev_step_training = self.training
else:
first_val_step = None

loss = self.fwd_bwd_step(dataloader_iter, batch_idx, True, first_val_step)

if isinstance(self.model, list):
for model_module in self.model:
Expand Down
Loading