Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
run_hp_search_optuna,
run_hp_search_ray,
)
from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
from .tokenization_utils_base import PreTrainedTokenizerBase
Expand All @@ -45,6 +46,9 @@
default_hp_space,
distributed_broadcast_scalars,
distributed_concat,
nested_concat,
nested_numpify,
nested_xla_mesh_reduce,
set_seed,
)
from .training_args import TrainingArguments
Expand Down Expand Up @@ -293,6 +297,12 @@ def __init__(
self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None
self.use_tune_checkpoints = False
if self.args.label_names is None:
self.args.label_names = (
["start_positions, end_positions"]
if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
else ["labels"]
)

def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
Expand Down Expand Up @@ -1307,9 +1317,9 @@ def prediction_loop(
if loss is not None:
eval_losses.extend([loss] * batch_size)
if logits is not None:
preds = logits if preds is None else tuple(torch.cat((p, l), dim=0) for p, l in zip(preds, logits))
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
if labels is not None:
label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0)
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)

if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
Expand All @@ -1318,25 +1328,23 @@ def prediction_loop(
if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes:
if preds is not None:
preds = tuple(distributed_concat(p, num_total_examples=self.num_examples(dataloader)) for p in preds)
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None:
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None:
preds = tuple(xm.mesh_reduce(f"eval_preds_{i}", p, torch.cat) for i, p in enumerate(preds))
preds = nested_xla_mesh_reduce("eval_preds", preds)
if label_ids is not None:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
label_ids = nested_xla_mesh_reduce("eval_label_ids", label_ids, torch.cat)
if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()

# Finally, turn the aggregated tensors into numpy arrays.
if preds is not None:
preds = tuple(p.cpu().numpy() for p in preds)
if len(preds) == 1:
preds = preds[0]
preds = nested_numpify(preds)
if label_ids is not None:
label_ids = label_ids.cpu().numpy()
label_ids = nested_numpify(label_ids)

if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
Expand Down Expand Up @@ -1382,8 +1390,7 @@ def prediction_step(
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
A tuple with the loss, logits and labels (each being optional).
"""
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to test for old deprecated argument names since they have all been changed in the lib and the user can now set their own name if they have an old model they are still using.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good! this will reduce cruft


has_labels = all(inputs.get(k) is not None for k in self.args.label_names)
inputs = self._prepare_inputs(inputs)

with torch.no_grad():
Expand All @@ -1402,10 +1409,18 @@ def prediction_step(
if prediction_loss_only:
return (loss, None, None)

labels = inputs.get("labels")
if labels is not None:
labels = labels.detach()
return (loss, tuple(l.detach() for l in logits), labels)
logits = tuple(logit.detach() for logit in logits)
if len(logits) == 1:
logits = logits[0]

if has_labels:
labels = tuple(inputs.get(name).detach() for name in self.args.label_names)
if len(labels) == 1:
labels = labels[0]
else:
labels = None

return (loss, logits, labels)

def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
"""
Expand Down
42 changes: 41 additions & 1 deletion src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from .file_utils import is_tf_available, is_torch_available
from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
from .tokenization_utils_base import ExplicitEnum


Expand Down Expand Up @@ -132,9 +132,49 @@ class HPSearchBackend(ExplicitEnum):
}


def nested_concat(tensors, new_tensors, dim=0):
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
if is_torch_available():
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
return torch.cat((tensors, new_tensors), dim=dim)
else:
raise ImportError("Torch must be installed to use `nested_concat`")


def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_numpify(t) for t in tensors)
return tensors.cpu().numpy()


def nested_detach(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()


def nested_xla_mesh_reduce(tensors, name):
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm

if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")

Comment on lines +162 to +171
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you get a chance to test this on TPU?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, was planning to ask you about it this morning.


def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
if is_torch_available():
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
Expand Down
19 changes: 14 additions & 5 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .utils import logging
Expand Down Expand Up @@ -128,6 +128,12 @@ class TrainingArguments:
forward method.

(Note: this behavior is not implemented for :class:`~transformers.TFTrainer` yet.)
label_names (:obj:`List[str]`, `optional`):
The list of keys in your dictionary of inputs that correspond to the labels.

Will eventually default to :obj:`["labels"]` except if the model used is one of the
:obj:`XxxForQuestionAnswering` in which case it will default to
:obj:`["start_positions", "end_positions"]`.
"""

output_dir: str = field(
Expand Down Expand Up @@ -253,13 +259,16 @@ class TrainingArguments:
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
)

def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN

remove_unused_columns: Optional[bool] = field(
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
)
label_names: Optional[List[str]] = field(
default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
)

def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN

@property
def train_batch_size(self) -> int:
Expand Down
33 changes: 25 additions & 8 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,21 @@


class RegressionDataset:
def __init__(self, a=2, b=3, length=64, seed=42):
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
np.random.seed(seed)
self.label_names = ["labels"] if label_names is None else label_names
self.length = length
self.x = np.random.normal(size=(length,)).astype(np.float32)
self.y = a * self.x + b + np.random.normal(scale=0.1, size=(length,))
self.ys = [a * self.x + b + np.random.normal(scale=0.1, size=(length,)) for _ in self.label_names]
self.ys = [y.astype(np.float32) for y in self.ys]

def __len__(self):
return self.length

def __getitem__(self, i):
return {"input_x": self.x[i], "label": self.y[i]}
result = {name: y[i] for name, y in zip(self.label_names, self.ys)}
result["input_x"] = self.x[i]
return result


class AlmostAccuracy:
Expand Down Expand Up @@ -68,16 +72,17 @@ def __init__(self, a=0, b=0, double_output=False):
self.double_output = double_output
self.config = None

def forward(self, input_x=None, labels=None):
def forward(self, input_x=None, labels=None, **kwargs):
y = input_x * self.a + self.b
if labels is None:
return (y, y) if self.double_output else (y,)
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)

def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs):
train_dataset = RegressionDataset(length=train_len)
eval_dataset = RegressionDataset(length=eval_len)
label_names = kwargs.get("label_names", None)
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
model = RegressionModel(a, b, double_output)
compute_metrics = kwargs.pop("compute_metrics", None)
data_collator = kwargs.pop("data_collator", None)
Expand Down Expand Up @@ -174,7 +179,7 @@ def test_evaluate(self):
trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy())
results = trainer.evaluate()

x, y = trainer.eval_dataset.x, trainer.eval_dataset.y
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
Expand All @@ -185,7 +190,7 @@ def test_evaluate(self):
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy())
results = trainer.evaluate()

x, y = trainer.eval_dataset.x, trainer.eval_dataset.y
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
Expand All @@ -212,6 +217,18 @@ def test_predict(self):
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))

# With more than one output/label of the model
trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, label_names=["labels", "labels_2"])
outputs = trainer.predict(trainer.eval_dataset)
preds = outputs.predictions
labels = outputs.label_ids
x = trainer.eval_dataset.x
self.assertTrue(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))

def test_trainer_with_datasets(self):
np.random.seed(42)
x = np.random.normal(size=(64,)).astype(np.float32)
Expand Down