Skip to content

Commit

Permalink
loss mask aware final layer applicaiton (NVIDIA#7275)
Browse files Browse the repository at this point in the history
* loss mask for final output and softmax

Signed-off-by: arendu <[email protected]>

* bs2 working

Signed-off-by: arendu <[email protected]>

* Fix skip generation

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add metric condition

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* encoder_input is none check

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: arendu <[email protected]>
Signed-off-by: Cheng-Ping Hsieh <[email protected]>
Co-authored-by: Cheng-Ping Hsieh <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: dorotat <[email protected]>
  • Loading branch information
3 people authored and dorotat-nv committed Aug 24, 2023
1 parent 0f0bb53 commit 88f9a9c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def forward(
input_ids,
position_ids,
attention_mask,
loss_mask=None,
labels=None,
token_type_ids=None,
layer_past=None,
Expand Down Expand Up @@ -294,9 +295,15 @@ def forward(
)

if self.post_process:
return post_language_model_processing(
lm_output,
labels,
if loss_mask is not None:
loss_lm_output = lm_output.transpose(0, 1)[loss_mask == 1].unsqueeze(1)
loss_labels = labels[loss_mask == 1].unsqueeze(0)
else:
loss_lm_output = lm_output
loss_labels = labels
post_process_result = post_language_model_processing(
loss_lm_output,
loss_labels,
self.language_model.output_layer.weight
if not self.share_embeddings_and_output_weights
else self.word_embeddings_weight(),
Expand All @@ -308,6 +315,17 @@ def forward(
sequence_parallel=self.sequence_parallel,
gradient_accumulation_fusion=self.config.gradient_accumulation_fusion,
)
if loss_mask is not None:
if isinstance(post_process_result, tuple):
loss, logits = post_process_result
else:
loss, logits = post_process_result, None

res = torch.zeros_like(labels).type_as(loss)
res[loss_mask == 1] = loss
return res if logits is None else (res, logits)
else:
return post_process_result
else:
return lm_output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
'position_ids': batch['position_ids'],
'attention_mask': batch['attention_mask'],
'labels': batch['labels'],
'loss_mask': batch['loss_mask'],
}
if not self.mcore_gpt:
forward_args['checkpoint_activations_all_layers'] = checkpoint_activations_all_layers
Expand Down

0 comments on commit 88f9a9c

Please sign in to comment.