diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index 520df0e87b1c..99826e2228fb 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -153,7 +153,11 @@ def compute_loss(self, model, inputs): return loss def prediction_step( - self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on :obj:`model` using obj:`inputs`. diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 4e55c4db656f..f7587faac0a2 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -43,6 +43,8 @@ class PretrainedConfig(object): - **is_composition** (:obj:`bool`): Whether the config class is composed of multiple sub-configs. In this case the config has to be initialized from two or more configs of type :class:`~transformers.PretrainedConfig` like: :class:`~transformers.EncoderDecoderConfig` or :class:`~RagConfig`. + - **keys_to_ignore_at_inference** (:obj:`List[str]`): A list of keys to ignore by default when looking at + dictionary outputs of the model during inference. Args: name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`): diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index 8533a013be5a..67e92e89663b 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -110,6 +110,7 @@ class BartConfig(PretrainedConfig): :obj:`True` for `bart-large-cnn`. """ model_type = "bart" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/ctrl/configuration_ctrl.py b/src/transformers/models/ctrl/configuration_ctrl.py index faffaa0df96e..c2633c49b8d1 100644 --- a/src/transformers/models/ctrl/configuration_ctrl.py +++ b/src/transformers/models/ctrl/configuration_ctrl.py @@ -77,6 +77,7 @@ class CTRLConfig(PretrainedConfig): """ model_type = "ctrl" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py index 25cdcb49f21c..a30da248a5e4 100644 --- a/src/transformers/models/gpt2/configuration_gpt2.py +++ b/src/transformers/models/gpt2/configuration_gpt2.py @@ -120,6 +120,7 @@ class GPT2Config(PretrainedConfig): """ model_type = "gpt2" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/marian/configuration_marian.py b/src/transformers/models/marian/configuration_marian.py index d5769bcb9cc1..a17531bb2f4d 100644 --- a/src/transformers/models/marian/configuration_marian.py +++ b/src/transformers/models/marian/configuration_marian.py @@ -97,3 +97,4 @@ class MarianConfig(BartConfig): """ model_type = "marian" + keys_to_ignore_at_inference = ["past_key_values"] diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py index 743666027835..c8b4540e1efd 100644 --- a/src/transformers/models/mbart/configuration_mbart.py +++ b/src/transformers/models/mbart/configuration_mbart.py @@ -102,3 +102,4 @@ class MBartConfig(BartConfig): """ model_type = "mbart" + keys_to_ignore_at_inference = ["past_key_values"] diff --git a/src/transformers/models/mt5/configuration_mt5.py b/src/transformers/models/mt5/configuration_mt5.py index 23bde1004798..09e9ac2262c9 100644 --- a/src/transformers/models/mt5/configuration_mt5.py +++ b/src/transformers/models/mt5/configuration_mt5.py @@ -62,6 +62,7 @@ class MT5Config(PretrainedConfig): Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. """ model_type = "mt5" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/pegasus/configuration_pegasus.py b/src/transformers/models/pegasus/configuration_pegasus.py index f134ea583201..585f06ddb46e 100644 --- a/src/transformers/models/pegasus/configuration_pegasus.py +++ b/src/transformers/models/pegasus/configuration_pegasus.py @@ -141,4 +141,5 @@ class PegasusConfig(BartConfig): """ model_type = "pegasus" + keys_to_ignore_at_inference = ["past_key_values"] # The implementation of the config object is in BartConfig diff --git a/src/transformers/models/prophetnet/configuration_prophetnet.py b/src/transformers/models/prophetnet/configuration_prophetnet.py index f652043e660b..fdb6f5f30020 100644 --- a/src/transformers/models/prophetnet/configuration_prophetnet.py +++ b/src/transformers/models/prophetnet/configuration_prophetnet.py @@ -92,6 +92,7 @@ class ProphetNetConfig(PretrainedConfig): smoothing is performed. """ model_type = "prophetnet" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/reformer/configuration_reformer.py b/src/transformers/models/reformer/configuration_reformer.py index 69d178875ea3..9e860a48c9d7 100755 --- a/src/transformers/models/reformer/configuration_reformer.py +++ b/src/transformers/models/reformer/configuration_reformer.py @@ -153,6 +153,7 @@ class ReformerConfig(PretrainedConfig): >>> configuration = model.config """ model_type = "reformer" + keys_to_ignore_at_inference = ["past_buckets_states"] def __init__( self, diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index 48bdb6c32944..75b396742c86 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -71,6 +71,7 @@ class T5Config(PretrainedConfig): the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`. """ model_type = "t5" + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/transfo_xl/configuration_transfo_xl.py b/src/transformers/models/transfo_xl/configuration_transfo_xl.py index 9885cbfa2e08..1008f3488a69 100644 --- a/src/transformers/models/transfo_xl/configuration_transfo_xl.py +++ b/src/transformers/models/transfo_xl/configuration_transfo_xl.py @@ -105,6 +105,7 @@ class TransfoXLConfig(PretrainedConfig): """ model_type = "transfo-xl" + keys_to_ignore_at_inference = ["mems"] def __init__( self, diff --git a/src/transformers/models/xlnet/configuration_xlnet.py b/src/transformers/models/xlnet/configuration_xlnet.py index db102317903b..f0592a8d0b0c 100644 --- a/src/transformers/models/xlnet/configuration_xlnet.py +++ b/src/transformers/models/xlnet/configuration_xlnet.py @@ -128,6 +128,7 @@ class XLNetConfig(PretrainedConfig): """ model_type = "xlnet" + keys_to_ignore_at_inference = ["mems"] def __init__( self, diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 950e24291368..64c363afb5b1 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1098,10 +1098,11 @@ def compute_loss(self, model, inputs): """ outputs = model(**inputs) # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index] # We don't use .loss here since the model may return tuples instead of ModelOutput. - return outputs[0] + return outputs["loss"] if isinstance(outputs, dict) else outputs[0] def is_local_process_zero(self) -> bool: """ @@ -1220,7 +1221,9 @@ def _rotate_checkpoints(self, use_mtime=False) -> None: logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) shutil.rmtree(checkpoint) - def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: + def evaluate( + self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None + ) -> Dict[str, float]: """ Run evaluation and returns metrics. @@ -1234,6 +1237,9 @@ def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the :obj:`__len__` method. + ignore_keys (:obj:`Lst[str]`, `optional`): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The @@ -1250,6 +1256,7 @@ def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if self.compute_metrics is None else None, + ignore_keys=ignore_keys, ) self.log(output.metrics) @@ -1261,7 +1268,7 @@ def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) return output.metrics - def predict(self, test_dataset: Dataset) -> PredictionOutput: + def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput: """ Run prediction and returns predictions and potential metrics. @@ -1272,6 +1279,9 @@ def predict(self, test_dataset: Dataset) -> PredictionOutput: test_dataset (:obj:`Dataset`): Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__` + ignore_keys (:obj:`Lst[str]`, `optional`): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. .. note:: @@ -1291,10 +1301,14 @@ def predict(self, test_dataset: Dataset) -> PredictionOutput: test_dataloader = self.get_test_dataloader(test_dataset) - return self.prediction_loop(test_dataloader, description="Prediction") + return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys) def prediction_loop( - self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, ) -> PredictionOutput: """ Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. @@ -1346,7 +1360,7 @@ def prediction_loop( self.callback_handler.eval_dataloader = dataloader for step, inputs in enumerate(dataloader): - loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) if loss is not None: losses = loss.repeat(batch_size) losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) @@ -1410,7 +1424,11 @@ def _gather_and_numpify(self, tensors, name): return nested_numpify(tensors) def prediction_step( - self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on :obj:`model` using obj:`inputs`. @@ -1427,6 +1445,9 @@ def prediction_step( argument :obj:`labels`. Check your model's documentation for all accepted arguments. prediction_loss_only (:obj:`bool`): Whether or not to return the loss only. + ignore_keys (:obj:`Lst[str]`, `optional`): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. Return: Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and @@ -1434,6 +1455,11 @@ def prediction_step( """ has_labels = all(inputs.get(k) is not None for k in self.label_names) inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] with torch.no_grad(): if self.args.fp16 and _use_native_amp: @@ -1442,16 +1468,21 @@ def prediction_step( else: outputs = model(**inputs) if has_labels: - loss = outputs[0].mean().detach() - logits = outputs[1:] + if isinstance(outputs, dict): + loss = outputs["loss"].mean().detach() + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss = outputs[0].mean().detach() + logits = outputs[1:] else: loss = None - # Slicing so we get a tuple even if `outputs` is a `ModelOutput`. - logits = outputs[:] + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] - # Remove the past from the logits. - logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :] if prediction_loss_only: return (loss, None, None) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index b5db8c07121c..5d80654d48cd 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -44,6 +44,8 @@ DataCollatorForLanguageModeling, GlueDataset, GlueDataTrainingArguments, + GPT2Config, + GPT2LMHeadModel, LineByLineTextDataset, PreTrainedModel, TextDataset, @@ -73,6 +75,18 @@ def __getitem__(self, i): return result +class RepeatDataset: + def __init__(self, x, length=64): + self.x = x + self.length = length + + def __len__(self): + return self.length + + def __getitem__(self, i): + return {"input_ids": self.x, "labels": self.x} + + class DynamicShapesDataset: def __init__(self, length=64, seed=42, batch_size=8): self.length = length @@ -136,6 +150,20 @@ def forward(self, input_x=None, labels=None, **kwargs): loss = torch.nn.functional.mse_loss(y, labels) return (loss, y, y) if self.double_output else (loss, y) + class RegressionDictModel(torch.nn.Module): + def __init__(self, a=0, b=0): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor(a).float()) + self.b = torch.nn.Parameter(torch.tensor(b).float()) + self.config = None + + def forward(self, input_x=None, labels=None, **kwargs): + y = input_x * self.a + self.b + result = {"output": y} + if labels is not None: + result["loss"] = torch.nn.functional.mse_loss(y, labels) + return result + class RegressionPreTrainedModel(PreTrainedModel): config_class = RegressionModelConfig base_model_prefix = "regression" @@ -236,6 +264,33 @@ def check_best_model_has_been_loaded( metrics = trainer.evaluate() self.assertEqual(metrics[metric], best_value) + def test_trainer_works_with_dict(self): + # Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break + # anything. + train_dataset = RegressionDataset() + eval_dataset = RegressionDataset() + model = RegressionDictModel() + args = TrainingArguments("./regression") + trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + trainer.train() + _ = trainer.evaluate() + _ = trainer.predict(eval_dataset) + + def test_evaluation_with_keys_to_drop(self): + config = GPT2Config(vocab_size=100, n_positions=128, n_ctx=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + x = torch.randint(0, 100, (128,)) + eval_dataset = RepeatDataset(x) + args = TrainingArguments("./test") + trainer = Trainer(tiny_gpt2, args, eval_dataset=eval_dataset) + # By default the past_key_values are removed + result = trainer.predict(eval_dataset) + self.assertTrue(isinstance(result.predictions, np.ndarray)) + # We can still get them by setting ignore_keys to [] + result = trainer.predict(eval_dataset, ignore_keys=[]) + self.assertTrue(isinstance(result.predictions, tuple)) + self.assertEqual(len(result.predictions), 2) + def test_training_arguments_are_left_untouched(self): trainer = get_regression_trainer() trainer.train()