Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions examples/seq2seq/finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
)
from transformers.trainer_utils import EvaluationStrategy
from utils import (
LegacySeq2SeqDataset,
Seq2SeqDataCollator,
Seq2SeqDataset,
assert_all_frozen,
Expand Down Expand Up @@ -138,6 +137,10 @@ class DataTrainingArguments:
src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."})
ignore_pad_token_for_loss: bool = field(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put at True for backward compatibility

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

default=True,
metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."},
)


def main():
Expand Down Expand Up @@ -223,7 +226,7 @@ def main():
freeze_params(model.get_encoder())
assert_all_frozen(model.get_encoder())

dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare_seq2seq_batch is now as a function in PretrainedTokenizer so this cannot be False.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch

dataset_class = Seq2SeqDataset

# Get datasets
train_dataset = (
Expand Down
110 changes: 68 additions & 42 deletions examples/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
import copy
from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler

from transformers import Trainer
from transformers import PreTrainedModel, Trainer, logging
from transformers.configuration_fsmt import FSMTConfig
from transformers.file_utils import is_torch_tpu_available
from transformers.optimization import (
Expand All @@ -27,7 +27,7 @@
from utils import label_smoothed_nll_loss


logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)

arg_to_scheduler = {
"linear": get_linear_schedule_with_warmup,
Expand All @@ -41,13 +41,25 @@


class Seq2SeqTrainer(Trainer):
def __init__(self, config, data_args, *args, **kwargs):
def __init__(self, config=None, data_args=None, *args, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make those variables optional to align better with Trainer and to keep 100% backwards compatibility

super().__init__(*args, **kwargs)
self.config = config

if config is None:
assert isinstance(
self.model, PreTrainedModel
), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}"
self.config = self._actual_model(self.model).config
else:
self.config = config

self.data_args = data_args
self.max_gen_length = data_args.val_max_target_length
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size

if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
assert (
self.config.pad_token_id is not None
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."

def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
Expand Down Expand Up @@ -114,23 +126,31 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
else DistributedSampler(self.train_dataset)
)

def compute_loss(self, model, inputs):
labels = inputs.pop("labels")
outputs = model(**inputs, use_cache=False)
logits = outputs[0]
return self._compute_loss(logits, labels)

def _compute_loss(self, logits, labels):
def _compute_loss(self, model, inputs):
inputs = copy.deepcopy(inputs)
if self.args.label_smoothing == 0:
# Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem to work for all models (EncoderDecoderModel does not work with it) -> Let's instead use the loss function of each model here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss functions of model use -100 as ignore_index , we will also need to replace pad tokens in labels with -100

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually do this manually before -> should that be the role of the Seq2SeqTrainer? Trainer also does not have this feature

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignoring pad_token_id confused lots of people and helps metrics so we automated it.
Related: #7828

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually do this manually before

we could do this in the collator, but we won't need to do if #7828 is merged

Copy link
Contributor

@sshleifer sshleifer Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will still need to cover FSMT/T5.
I would definitely not do this change right now, it works as is and is much easier than checking that every model ignores padding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch's CE loss function has -100 as a default value and from what I understood it is the default behavior of the library to ignore tokens when there have the index -100 and not when there are equal to the padding token (often we set padding token == -100): https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

It would require models to manually replace tokens with -100, but I think that's how it should be done in general in the library. How would be handle models that don't have a padding_token or want to disregard loss of more than just the padding token? For such cases I think it can be quite handy if the user overwrites all labels he does not want to consider with -100

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will discuss on zoom!

assert logits.shape[-1] == self.vocab_size
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
# force training to ignore pad token
labels = inputs.pop("labels")
logits = model(**inputs, use_cache=False)[0]

loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
else:
# compute usual loss via models
loss, logits = model(**inputs, use_cache=False)[:2]
else:
# compute label smoothed loss
labels = inputs.pop("labels")
logits = model(**inputs, use_cache=False)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think use_cache=False everywhere or nowhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed it - think it's better this way to not give the false impression that use_cache=True will break training. All models have use_cache=True by default and training works by default. It's all about whether past_key_values are inserted or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this actually breaks a test - it shouldn't. This is related to this Bart bug we never solved: #6353 :-/

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Oct 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add use_cache=False again for now and remove it when fixing the bug in Bart.

lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
loss, nll_loss = label_smoothed_nll_loss(
loss, _ = label_smoothed_nll_loss(
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
)
return loss, logits

def compute_loss(self, model, inputs):
loss, _ = self._compute_loss(model, inputs)
return loss

def prediction_step(
Expand Down Expand Up @@ -158,31 +178,37 @@ def prediction_step(
"""
inputs = self._prepare_inputs(inputs)

if self.args.predict_with_generate and not self.args.prediction_loss_only:
gen_kwargs = {
"max_length": self.data_args.val_max_target_length
if self.data_args is not None
else self.config.max_length,
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
}
generated_tokens = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
if self.config.pad_token_id is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect this case to break. _pad_tensors_to_max_len is needed for some sort of Trainer/consistent shapes reason @patil-suraj .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Trainer expects all returned preds to be of same shape, which it concatenates at for every batch eval

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it -> if config.pad_token_id is not defined we cannot run _pad_tensors_to_max_len. How is this breaking anything? I am running all my experiments with no pad_token_id defined, so this case works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since Trainer concatenates the preds I assuming they should be of same length across batches. It was breaking in my last experiment when not using _pad_tensors_to_max_len

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine -> see the test I added for bert2bert. Such a model does not have a self.config.pad_token_id defined and still works.

generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])

# compute loss on predict data
with torch.no_grad():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generate() is always in torch.no_grad() context.

if self.args.predict_with_generate and not self.args.prediction_loss_only:
generated_tokens = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
use_cache=True,
num_beams=self.data_args.eval_beams,
max_length=self.max_gen_length,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need this if eval_beams and and max_length are different than default

)
# in case the batch is shorter than max length, the output should be padded
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.max_gen_length)

labels_out = inputs.get("labels")
# Call forward again to get loss # TODO: avoidable?
outputs = model(**inputs, use_cache=False)
loss = self._compute_loss(outputs[1], labels_out)
loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)

logits = generated_tokens if self.args.predict_with_generate else outputs[1]

labels_out = labels_out.detach()
labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length)
return (loss, logits.detach(), labels)
loss, logits = self._compute_loss(model, inputs)

loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)

logits = generated_tokens if self.args.predict_with_generate else logits

labels = inputs["labels"]
if self.config.pad_token_id is not None:
labels = self._pad_tensors_to_max_len(labels, self.config.max_length)

return (loss, logits, labels)

def _pad_tensors_to_max_len(self, tensor, max_length):
padded_tensor = self.config.pad_token_id * torch.ones(
Expand Down
117 changes: 115 additions & 2 deletions examples/seq2seq/test_finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

import pytest

from transformers import is_torch_available
from transformers import BertTokenizer, EncoderDecoderModel, is_torch_available
from transformers.file_utils import is_datasets_available
from transformers.testing_utils import TestCasePlus, slow
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed

from .finetune_trainer import main
from .finetune_trainer import Seq2SeqTrainingArguments, main
from .seq2seq_trainer import Seq2SeqTrainer
from .test_seq2seq_examples import MBART_TINY
from .utils import execute_async_std

Expand Down Expand Up @@ -50,6 +52,117 @@ def test_finetune_trainer_slow(self):
assert "test_generations.txt" in contents
assert "test_results.json" in contents

@slow
def test_finetune_bert2bert(self):
if not is_datasets_available():
return

import datasets

bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id

train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is cool.
cc @stas00 if you ever want to add more training data to a unit-test.

val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")

train_dataset = train_dataset.select(range(32))
val_dataset = val_dataset.select(range(16))

rouge = datasets.load_metric("rouge")

batch_size = 4

def _map_to_encoder_decoder_inputs(batch):
# Tokenizer will automatically set [BOS] <text> [EOS]
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
batch["input_ids"] = inputs.input_ids
batch["attention_mask"] = inputs.attention_mask

batch["decoder_input_ids"] = outputs.input_ids
batch["labels"] = outputs.input_ids.copy()
batch["labels"] = [
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
]
batch["decoder_attention_mask"] = outputs.attention_mask

assert all([len(x) == 512 for x in inputs.input_ids])
assert all([len(x) == 128 for x in outputs.input_ids])

return batch

def _compute_metrics(pred):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI by default you will get rouge1, rouge2, rougeL (if you don't overwrite compute_metrics

labels_ids = pred.label_ids
pred_ids = pred.predictions

# all unnecessary tokens are removed
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
"rouge2"
].mid

return {
"rouge2_precision": round(rouge_output.precision, 4),
"rouge2_recall": round(rouge_output.recall, 4),
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
}

# map train dataset
train_dataset = train_dataset.map(
_map_to_encoder_decoder_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights"],
)
train_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

# same for validation dataset
val_dataset = val_dataset.map(
_map_to_encoder_decoder_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights"],
)
val_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

output_dir = self.get_auto_remove_tmp_dir()

training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
predict_with_generate=True,
evaluate_during_training=True,
do_train=True,
do_eval=True,
warmup_steps=0,
eval_steps=2,
logging_steps=2,
)

# instantiate trainer
trainer = Seq2SeqTrainer(
model=bert2bert,
args=training_args,
compute_metrics=_compute_metrics,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)

# start training
trainer.train()

def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):

# XXX: remove hardcoded path
Expand Down