Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -153,7 +153,11 @@ def compute_loss(self, model, inputs):
return loss

def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class PretrainedConfig(object):
- **is_composition** (:obj:`bool`): Whether the config class is composed of multiple sub-configs. In this case
the config has to be initialized from two or more configs of type :class:`~transformers.PretrainedConfig`
like: :class:`~transformers.EncoderDecoderConfig` or :class:`~RagConfig`.
- **keys_to_ignore_at_inference** (:obj:`List[str]`): A list of keys to ignore by default when looking at
dictionary outputs of the model during inference.

Args:
name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class BartConfig(PretrainedConfig):
:obj:`True` for `bart-large-cnn`.
"""
model_type = "bart"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/ctrl/configuration_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class CTRLConfig(PretrainedConfig):
"""

model_type = "ctrl"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt2/configuration_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class GPT2Config(PretrainedConfig):
"""

model_type = "gpt2"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/marian/configuration_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,4 @@ class MarianConfig(BartConfig):
"""

model_type = "marian"
keys_to_ignore_at_inference = ["past_key_values"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit late now, but I'm not a huge fan of the name to be honest -> this seems to be very specific to training, but one might think now that past_key_values can never be passed during inference in general. Why not call it keys_to_ignore_at_training?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is not for training, only for inference. During training we only get the loss in the outputs.
And this is not ignore to pass to the model, but ignore because they are not part of the logits/scores/predictions we want to gather. Maybe output_keys_to_ignore_at_inference is clearer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! Yeah I think output_keys_to_ignore_at_inference would be a bit clearer to me :-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #8857

1 change: 1 addition & 0 deletions src/transformers/models/mbart/configuration_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,4 @@ class MBartConfig(BartConfig):
"""

model_type = "mbart"
keys_to_ignore_at_inference = ["past_key_values"]
1 change: 1 addition & 0 deletions src/transformers/models/mt5/configuration_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class MT5Config(PretrainedConfig):
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`.
"""
model_type = "mt5"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/pegasus/configuration_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,5 @@ class PegasusConfig(BartConfig):
"""

model_type = "pegasus"
keys_to_ignore_at_inference = ["past_key_values"]
# The implementation of the config object is in BartConfig
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class ProphetNetConfig(PretrainedConfig):
smoothing is performed.
"""
model_type = "prophetnet"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/reformer/configuration_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class ReformerConfig(PretrainedConfig):
>>> configuration = model.config
"""
model_type = "reformer"
keys_to_ignore_at_inference = ["past_buckets_states"]

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/t5/configuration_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class T5Config(PretrainedConfig):
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
"""
model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class TransfoXLConfig(PretrainedConfig):
"""

model_type = "transfo-xl"
keys_to_ignore_at_inference = ["mems"]

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/xlnet/configuration_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class XLNetConfig(PretrainedConfig):
"""

model_type = "xlnet"
keys_to_ignore_at_inference = ["mems"]

def __init__(
self,
Expand Down
57 changes: 44 additions & 13 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,10 +1098,11 @@ def compute_loss(self, model, inputs):
"""
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
# We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs[0]
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]

def is_local_process_zero(self) -> bool:
"""
Expand Down Expand Up @@ -1220,7 +1221,9 @@ def _rotate_checkpoints(self, use_mtime=False) -> None:
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
shutil.rmtree(checkpoint)

def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
def evaluate(
self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.

Expand All @@ -1234,6 +1237,9 @@ def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
:obj:`__len__` method.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.

Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
Expand All @@ -1250,6 +1256,7 @@ def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
)

self.log(output.metrics)
Expand All @@ -1261,7 +1268,7 @@ def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return output.metrics

def predict(self, test_dataset: Dataset) -> PredictionOutput:
def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.

Expand All @@ -1272,6 +1279,9 @@ def predict(self, test_dataset: Dataset) -> PredictionOutput:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.

.. note::

Expand All @@ -1291,10 +1301,14 @@ def predict(self, test_dataset: Dataset) -> PredictionOutput:

test_dataloader = self.get_test_dataloader(test_dataset)

return self.prediction_loop(test_dataloader, description="Prediction")
return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys)

def prediction_loop(
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Expand Down Expand Up @@ -1346,7 +1360,7 @@ def prediction_loop(
self.callback_handler.eval_dataloader = dataloader

for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
if loss is not None:
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
Expand Down Expand Up @@ -1410,7 +1424,11 @@ def _gather_and_numpify(self, tensors, name):
return nested_numpify(tensors)

def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
Expand All @@ -1427,13 +1445,21 @@ def prediction_step(
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.

Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
labels (each being optional).
"""
has_labels = all(inputs.get(k) is not None for k in self.label_names)
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []

with torch.no_grad():
if self.args.fp16 and _use_native_amp:
Expand All @@ -1442,16 +1468,21 @@ def prediction_step(
else:
outputs = model(**inputs)
if has_labels:
loss = outputs[0].mean().detach()
logits = outputs[1:]
if isinstance(outputs, dict):
loss = outputs["loss"].mean().detach()
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
loss = outputs[0].mean().detach()
logits = outputs[1:]
else:
loss = None
# Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
logits = outputs[:]
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
# Remove the past from the logits.
logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :]

if prediction_loss_only:
return (loss, None, None)
Expand Down
55 changes: 55 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
DataCollatorForLanguageModeling,
GlueDataset,
GlueDataTrainingArguments,
GPT2Config,
GPT2LMHeadModel,
LineByLineTextDataset,
PreTrainedModel,
TextDataset,
Expand Down Expand Up @@ -73,6 +75,18 @@ def __getitem__(self, i):
return result


class RepeatDataset:
def __init__(self, x, length=64):
self.x = x
self.length = length

def __len__(self):
return self.length

def __getitem__(self, i):
return {"input_ids": self.x, "labels": self.x}


class DynamicShapesDataset:
def __init__(self, length=64, seed=42, batch_size=8):
self.length = length
Expand Down Expand Up @@ -136,6 +150,20 @@ def forward(self, input_x=None, labels=None, **kwargs):
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)

class RegressionDictModel(torch.nn.Module):
def __init__(self, a=0, b=0):
super().__init__()
self.a = torch.nn.Parameter(torch.tensor(a).float())
self.b = torch.nn.Parameter(torch.tensor(b).float())
self.config = None

def forward(self, input_x=None, labels=None, **kwargs):
y = input_x * self.a + self.b
result = {"output": y}
if labels is not None:
result["loss"] = torch.nn.functional.mse_loss(y, labels)
return result

class RegressionPreTrainedModel(PreTrainedModel):
config_class = RegressionModelConfig
base_model_prefix = "regression"
Expand Down Expand Up @@ -236,6 +264,33 @@ def check_best_model_has_been_loaded(
metrics = trainer.evaluate()
self.assertEqual(metrics[metric], best_value)

def test_trainer_works_with_dict(self):
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
# anything.
train_dataset = RegressionDataset()
eval_dataset = RegressionDataset()
model = RegressionDictModel()
args = TrainingArguments("./regression")
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train()
_ = trainer.evaluate()
_ = trainer.predict(eval_dataset)

def test_evaluation_with_keys_to_drop(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_ctx=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
eval_dataset = RepeatDataset(x)
args = TrainingArguments("./test")
trainer = Trainer(tiny_gpt2, args, eval_dataset=eval_dataset)
# By default the past_key_values are removed
result = trainer.predict(eval_dataset)
self.assertTrue(isinstance(result.predictions, np.ndarray))
# We can still get them by setting ignore_keys to []
result = trainer.predict(eval_dataset, ignore_keys=[])
self.assertTrue(isinstance(result.predictions, tuple))
self.assertEqual(len(result.predictions), 2)

def test_training_arguments_are_left_untouched(self):
trainer = get_regression_trainer()
trainer.train()
Expand Down