Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loss mask aware final layer applicaiton #7275

Merged
merged 14 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def forward(
input_ids,
position_ids,
attention_mask,
loss_mask=None,
labels=None,
token_type_ids=None,
layer_past=None,
Expand Down Expand Up @@ -294,9 +295,15 @@ def forward(
)

if self.post_process:
return post_language_model_processing(
lm_output,
labels,
if loss_mask is not None:
loss_lm_output = lm_output.transpose(0, 1)[loss_mask == 1].unsqueeze(1)
loss_labels = labels[loss_mask == 1].unsqueeze(0)
else:
loss_lm_output = lm_output
loss_labels = labels
post_process_result = post_language_model_processing(
loss_lm_output,
loss_labels,
self.language_model.output_layer.weight
if not self.share_embeddings_and_output_weights
else self.word_embeddings_weight(),
Expand All @@ -308,6 +315,17 @@ def forward(
sequence_parallel=self.sequence_parallel,
gradient_accumulation_fusion=self.config.gradient_accumulation_fusion,
)
if loss_mask is not None:
stevehuang52 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(post_process_result, tuple):
loss, logits = post_process_result
else:
loss, logits = post_process_result, None

res = torch.zeros_like(labels).type_as(loss)
res[loss_mask == 1] = loss
return res if logits is None else (res, logits)
else:
return post_process_result
else:
return lm_output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
'position_ids': batch['position_ids'],
'attention_mask': batch['attention_mask'],
'labels': batch['labels'],
'loss_mask': batch['loss_mask'],
}
if not self.mcore_gpt:
forward_args['checkpoint_activations_all_layers'] = checkpoint_activations_all_layers
Expand Down
Loading