Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,9 +1045,9 @@ def forward(
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss.
Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
with labels in ``[0, ..., config.vocab_size]``.
Indices should either be in ``[0, ..., config.vocab_size]`` (see ``input_ids`` docstring).
Tokens with indices set to ``config.pad_token_id`` are ignored (masked), the loss is only computed for the tokens
with labels in ``[0, ..., config.vocab_size]`` excluding ``config.pad_token_id``.

Returns:

Expand Down Expand Up @@ -1090,6 +1090,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
assert labels.min() > 0, f'negative labels are not supported, got {labels.min()}'
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be labels.min()>=0, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes good catch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is something we can accept in the forward pass of the BART model, as it would severely harm TPU performance. The assertion means that we would be retrieving the value of the xla tensor labels.min() > 0 back on CPU every time, which would cause a big performance drop.

I would advocate for this to be put in the dataloader instead, and the loss will crash anyway when seeing a label value which has a negative value and is not the ignored index.

Copy link
Contributor

@patil-suraj patil-suraj Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aargh, we completely ignored TPU, makes sense. Thanks Lysandre!

use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
Expand All @@ -1110,8 +1111,7 @@ def forward(

masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# TODO(SS): do we need to ignore pad tokens in labels?
loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)
Copy link
Contributor

@patil-suraj patil-suraj Oct 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert if there are -100 in labels

Wdyt @sshleifer ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea. We should also probably do FSMT

masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

if not return_dict:
Expand Down