From 9507f08a04d452dd8b5c319fa9fddb2a219cc459 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Thu, 18 Apr 2024 10:39:48 -0700 Subject: [PATCH] Fix memory leak at loss func (#8868) * PR #8803: Update embedding init prototype to match mc Signed-off-by: Jaemin Choi * PR #8810: Fix import of get_gpt_layer_ammo_spec Signed-off-by: Jaemin Choi * PR #8853: Fix memory leak at loss func Signed-off-by: Jaemin Choi --------- Signed-off-by: Jaemin Choi Signed-off-by: Shriya Palsamudram <69161273+ShriyaPalsamudram@users.noreply.github.com> Co-authored-by: Jaemin Choi Co-authored-by: Eric Harper Co-authored-by: Shriya Palsamudram <69161273+ShriyaPalsamudram@users.noreply.github.com> Co-authored-by: Pablo Garay --- .../language_modeling/megatron_gpt_model.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 43cc8c26444f..a660af46f13d 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -367,6 +367,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1))) self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) self.loss_broadcast_src_rank = None + self.return_output_tensors = cfg.data.get('return_output_tensors', False) + self.validation_drop_last = cfg.data.get('validation_drop_last', True) + self.sample_weight = cfg.data.get('sample_weight', 'token') self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False) self.inference_params = None @@ -621,7 +624,7 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: - if (not forward_only) or self.cfg.data.get('validation_drop_last', True): + if (not forward_only) or self.validation_drop_last: # average loss across micro batches loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] loss_tensor = torch.concat(loss_tensors_list) @@ -1136,10 +1139,9 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ def loss_func(output_tensor): # Loss for a micro-batch (ub) loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor) - cp_size = self.cfg.get('context_parallel_size', 1) - if self.cfg.data.get( - "return_output_tensors", False - ): # TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare) + cp_size = parallel_state.get_context_parallel_world_size() + if self.return_output_tensors: + # TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare) loss_for_ub, q_hs, d_hs, pos_cs, neg_cs, diff_cs = loss_for_ub reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) pos_cs = average_losses_across_data_parallel_group([pos_cs]) @@ -1156,15 +1158,14 @@ def loss_func(output_tensor): 'diff_cs': diff_cs, }, ) - elif validation_step and not self.cfg.data.get('validation_drop_last', True): - sample_weight = self.cfg.data.get('sample_weight', 'token') + elif validation_step and not self.validation_drop_last: num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub'] if loss_for_ub.isnan(): assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input' loss_sum_for_ub = torch.zeros_like(loss_for_ub) num_valid_tokens_in_ub = 0 else: - if sample_weight == 'constant': + if self.sample_weight == 'constant': num_valid_tokens_in_ub = 1 loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub @@ -1296,7 +1297,7 @@ def validation_step(self, dataloader_iter, dataloader_idx=0): def on_validation_epoch_end(self): if parallel_state.is_pipeline_last_stage(): # only the last pipeline parallel stages return loss with their batch size - if self.cfg.data.get('validation_drop_last', True): + if self.validation_drop_last: averaged_loss = torch.stack(self.validation_step_outputs).mean() else: # Compute the avg loss by total_loss across all samples / total number of samples @@ -1534,7 +1535,7 @@ def setup_validation_data(self, cfg): ) drop_last = True - if not self.cfg.data.get('validation_drop_last', True): + if not self.validation_drop_last: logging.info(f'Drop last in validation dataset is set to False') drop_last = False pad_samples_to_global_batch_size = False