diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index fb9e7a14a92e..29632ea91f7a 100644 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -16,7 +16,6 @@ ) from transformers.trainer_utils import EvaluationStrategy from utils import ( - LegacySeq2SeqDataset, Seq2SeqDataCollator, Seq2SeqDataset, assert_all_frozen, @@ -138,6 +137,10 @@ class DataTrainingArguments: src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."}) tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."}) eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."}) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."}, + ) def main(): @@ -223,7 +226,7 @@ def main(): freeze_params(model.get_encoder()) assert_all_frozen(model.get_encoder()) - dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset + dataset_class = Seq2SeqDataset # Get datasets train_dataset = ( diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index bf002ee2b654..39ca2c9cd237 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -1,11 +1,11 @@ -import logging +import copy from typing import Any, Dict, Optional, Tuple, Union import torch from torch import nn from torch.utils.data import DistributedSampler, RandomSampler -from transformers import Trainer +from transformers import PreTrainedModel, Trainer, logging from transformers.configuration_fsmt import FSMTConfig from transformers.file_utils import is_torch_tpu_available from transformers.optimization import ( @@ -27,7 +27,7 @@ from utils import label_smoothed_nll_loss -logger = logging.getLogger(__name__) +logger = logging.get_logger(__name__) arg_to_scheduler = { "linear": get_linear_schedule_with_warmup, @@ -41,13 +41,25 @@ class Seq2SeqTrainer(Trainer): - def __init__(self, config, data_args, *args, **kwargs): + def __init__(self, config=None, data_args=None, *args, **kwargs): super().__init__(*args, **kwargs) - self.config = config + + if config is None: + assert isinstance( + self.model, PreTrainedModel + ), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}" + self.config = self._actual_model(self.model).config + else: + self.config = config + self.data_args = data_args - self.max_gen_length = data_args.val_max_target_length self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size + if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss): + assert ( + self.config.pad_token_id is not None + ), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing." + def create_optimizer_and_scheduler(self, num_training_steps: int): """ Setup the optimizer and the learning rate scheduler. @@ -114,23 +126,31 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: else DistributedSampler(self.train_dataset) ) - def compute_loss(self, model, inputs): - labels = inputs.pop("labels") - outputs = model(**inputs, use_cache=False) - logits = outputs[0] - return self._compute_loss(logits, labels) - - def _compute_loss(self, logits, labels): + def _compute_loss(self, model, inputs): + inputs = copy.deepcopy(inputs) if self.args.label_smoothing == 0: - # Same behavior as modeling_bart.py - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) - assert logits.shape[-1] == self.vocab_size - loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) + if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: + # force training to ignore pad token + labels = inputs.pop("labels") + logits = model(**inputs, use_cache=False)[0] + + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) + loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) + else: + # compute usual loss via models + loss, logits = model(**inputs, use_cache=False)[:2] else: + # compute label smoothed loss + labels = inputs.pop("labels") + logits = model(**inputs, use_cache=False)[0] lprobs = torch.nn.functional.log_softmax(logits, dim=-1) - loss, nll_loss = label_smoothed_nll_loss( + loss, _ = label_smoothed_nll_loss( lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id ) + return loss, logits + + def compute_loss(self, model, inputs): + loss, _ = self._compute_loss(model, inputs) return loss def prediction_step( @@ -158,31 +178,37 @@ def prediction_step( """ inputs = self._prepare_inputs(inputs) + if self.args.predict_with_generate and not self.args.prediction_loss_only: + gen_kwargs = { + "max_length": self.data_args.val_max_target_length + if self.data_args is not None + else self.config.max_length, + "num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams, + } + generated_tokens = model.generate( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + **gen_kwargs, + ) + # in case the batch is shorter than max length, the output should be padded + if self.config.pad_token_id is not None: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) + + # compute loss on predict data with torch.no_grad(): - if self.args.predict_with_generate and not self.args.prediction_loss_only: - generated_tokens = model.generate( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], - use_cache=True, - num_beams=self.data_args.eval_beams, - max_length=self.max_gen_length, - ) - # in case the batch is shorter than max length, the output should be padded - generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.max_gen_length) - - labels_out = inputs.get("labels") - # Call forward again to get loss # TODO: avoidable? - outputs = model(**inputs, use_cache=False) - loss = self._compute_loss(outputs[1], labels_out) - loss = loss.mean().detach() - if self.args.prediction_loss_only: - return (loss, None, None) - - logits = generated_tokens if self.args.predict_with_generate else outputs[1] - - labels_out = labels_out.detach() - labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length) - return (loss, logits.detach(), labels) + loss, logits = self._compute_loss(model, inputs) + + loss = loss.mean().detach() + if self.args.prediction_loss_only: + return (loss, None, None) + + logits = generated_tokens if self.args.predict_with_generate else logits + + labels = inputs["labels"] + if self.config.pad_token_id is not None: + labels = self._pad_tensors_to_max_len(labels, self.config.max_length) + + return (loss, logits, labels) def _pad_tensors_to_max_len(self, tensor, max_length): padded_tensor = self.config.pad_token_id * torch.ones( diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 1a013debb800..8a0cdf3aa7fa 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -5,12 +5,14 @@ import pytest -from transformers import is_torch_available +from transformers import BertTokenizer, EncoderDecoderModel, is_torch_available +from transformers.file_utils import is_datasets_available from transformers.testing_utils import TestCasePlus, slow from transformers.trainer_callback import TrainerState from transformers.trainer_utils import set_seed -from .finetune_trainer import main +from .finetune_trainer import Seq2SeqTrainingArguments, main +from .seq2seq_trainer import Seq2SeqTrainer from .test_seq2seq_examples import MBART_TINY from .utils import execute_async_std @@ -50,6 +52,117 @@ def test_finetune_trainer_slow(self): assert "test_generations.txt" in contents assert "test_results.json" in contents + @slow + def test_finetune_bert2bert(self): + if not is_datasets_available(): + return + + import datasets + + bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny") + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + + bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size + bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id + + train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]") + val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]") + + train_dataset = train_dataset.select(range(32)) + val_dataset = val_dataset.select(range(16)) + + rouge = datasets.load_metric("rouge") + + batch_size = 4 + + def _map_to_encoder_decoder_inputs(batch): + # Tokenizer will automatically set [BOS] [EOS] + inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512) + outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128) + batch["input_ids"] = inputs.input_ids + batch["attention_mask"] = inputs.attention_mask + + batch["decoder_input_ids"] = outputs.input_ids + batch["labels"] = outputs.input_ids.copy() + batch["labels"] = [ + [-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"] + ] + batch["decoder_attention_mask"] = outputs.attention_mask + + assert all([len(x) == 512 for x in inputs.input_ids]) + assert all([len(x) == 128 for x in outputs.input_ids]) + + return batch + + def _compute_metrics(pred): + labels_ids = pred.label_ids + pred_ids = pred.predictions + + # all unnecessary tokens are removed + pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) + + rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[ + "rouge2" + ].mid + + return { + "rouge2_precision": round(rouge_output.precision, 4), + "rouge2_recall": round(rouge_output.recall, 4), + "rouge2_fmeasure": round(rouge_output.fmeasure, 4), + } + + # map train dataset + train_dataset = train_dataset.map( + _map_to_encoder_decoder_inputs, + batched=True, + batch_size=batch_size, + remove_columns=["article", "highlights"], + ) + train_dataset.set_format( + type="torch", + columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"], + ) + + # same for validation dataset + val_dataset = val_dataset.map( + _map_to_encoder_decoder_inputs, + batched=True, + batch_size=batch_size, + remove_columns=["article", "highlights"], + ) + val_dataset.set_format( + type="torch", + columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"], + ) + + output_dir = self.get_auto_remove_tmp_dir() + + training_args = Seq2SeqTrainingArguments( + output_dir=output_dir, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + predict_with_generate=True, + evaluate_during_training=True, + do_train=True, + do_eval=True, + warmup_steps=0, + eval_steps=2, + logging_steps=2, + ) + + # instantiate trainer + trainer = Seq2SeqTrainer( + model=bert2bert, + args=training_args, + compute_metrics=_compute_metrics, + train_dataset=train_dataset, + eval_dataset=val_dataset, + ) + + # start training + trainer.train() + def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int): # XXX: remove hardcoded path