diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2e314c896e03..7cb7f7792061 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -31,6 +31,7 @@ run_hp_search_optuna, run_hp_search_ray, ) +from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING from .modeling_utils import PreTrainedModel from .optimization import AdamW, get_linear_schedule_with_warmup from .tokenization_utils_base import PreTrainedTokenizerBase @@ -45,6 +46,9 @@ default_hp_space, distributed_broadcast_scalars, distributed_concat, + nested_concat, + nested_numpify, + nested_xla_mesh_reduce, set_seed, ) from .training_args import TrainingArguments @@ -293,6 +297,12 @@ def __init__( self.scaler = torch.cuda.amp.GradScaler() self.hp_search_backend = None self.use_tune_checkpoints = False + if self.args.label_names is None: + self.args.label_names = ( + ["start_positions, end_positions"] + if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values() + else ["labels"] + ) def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): if not self.args.remove_unused_columns: @@ -1307,9 +1317,9 @@ def prediction_loop( if loss is not None: eval_losses.extend([loss] * batch_size) if logits is not None: - preds = logits if preds is None else tuple(torch.cat((p, l), dim=0) for p, l in zip(preds, logits)) + preds = logits if preds is None else nested_concat(preds, logits, dim=0) if labels is not None: - label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0) + label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0) if self.args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop @@ -1318,25 +1328,23 @@ def prediction_loop( if self.args.local_rank != -1: # In distributed mode, concatenate all results from all nodes: if preds is not None: - preds = tuple(distributed_concat(p, num_total_examples=self.num_examples(dataloader)) for p in preds) + preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader)) if label_ids is not None: label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader)) elif is_torch_tpu_available(): # tpu-comment: Get all predictions and labels from all worker shards of eval dataset if preds is not None: - preds = tuple(xm.mesh_reduce(f"eval_preds_{i}", p, torch.cat) for i, p in enumerate(preds)) + preds = nested_xla_mesh_reduce("eval_preds", preds) if label_ids is not None: - label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) + label_ids = nested_xla_mesh_reduce("eval_label_ids", label_ids, torch.cat) if eval_losses is not None: eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist() # Finally, turn the aggregated tensors into numpy arrays. if preds is not None: - preds = tuple(p.cpu().numpy() for p in preds) - if len(preds) == 1: - preds = preds[0] + preds = nested_numpify(preds) if label_ids is not None: - label_ids = label_ids.cpu().numpy() + label_ids = nested_numpify(label_ids) if self.compute_metrics is not None and preds is not None and label_ids is not None: metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) @@ -1382,8 +1390,7 @@ def prediction_step( Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and labels (each being optional). """ - has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) - + has_labels = all(inputs.get(k) is not None for k in self.args.label_names) inputs = self._prepare_inputs(inputs) with torch.no_grad(): @@ -1402,10 +1409,18 @@ def prediction_step( if prediction_loss_only: return (loss, None, None) - labels = inputs.get("labels") - if labels is not None: - labels = labels.detach() - return (loss, tuple(l.detach() for l in logits), labels) + logits = tuple(logit.detach() for logit in logits) + if len(logits) == 1: + logits = logits[0] + + if has_labels: + labels = tuple(inputs.get(name).detach() for name in self.args.label_names) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + return (loss, logits, labels) def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): """ diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index b7215b5f2764..e273b00aa895 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -3,7 +3,7 @@ import numpy as np -from .file_utils import is_tf_available, is_torch_available +from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available from .tokenization_utils_base import ExplicitEnum @@ -132,9 +132,49 @@ class HPSearchBackend(ExplicitEnum): } +def nested_concat(tensors, new_tensors, dim=0): + "Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors." + if is_torch_available(): + assert type(tensors) == type( + new_tensors + ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors)) + return torch.cat((tensors, new_tensors), dim=dim) + else: + raise ImportError("Torch must be installed to use `nested_concat`") + + +def nested_numpify(tensors): + "Numpify `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_numpify(t) for t in tensors) + return tensors.cpu().numpy() + + +def nested_detach(tensors): + "Detach `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_detach(t) for t in tensors) + return tensors.detach() + + +def nested_xla_mesh_reduce(tensors, name): + if is_torch_tpu_available(): + import torch_xla.core.xla_model as xm + + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors)) + return xm.mesh_reduce(name, tensors, torch.cat) + else: + raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`") + + def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor": if is_torch_available(): try: + if isinstance(tensor, (tuple, list)): + return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor) output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(output_tensors, tensor) concat = torch.cat(output_tensors, dim=0) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 60b86d28b789..c635595f2bfd 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2,7 +2,7 @@ import json import os from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required from .utils import logging @@ -128,6 +128,12 @@ class TrainingArguments: forward method. (Note: this behavior is not implemented for :class:`~transformers.TFTrainer` yet.) + label_names (:obj:`List[str]`, `optional`): + The list of keys in your dictionary of inputs that correspond to the labels. + + Will eventually default to :obj:`["labels"]` except if the model used is one of the + :obj:`XxxForQuestionAnswering` in which case it will default to + :obj:`["start_positions", "end_positions"]`. """ output_dir: str = field( @@ -253,13 +259,16 @@ class TrainingArguments: default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."} ) - def __post_init__(self): - if self.disable_tqdm is None: - self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN - remove_unused_columns: Optional[bool] = field( default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."} ) + label_names: Optional[List[str]] = field( + default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."} + ) + + def __post_init__(self): + if self.disable_tqdm is None: + self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN @property def train_batch_size(self) -> int: diff --git a/tests/test_trainer.py b/tests/test_trainer.py index f5bbe9145b78..239de5d6eb84 100755 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -24,17 +24,21 @@ class RegressionDataset: - def __init__(self, a=2, b=3, length=64, seed=42): + def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): np.random.seed(seed) + self.label_names = ["labels"] if label_names is None else label_names self.length = length self.x = np.random.normal(size=(length,)).astype(np.float32) - self.y = a * self.x + b + np.random.normal(scale=0.1, size=(length,)) + self.ys = [a * self.x + b + np.random.normal(scale=0.1, size=(length,)) for _ in self.label_names] + self.ys = [y.astype(np.float32) for y in self.ys] def __len__(self): return self.length def __getitem__(self, i): - return {"input_x": self.x[i], "label": self.y[i]} + result = {name: y[i] for name, y in zip(self.label_names, self.ys)} + result["input_x"] = self.x[i] + return result class AlmostAccuracy: @@ -68,7 +72,7 @@ def __init__(self, a=0, b=0, double_output=False): self.double_output = double_output self.config = None - def forward(self, input_x=None, labels=None): + def forward(self, input_x=None, labels=None, **kwargs): y = input_x * self.a + self.b if labels is None: return (y, y) if self.double_output else (y,) @@ -76,8 +80,9 @@ def forward(self, input_x=None, labels=None): return (loss, y, y) if self.double_output else (loss, y) def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs): - train_dataset = RegressionDataset(length=train_len) - eval_dataset = RegressionDataset(length=eval_len) + label_names = kwargs.get("label_names", None) + train_dataset = RegressionDataset(length=train_len, label_names=label_names) + eval_dataset = RegressionDataset(length=eval_len, label_names=label_names) model = RegressionModel(a, b, double_output) compute_metrics = kwargs.pop("compute_metrics", None) data_collator = kwargs.pop("data_collator", None) @@ -174,7 +179,7 @@ def test_evaluate(self): trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy()) results = trainer.evaluate() - x, y = trainer.eval_dataset.x, trainer.eval_dataset.y + x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] pred = 1.5 * x + 2.5 expected_loss = ((pred - y) ** 2).mean() self.assertAlmostEqual(results["eval_loss"], expected_loss) @@ -185,7 +190,7 @@ def test_evaluate(self): trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy()) results = trainer.evaluate() - x, y = trainer.eval_dataset.x, trainer.eval_dataset.y + x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] pred = 1.5 * x + 2.5 expected_loss = ((pred - y) ** 2).mean() self.assertAlmostEqual(results["eval_loss"], expected_loss) @@ -212,6 +217,18 @@ def test_predict(self): self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5)) self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5)) + # With more than one output/label of the model + trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, label_names=["labels", "labels_2"]) + outputs = trainer.predict(trainer.eval_dataset) + preds = outputs.predictions + labels = outputs.label_ids + x = trainer.eval_dataset.x + self.assertTrue(len(preds), 2) + self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5)) + self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5)) + self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0])) + self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1])) + def test_trainer_with_datasets(self): np.random.seed(42) x = np.random.normal(size=(64,)).astype(np.float32)