From 5d5fa51c5f144d0c17683526c5f863b50543e5eb Mon Sep 17 00:00:00 2001 From: Adi Renduchintala Date: Tue, 22 Aug 2023 14:17:17 -0700 Subject: [PATCH] loss mask aware final layer applicaiton (#7275) * loss mask for final output and softmax Signed-off-by: arendu * bs2 working Signed-off-by: arendu * Fix skip generation Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add metric condition Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * encoder_input is none check Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: arendu Signed-off-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../language_modeling/megatron/gpt_model.py | 24 ++++++++++++++++--- .../language_modeling/megatron_gpt_model.py | 1 + 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py index 1b9ef415c64b..feed809ec737 100755 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py @@ -267,6 +267,7 @@ def forward( input_ids, position_ids, attention_mask, + loss_mask=None, labels=None, token_type_ids=None, layer_past=None, @@ -294,9 +295,15 @@ def forward( ) if self.post_process: - return post_language_model_processing( - lm_output, - labels, + if loss_mask is not None: + loss_lm_output = lm_output.transpose(0, 1)[loss_mask == 1].unsqueeze(1) + loss_labels = labels[loss_mask == 1].unsqueeze(0) + else: + loss_lm_output = lm_output + loss_labels = labels + post_process_result = post_language_model_processing( + loss_lm_output, + loss_labels, self.language_model.output_layer.weight if not self.share_embeddings_and_output_weights else self.word_embeddings_weight(), @@ -308,6 +315,17 @@ def forward( sequence_parallel=self.sequence_parallel, gradient_accumulation_fusion=self.config.gradient_accumulation_fusion, ) + if loss_mask is not None: + if isinstance(post_process_result, tuple): + loss, logits = post_process_result + else: + loss, logits = post_process_result, None + + res = torch.zeros_like(labels).type_as(loss) + res[loss_mask == 1] = loss + return res if logits is None else (res, logits) + else: + return post_process_result else: return lm_output diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 358f3387b812..b8b81783dac3 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -833,6 +833,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ 'position_ids': batch['position_ids'], 'attention_mask': batch['attention_mask'], 'labels': batch['labels'], + 'loss_mask': batch['loss_mask'], } if not self.mcore_gpt: forward_args['checkpoint_activations_all_layers'] = checkpoint_activations_all_layers