Skip to content

Conversation

@lvwerra
Copy link
Member

@lvwerra lvwerra commented Oct 15, 2020

What does this PR do?

There is a discrepancy between the fine-tuning script and the BartForConditionalGeneration which is also noted in the comments.

From examples/seq2seq/finetune.py:

# Same behavior as modeling_bart.py, besides ignoring pad_token_id
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)

From transformers/src/modeling_bart.py:

loss_fct = CrossEntropyLoss()
# TODO(SS): do we need to ignore pad tokens in labels?

Training with the Trainer and BartForConditionalGeneration results in a model that produces garbled text (lots of repetitions and no coherence). Adding the ignore_index=self.config.pad_token_id in the CrossEntropyLoss resolves the issue.

Besides a before and after run I did not study the behaviour in a systematic way since training the model requires a significant amount of time and compute. If you would like to see more testing let me know what you think is the best way to test this thoroughly.

@lvwerra
Copy link
Member Author

lvwerra commented Oct 15, 2020

I don't seem to be able to add reviewers but I guess this would fall into the domain of @sshleifer.

@sshleifer sshleifer self-assigned this Oct 15, 2020
@sshleifer
Copy link
Contributor

Thx for the contribution!

FYI we have a seq2seq finetuner with this bugfix.
https://github.com/huggingface/transformers/blob/master/examples/seq2seq/seq2seq_trainer.py#L43

I worked on this at some point and thought I had fixed it.

Any issue with merging this @patil-suraj @patrickvonplaten ?

@sshleifer sshleifer requested a review from sgugger October 15, 2020 18:19
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@patil-suraj
Copy link
Contributor

patil-suraj commented Oct 15, 2020

LGTM. but should be documented. seen few notebooks where people are setting pad tokens to -100 in labels . We should change this for T5 as well

if labels is not None:
loss_fct = CrossEntropyLoss()
# TODO(SS): do we need to ignore pad tokens in labels?
loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)
Copy link
Contributor

@patil-suraj patil-suraj Oct 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert if there are -100 in labels

Wdyt @sshleifer ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea. We should also probably do FSMT

@lvwerra
Copy link
Member Author

lvwerra commented Oct 15, 2020

FYI we have a seq2seq finetuner with this bugfix.
https://github.com/huggingface/transformers/blob/master/examples/seq2seq/seq2seq_trainer.py#L43

Thanks, I did not see that! With the fix in the model I was able to train Pegasus with the standard Trainer.

@lvwerra
Copy link
Member Author

lvwerra commented Oct 15, 2020

LGTM. but should be documented. seen few notebooks where people are setting pad tokens to -100 in labels . We should change this for T5 as well

Good point, I remember that through me off because it explicitly says -100 works in the model's docstring.

@lvwerra
Copy link
Member Author

lvwerra commented Oct 15, 2020

I updated the docstring and added two assertions. Are these the assertions you were looking for @patil-suraj ?


if labels is not None:
assert -100 not in labels
assert labels.min() > 0, f'negative labels are not supported, got {labels.min()}'
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be labels.min()>=0, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes good catch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is something we can accept in the forward pass of the BART model, as it would severely harm TPU performance. The assertion means that we would be retrieving the value of the xla tensor labels.min() > 0 back on CPU every time, which would cause a big performance drop.

I would advocate for this to be put in the dataloader instead, and the loss will crash anyway when seeing a label value which has a negative value and is not the ignored index.

Copy link
Contributor

@patil-suraj patil-suraj Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aargh, we completely ignored TPU, makes sense. Thanks Lysandre!

@patrickvonplaten
Copy link
Contributor

I am not in favor of this PR to be honest.

  1. By replacing the default ignore_idx = -100 to ignore_idx = pad_token_id we constrain the user to be able to only ignore padding tokens but no other tokens. Previously a user could simply set tokens that should be ignored to -100 (pad_token, but in addition all other tokens). After this PR a user would not be able to ignore other tokens that the pad_token anymore.

  2. This is not consistent with other models, which always only use -100 to ignore the loss

  3. This is just a convenience function that should not be handled directly in the model itself. I'm 100% fine if this is handled in the Seq2SeqTrainer or Trainer

Already discussed offline with @sshleifer.

What are your thoughts on this @LysandreJik @sgugger @thomwolf ?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we should force model to only ignore pad_tokens in loss

@lvwerra
Copy link
Member Author

lvwerra commented Oct 22, 2020

I agree that it would be nice to have a uniform pattern across the model architectures allowing to use the models interchangeably.

It seems there is some work needed to make this allow -100 tokens in Bart since they break the embedding process in the forward pass:

/usr/local/lib/python3.6/dist-packages/transformers/modeling_bart.py in forward(self, input_ids, attention_mask, output_attentions, output_hidden_states, return_dict)
    332             attention_mask = invert_mask(attention_mask)
    333 
--> 334         inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
    335         embed_pos = self.embed_positions(input_ids)
    336         x = inputs_embeds + embed_pos

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/sparse.py in forward(self, input)
    112         return F.embedding(
    113             input, self.weight, self.padding_idx, self.max_norm,
--> 114             self.norm_type, self.scale_grad_by_freq, self.sparse)
    115 
    116     def extra_repr(self):

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   1482         # remove once script supports set_grad_enabled
   1483         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1484     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   1485 
   1486 

RuntimeError: index out of range: Tried to access index -100 out of table with 50263 rows. at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:418

@sshleifer
Copy link
Contributor

@lvwerra I think we should ignore pad_token_id, but if we go the -100 route it should be fine if you pass decoder_input_ids to BART? I don't see a call on your traceback so can't be sure.

@lvwerra
Copy link
Member Author

lvwerra commented Oct 24, 2020

@sshleifer I tried to pass -100 in the input_ids:

from transformers import BartForConditionalGeneration, BartTokenizer
import torch

model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')

model(torch.tensor([[0, -100]]))

But I get the same error with:

t =model(torch.tensor([[0, 1]]), decoder_input_ids=torch.tensor([[0, 0, -100]]))

It seems that the error comes from the line x = self.embed_tokens(input_ids) * self.embed_scale which is called in both forward passes of the encoder and decoder modules. How do you usually deal with this?

@sshleifer
Copy link
Contributor

I think to successfully implement the -100 strategy (Which I have never done),
you have to pass labels that contain -100 and decoder_input_ids that don't contain -100.

@thomwolf
Copy link
Member

Yes I would tend to agree with @patrickvonplaten, I think that the usual philosophy of the lib is that we let the user handle this himself and have clear and simple exemple which shows that you should replace pad_token ids with ignore index in the labels.

@sgugger
Copy link
Collaborator

sgugger commented Oct 25, 2020

I somehow missed the notification when @patrickvonplaten asked for advice earlier but I agree with what he said. We only handle a basic loss computation inside the model. We refused PRs to add weights for cross-entropy recently, for the same reason @thomwolf just pointed out: anything fancier should be done by the user themself, as we can't support every use case.

For the Trainer and Seq2SeqTrainer, there is any easy way to handle a custom loss computation, by subclassing and overriding the compute_loss function (see the example in the docs.

@lvwerra
Copy link
Member Author

lvwerra commented Oct 25, 2020

Thanks for the feedback @thomwolf & @sgugger! From a user perspective, I think it would be great if one could use a model in combination with the Trainer for the model's standard task (e.g. conditional generation) without customisation or subclassing. If the default loss function of the model does not support that, then for which other use-cases would the default behaviour (not ignoring padding tokens in the labels) still be useful? Customisation already requires advanced knowledge of the inner workings of the Trainer class which not all users might have. If a user wants to do something more sophisticated than the standard task that requires modification of this behaviour, they could still write a custom compute_loss function.

If you want to go the -100-route the user has to "manually" right shift the labels tokens to create the decoder_input_ids and then replace the pad_token_id in the labels with -100. As far as I can tell, this is always required to train the model for conditional generation, so I am wondering why it should not be the default behaviour inside the model BartForConditionalGeneration? Otherwise, the defaults are never used in practice and customisation is always required to train the model.

@patrickvonplaten
Copy link
Contributor

Hey @lvwerra,

I think the main arguments against ignoring the pad_token_id inside BartForConditionalGeneration is that:

  1. We cannot allow all models to have this behavior because some of them do not have a pad_token_id, e.g. GPT2. Because consistency between models is one of our top priorities, it is not a good idea to use -100 for some models and pad_token_id for others.

  2. The are use cases where users want to not only ignore the padding token, but also other tokens, e.g. the eos token id. In this case it would be cleaner to set both pad and eos to -100 and ignore those tokens than setting the eos token to the pad token.

@lvwerra
Copy link
Member Author

lvwerra commented Nov 1, 2020

Ok, so if I understand correctly the minimal example to train a Bart model given a dataset object with columns 'text' and 'summary' would be to apply the following function (e.g. with .map()) before passing the model and the dataset to the Trainer:

from transformers.modeling_bart import shift_tokens_right

def convert_to_features(example_batch):
    input_encodings = tokenizer.batch_encode_plus(example_batch['text'], pad_to_max_length=True)
    target_encodings = tokenizer.batch_encode_plus(example_batch['summary'], pad_to_max_length=True)
    
    labels = target_encodings['input_ids']
    decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id)
    labels[labels[:, :] == 0] = -100
    
    encodings = {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        'decoder_input_ids': decoder_input_ids,
        'labels': labels,
    }

    return encodings

It took me quite a while reading examples and code reading to figure this out. Not only the thing with the padding tokens and -100 but also the difference between decoder_input_ids and labels. I am more than happy to update the docs to save the next person some time, since this seems not to be an edge case but the minimal work required to train Bart for conditional generation. Is there a good place to point this out?

@sshleifer
Copy link
Contributor

you could write a forums post and link to it from bart.rst?

@lvwerra lvwerra mentioned this pull request Nov 5, 2020
@lvwerra lvwerra closed this Nov 29, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants