-
Notifications
You must be signed in to change notification settings - Fork 31.8k
[Seq2Seq] Allow EncoderDecoderModels to be trained with Seq2Seq #7809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
18a61d7
dccf5bf
2231622
9d06360
7334053
c3845d8
82a5013
990ba2e
e6b6047
24757ce
642d903
4e6442d
1a61965
9fb1f27
2af3235
62a2068
ba187fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,6 @@ | |
| ) | ||
| from transformers.trainer_utils import EvaluationStrategy | ||
| from utils import ( | ||
| LegacySeq2SeqDataset, | ||
| Seq2SeqDataCollator, | ||
| Seq2SeqDataset, | ||
| assert_all_frozen, | ||
|
|
@@ -138,6 +137,10 @@ class DataTrainingArguments: | |
| src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."}) | ||
| tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."}) | ||
| eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."}) | ||
| ignore_pad_token_for_loss: bool = field( | ||
| default=True, | ||
| metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."}, | ||
| ) | ||
|
|
||
|
|
||
| def main(): | ||
|
|
@@ -223,7 +226,7 @@ def main(): | |
| freeze_params(model.get_encoder()) | ||
| assert_all_frozen(model.get_encoder()) | ||
|
|
||
| dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great catch |
||
| dataset_class = Seq2SeqDataset | ||
|
|
||
| # Get datasets | ||
| train_dataset = ( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,11 @@ | ||
| import logging | ||
| import copy | ||
| from typing import Any, Dict, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| from torch.utils.data import DistributedSampler, RandomSampler | ||
|
|
||
| from transformers import Trainer | ||
| from transformers import PreTrainedModel, Trainer, logging | ||
| from transformers.configuration_fsmt import FSMTConfig | ||
| from transformers.file_utils import is_torch_tpu_available | ||
| from transformers.optimization import ( | ||
|
|
@@ -27,7 +27,7 @@ | |
| from utils import label_smoothed_nll_loss | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
| logger = logging.get_logger(__name__) | ||
|
|
||
| arg_to_scheduler = { | ||
| "linear": get_linear_schedule_with_warmup, | ||
|
|
@@ -41,13 +41,25 @@ | |
|
|
||
|
|
||
| class Seq2SeqTrainer(Trainer): | ||
| def __init__(self, config, data_args, *args, **kwargs): | ||
| def __init__(self, config=None, data_args=None, *args, **kwargs): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make those variables optional to align better with |
||
| super().__init__(*args, **kwargs) | ||
| self.config = config | ||
|
|
||
| if config is None: | ||
| assert isinstance( | ||
| self.model, PreTrainedModel | ||
| ), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}" | ||
| self.config = self._actual_model(self.model).config | ||
| else: | ||
| self.config = config | ||
|
|
||
| self.data_args = data_args | ||
| self.max_gen_length = data_args.val_max_target_length | ||
| self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size | ||
|
|
||
| if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss): | ||
| assert ( | ||
| self.config.pad_token_id is not None | ||
| ), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing." | ||
|
|
||
| def create_optimizer_and_scheduler(self, num_training_steps: int): | ||
| """ | ||
| Setup the optimizer and the learning rate scheduler. | ||
|
|
@@ -114,23 +126,31 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: | |
| else DistributedSampler(self.train_dataset) | ||
| ) | ||
|
|
||
| def compute_loss(self, model, inputs): | ||
| labels = inputs.pop("labels") | ||
| outputs = model(**inputs, use_cache=False) | ||
| logits = outputs[0] | ||
| return self._compute_loss(logits, labels) | ||
|
|
||
| def _compute_loss(self, logits, labels): | ||
| def _compute_loss(self, model, inputs): | ||
| inputs = copy.deepcopy(inputs) | ||
| if self.args.label_smoothing == 0: | ||
| # Same behavior as modeling_bart.py | ||
| loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not seem to work for all models (
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. loss functions of model use -100 as
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I usually do this manually before -> should that be the role of the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ignoring pad_token_id confused lots of people and helps metrics so we automated it.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
we could do this in the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we will still need to cover FSMT/T5.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PyTorch's CE loss function has It would require models to manually replace tokens with -100, but I think that's how it should be done in general in the library. How would be handle models that don't have a padding_token or want to disregard loss of more than just the padding token? For such cases I think it can be quite handy if the user overwrites all labels he does not want to consider with -100
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we will discuss on zoom! |
||
| assert logits.shape[-1] == self.vocab_size | ||
| loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) | ||
| if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: | ||
| # force training to ignore pad token | ||
| labels = inputs.pop("labels") | ||
| logits = model(**inputs, use_cache=False)[0] | ||
|
|
||
| loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) | ||
| loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) | ||
| else: | ||
| # compute usual loss via models | ||
| loss, logits = model(**inputs, use_cache=False)[:2] | ||
| else: | ||
| # compute label smoothed loss | ||
| labels = inputs.pop("labels") | ||
| logits = model(**inputs, use_cache=False)[0] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed it - think it's better this way to not give the false impression that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh this actually breaks a test - it shouldn't. This is related to this Bart bug we never solved: #6353 :-/
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will add |
||
| lprobs = torch.nn.functional.log_softmax(logits, dim=-1) | ||
| loss, nll_loss = label_smoothed_nll_loss( | ||
| loss, _ = label_smoothed_nll_loss( | ||
| lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id | ||
| ) | ||
| return loss, logits | ||
|
|
||
| def compute_loss(self, model, inputs): | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| loss, _ = self._compute_loss(model, inputs) | ||
| return loss | ||
|
|
||
| def prediction_step( | ||
|
|
@@ -158,31 +178,37 @@ def prediction_step( | |
| """ | ||
| inputs = self._prepare_inputs(inputs) | ||
|
|
||
| if self.args.predict_with_generate and not self.args.prediction_loss_only: | ||
| gen_kwargs = { | ||
| "max_length": self.data_args.val_max_target_length | ||
| if self.data_args is not None | ||
| else self.config.max_length, | ||
| "num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams, | ||
| } | ||
| generated_tokens = model.generate( | ||
| inputs["input_ids"], | ||
| attention_mask=inputs["attention_mask"], | ||
| **gen_kwargs, | ||
| ) | ||
| # in case the batch is shorter than max length, the output should be padded | ||
| if self.config.pad_token_id is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would expect this case to break.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't get it -> if
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is fine -> see the test I added for bert2bert. Such a model does not have a |
||
| generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) | ||
|
|
||
| # compute loss on predict data | ||
| with torch.no_grad(): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| if self.args.predict_with_generate and not self.args.prediction_loss_only: | ||
| generated_tokens = model.generate( | ||
| inputs["input_ids"], | ||
| attention_mask=inputs["attention_mask"], | ||
| use_cache=True, | ||
| num_beams=self.data_args.eval_beams, | ||
| max_length=self.max_gen_length, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need this if |
||
| ) | ||
| # in case the batch is shorter than max length, the output should be padded | ||
| generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.max_gen_length) | ||
|
|
||
| labels_out = inputs.get("labels") | ||
| # Call forward again to get loss # TODO: avoidable? | ||
| outputs = model(**inputs, use_cache=False) | ||
| loss = self._compute_loss(outputs[1], labels_out) | ||
| loss = loss.mean().detach() | ||
| if self.args.prediction_loss_only: | ||
| return (loss, None, None) | ||
|
|
||
| logits = generated_tokens if self.args.predict_with_generate else outputs[1] | ||
|
|
||
| labels_out = labels_out.detach() | ||
| labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length) | ||
| return (loss, logits.detach(), labels) | ||
| loss, logits = self._compute_loss(model, inputs) | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| loss = loss.mean().detach() | ||
| if self.args.prediction_loss_only: | ||
| return (loss, None, None) | ||
|
|
||
| logits = generated_tokens if self.args.predict_with_generate else logits | ||
|
|
||
| labels = inputs["labels"] | ||
| if self.config.pad_token_id is not None: | ||
| labels = self._pad_tensors_to_max_len(labels, self.config.max_length) | ||
|
|
||
| return (loss, logits, labels) | ||
|
|
||
| def _pad_tensors_to_max_len(self, tensor, max_length): | ||
| padded_tensor = self.config.pad_token_id * torch.ones( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,12 +5,14 @@ | |
|
|
||
| import pytest | ||
|
|
||
| from transformers import is_torch_available | ||
| from transformers import BertTokenizer, EncoderDecoderModel, is_torch_available | ||
| from transformers.file_utils import is_datasets_available | ||
| from transformers.testing_utils import TestCasePlus, slow | ||
| from transformers.trainer_callback import TrainerState | ||
| from transformers.trainer_utils import set_seed | ||
|
|
||
| from .finetune_trainer import main | ||
| from .finetune_trainer import Seq2SeqTrainingArguments, main | ||
| from .seq2seq_trainer import Seq2SeqTrainer | ||
| from .test_seq2seq_examples import MBART_TINY | ||
| from .utils import execute_async_std | ||
|
|
||
|
|
@@ -50,6 +52,117 @@ def test_finetune_trainer_slow(self): | |
| assert "test_generations.txt" in contents | ||
| assert "test_results.json" in contents | ||
|
|
||
| @slow | ||
| def test_finetune_bert2bert(self): | ||
| if not is_datasets_available(): | ||
| return | ||
|
|
||
| import datasets | ||
|
|
||
| bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny") | ||
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | ||
|
|
||
| bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size | ||
| bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id | ||
|
|
||
| train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is cool. |
||
| val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]") | ||
|
|
||
| train_dataset = train_dataset.select(range(32)) | ||
| val_dataset = val_dataset.select(range(16)) | ||
|
|
||
| rouge = datasets.load_metric("rouge") | ||
|
|
||
| batch_size = 4 | ||
|
|
||
| def _map_to_encoder_decoder_inputs(batch): | ||
| # Tokenizer will automatically set [BOS] <text> [EOS] | ||
| inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512) | ||
| outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128) | ||
| batch["input_ids"] = inputs.input_ids | ||
| batch["attention_mask"] = inputs.attention_mask | ||
|
|
||
| batch["decoder_input_ids"] = outputs.input_ids | ||
| batch["labels"] = outputs.input_ids.copy() | ||
| batch["labels"] = [ | ||
| [-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"] | ||
| ] | ||
| batch["decoder_attention_mask"] = outputs.attention_mask | ||
|
|
||
| assert all([len(x) == 512 for x in inputs.input_ids]) | ||
| assert all([len(x) == 128 for x in outputs.input_ids]) | ||
|
|
||
| return batch | ||
|
|
||
| def _compute_metrics(pred): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI by default you will get rouge1, rouge2, rougeL (if you don't overwrite |
||
| labels_ids = pred.label_ids | ||
| pred_ids = pred.predictions | ||
|
|
||
| # all unnecessary tokens are removed | ||
| pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) | ||
| label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) | ||
|
|
||
| rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[ | ||
| "rouge2" | ||
| ].mid | ||
|
|
||
| return { | ||
| "rouge2_precision": round(rouge_output.precision, 4), | ||
| "rouge2_recall": round(rouge_output.recall, 4), | ||
| "rouge2_fmeasure": round(rouge_output.fmeasure, 4), | ||
| } | ||
|
|
||
| # map train dataset | ||
| train_dataset = train_dataset.map( | ||
| _map_to_encoder_decoder_inputs, | ||
| batched=True, | ||
| batch_size=batch_size, | ||
| remove_columns=["article", "highlights"], | ||
| ) | ||
| train_dataset.set_format( | ||
| type="torch", | ||
| columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"], | ||
| ) | ||
|
|
||
| # same for validation dataset | ||
| val_dataset = val_dataset.map( | ||
| _map_to_encoder_decoder_inputs, | ||
| batched=True, | ||
| batch_size=batch_size, | ||
| remove_columns=["article", "highlights"], | ||
| ) | ||
| val_dataset.set_format( | ||
| type="torch", | ||
| columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"], | ||
| ) | ||
|
|
||
| output_dir = self.get_auto_remove_tmp_dir() | ||
|
|
||
| training_args = Seq2SeqTrainingArguments( | ||
| output_dir=output_dir, | ||
| per_device_train_batch_size=batch_size, | ||
| per_device_eval_batch_size=batch_size, | ||
| predict_with_generate=True, | ||
| evaluate_during_training=True, | ||
| do_train=True, | ||
| do_eval=True, | ||
| warmup_steps=0, | ||
| eval_steps=2, | ||
| logging_steps=2, | ||
| ) | ||
|
|
||
| # instantiate trainer | ||
| trainer = Seq2SeqTrainer( | ||
| model=bert2bert, | ||
| args=training_args, | ||
| compute_metrics=_compute_metrics, | ||
| train_dataset=train_dataset, | ||
| eval_dataset=val_dataset, | ||
| ) | ||
|
|
||
| # start training | ||
| trainer.train() | ||
|
|
||
| def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int): | ||
|
|
||
| # XXX: remove hardcoded path | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put at
Truefor backward compatibilityThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!