Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 25 additions & 60 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalLoopContainer,
EvalLoopOutput,
EvalPrediction,
HPSearchBackend,
Expand Down Expand Up @@ -3627,20 +3628,14 @@ def evaluation_loop(
self._past = None

# Initialize containers
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
losses_host = None
preds_host = None
labels_host = None
inputs_host = None

# losses/preds/labels on CPU (final containers)
all_losses = None
all_preds = None
all_labels = None
all_inputs = None
# Will be useful when we have an iterable dataset so don't know its length.
all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)

# Will be useful when we have an iterable dataset so don't know its length.
observed_num_examples = 0

# Main evaluation loop
for step, inputs in enumerate(dataloader):
# Update the observed num examples
Expand All @@ -3659,56 +3654,33 @@ def evaluation_loop(
if is_torch_xla_available():
xm.mark_step()

# Update containers on host
# Update containers
if loss is not None:
losses = self.gather_function((loss.repeat(batch_size)))
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
if labels is not None:
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
all_losses.add(losses)
if inputs_decode is not None:
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
inputs_decode = self.gather_function((inputs_decode))
inputs_host = (
inputs_decode
if inputs_host is None
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
)
all_inputs.add(inputs_decode)
if logits is not None:
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self.gather_function((logits))
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)

all_preds.add(logits)
if labels is not None:
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
labels = self.gather_function((labels))
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
all_labels.add(labels)

self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)

# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode
if all_inputs is None
else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = (
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
)

# Set back to None to begin a new accumulation
losses_host, preds_host, inputs_host, labels_host = None, None, None, None
all_losses.to_cpu_and_numpy()
all_preds.to_cpu_and_numpy()
all_labels.to_cpu_and_numpy()
all_inputs.to_cpu_and_numpy()

# After all calls to `.gather_function`, reset to `gather_for_metrics`:
self.gather_function = self.accelerator.gather_for_metrics
Expand All @@ -3717,20 +3689,10 @@ def evaluation_loop(
delattr(self, "_past")

# Gather all remaining tensors and put them back on the CPU
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
all_losses = all_losses.get_arrays()
all_preds = all_preds.get_arrays()
all_labels = all_labels.get_arrays()
all_inputs = all_inputs.get_arrays()

# Number of samples
if has_length(eval_dataset):
Expand Down Expand Up @@ -3761,7 +3723,9 @@ def evaluation_loop(
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)

if all_losses is not None:
if isinstance(all_losses, list) and all_losses:
metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()
elif isinstance(all_losses, np.ndarray):
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
if hasattr(self, "jit_compilation_time"):
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
Expand Down Expand Up @@ -4204,6 +4168,7 @@ def prediction_loop(
logger.info(f"***** Running {description} *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Batch size = {batch_size}")

losses_host: torch.Tensor = None
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
Expand Down
53 changes: 53 additions & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import numpy as np

from .trainer_pt_utils import nested_concat, nested_numpify
from .utils import (
ExplicitEnum,
is_psutil_available,
Expand Down Expand Up @@ -199,6 +200,58 @@ class TrainOutput(NamedTuple):
metrics: Dict[str, float]


class EvalLoopContainer:
"""
Container to store intermediate results of evaluation loop

Args:
do_nested_concat (`bool`, *optional*, defaults to `True`):
If set to `True`, each iteration will recursively concatenate a new object containing tensors to
the existing stored tensors, provided that the structure of the existing object and the new one
are identical. If set to `False`, all newly added tensors will be stored in a list.
padding_index (`int`, *optional*, defaults to -100):
Value used to pad tensors of different shapes when `do_nested_concat=True`.
"""

def __init__(self, do_nested_concat: bool = True, padding_index: int = -100):
self.do_nested_concat = do_nested_concat
self.padding_index = padding_index
self.tensors = None
self.arrays = None

def add(self, tensors) -> None:
"""Add tensors to the stored objects. If `do_nested_concat=True`, the tensors will be concatenated recursively."""
if self.tensors is None:
self.tensors = tensors if self.do_nested_concat else [tensors]
elif self.do_nested_concat:
self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index)
else:
self.tensors.append(tensors)

def to_cpu_and_numpy(self) -> None:
"""Move tensors in stored objects to CPU and convert them to numpy arrays."""

# Check if we have something to add, if not just return
if self.tensors is None:
return

new_arrays = nested_numpify(self.tensors)
if self.arrays is None:
self.arrays = new_arrays
elif self.do_nested_concat:
self.arrays = nested_concat(self.arrays, new_arrays, padding_index=self.padding_index)
else:
self.arrays.extend(new_arrays)

# reset device tensors after adding to cpu
self.tensors = None

def get_arrays(self):
"""Returns the numpified and moved to CPU stored objects."""
self.to_cpu_and_numpy()
return self.arrays


PREFIX_CHECKPOINT_DIR = "checkpoint"
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")

Expand Down
9 changes: 9 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,9 @@ class TrainingArguments:
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class.
eval_do_concat_batches (`bool`, *optional*, defaults to `True`):
If set to `False`, inputs/losses/labels/predictions are stored as lists, with each batch kept separate.
If set to `True`, tensors in these nested objects are recursively concatenated across batches.
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
Expand Down Expand Up @@ -1261,6 +1264,12 @@ class TrainingArguments:
include_inputs_for_metrics: bool = field(
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
)
eval_do_concat_batches: bool = field(
default=True,
metadata={
"help": "Whether or not tensors in nested objects in batches should be recursively concatenated between batches."
},
)
# Deprecated arguments
fp16_backend: str = field(
default="auto",
Expand Down
91 changes: 90 additions & 1 deletion tests/trainer/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from transformers.data.data_collator import default_data_collator
from transformers.testing_utils import require_accelerate, require_torch
from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size
from transformers.trainer_utils import EvalLoopContainer, RemoveColumnsCollator, find_executable_batch_size
from transformers.utils import is_torch_available


Expand Down Expand Up @@ -497,3 +497,92 @@ def info(self, msg):
remove_columns_collator(data_batch)
self.assertEqual(logger.called, 1)
self.assertIn("col3", logger.last_msg)

def test_eval_loop_container(self):
batch_1 = [
torch.ones([8, 5]),
{"loss": torch.tensor(1.0)},
(torch.ones([8, 2, 3]), torch.ones([8, 2])),
]
batch_2 = [
torch.ones([4, 5]),
{"loss": torch.tensor(2.0)},
(torch.ones([4, 2, 3]), torch.ones([4, 6])),
]

concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
concat_container.add(batch_1)
concat_container.add(batch_2)
concat_container.to_cpu_and_numpy()
arrays = concat_container.get_arrays()

# Test two nested batches concatenation
self.assertIsInstance(arrays, list)
self.assertEqual(len(arrays), 3)
self.assertIsInstance(arrays[0], np.ndarray)
self.assertEqual(arrays[0].shape, (12, 5))
self.assertIsInstance(arrays[1], dict)
self.assertIsInstance(arrays[1]["loss"], np.ndarray)
self.assertEqual(arrays[1]["loss"].shape, (2,))
self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0, 2.0])))
self.assertIsInstance(arrays[2], tuple)
self.assertEqual(len(arrays[2]), 2)
self.assertEqual(arrays[2][0].shape, (12, 2, 3))
self.assertEqual(arrays[2][1].shape, (12, 6))
# check that first batch padded with padding index -100 after concatenation
self.assertEqual(arrays[2][1][0][2], -100)

# Test two batches with no concatenation
list_container = EvalLoopContainer(do_nested_concat=False)
list_container.add(batch_1)
list_container.add(batch_2)
list_container.to_cpu_and_numpy()
arrays = list_container.get_arrays()

self.assertEqual(len(arrays), 2)
self.assertIsInstance(arrays, list)
np_batch_1, np_batch_2 = arrays

self.assertIsInstance(np_batch_1, list)
self.assertEqual(len(np_batch_1), 3)
self.assertIsInstance(np_batch_1[0], np.ndarray)
self.assertIsInstance(np_batch_1[1], dict)
self.assertIsInstance(np_batch_1[2], tuple)
self.assertEqual(np_batch_1[0].shape, (8, 5))
self.assertEqual(np_batch_1[1]["loss"].shape, ())
self.assertEqual(np_batch_1[2][0].shape, (8, 2, 3))
self.assertEqual(np_batch_1[2][1].shape, (8, 2))

self.assertIsInstance(np_batch_2, list)
self.assertEqual(len(np_batch_2), 3)
self.assertIsInstance(np_batch_2[0], np.ndarray)
self.assertIsInstance(np_batch_2[1], dict)
self.assertIsInstance(np_batch_2[2], tuple)
self.assertEqual(np_batch_2[0].shape, (4, 5))
self.assertEqual(np_batch_2[1]["loss"].shape, ())
self.assertEqual(np_batch_2[2][0].shape, (4, 2, 3))
self.assertEqual(np_batch_2[2][1].shape, (4, 6))

# Test no batches
none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=-100).get_arrays()
self.assertIsNone(none_arr)

none_arr = EvalLoopContainer(do_nested_concat=False).get_arrays()
self.assertIsNone(none_arr)

# Test one batch
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
concat_container.add(batch_1)
arrays = concat_container.get_arrays()
self.assertIsInstance(arrays, list)
self.assertEqual(len(arrays), 3)
self.assertIsInstance(arrays[0], np.ndarray)
self.assertEqual(arrays[0].shape, (8, 5))
self.assertIsInstance(arrays[1], dict)
self.assertIsInstance(arrays[1]["loss"], np.ndarray)
self.assertEqual(arrays[1]["loss"].shape, ())
self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0])))
self.assertIsInstance(arrays[2], tuple)
self.assertEqual(len(arrays[2]), 2)
self.assertEqual(arrays[2][0].shape, (8, 2, 3))
self.assertEqual(arrays[2][1].shape, (8, 2))