From 3be12de2702ff97e614a44c0912df41a756c7a76 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 16 Apr 2024 11:06:46 +0000 Subject: [PATCH 1/7] Add evaluation loop container for interm. results --- src/transformers/trainer.py | 87 +++++++++---------------------- src/transformers/trainer_utils.py | 53 +++++++++++++++++++ src/transformers/training_args.py | 6 +++ 3 files changed, 85 insertions(+), 61 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 45b45992bf42..94da32856b3f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -104,6 +104,7 @@ from .trainer_utils import ( PREFIX_CHECKPOINT_DIR, BestRun, + EvalLoopContainer, EvalLoopOutput, EvalPrediction, HPSearchBackend, @@ -3627,20 +3628,14 @@ def evaluation_loop( self._past = None # Initialize containers - # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) - losses_host = None - preds_host = None - labels_host = None - inputs_host = None - - # losses/preds/labels on CPU (final containers) - all_losses = None - all_preds = None - all_labels = None - all_inputs = None + all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + # Will be useful when we have an iterable dataset so don't know its length. - observed_num_examples = 0 + # Main evaluation loop for step, inputs in enumerate(dataloader): # Update the observed num examples @@ -3659,56 +3654,33 @@ def evaluation_loop( if is_torch_xla_available(): xm.mark_step() - # Update containers on host + # Update containers if loss is not None: losses = self.gather_function((loss.repeat(batch_size))) - losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) - if labels is not None: - labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) + all_losses.add(losses) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) inputs_decode = self.gather_function((inputs_decode)) - inputs_host = ( - inputs_decode - if inputs_host is None - else nested_concat(inputs_host, inputs_decode, padding_index=-100) - ) + all_inputs.add(inputs_decode) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.gather_function((logits)) - preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) - + all_preds.add(logits) if labels is not None: + labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) labels = self.gather_function((labels)) - labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + all_labels.add(labels) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: - if losses_host is not None: - losses = nested_numpify(losses_host) - all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) - if preds_host is not None: - logits = nested_numpify(preds_host) - all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) - if inputs_host is not None: - inputs_decode = nested_numpify(inputs_host) - all_inputs = ( - inputs_decode - if all_inputs is None - else nested_concat(all_inputs, inputs_decode, padding_index=-100) - ) - if labels_host is not None: - labels = nested_numpify(labels_host) - all_labels = ( - labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) - ) - - # Set back to None to begin a new accumulation - losses_host, preds_host, inputs_host, labels_host = None, None, None, None + all_losses.to_cpu_and_numpy() + all_preds.to_cpu_and_numpy() + all_labels.to_cpu_and_numpy() + all_inputs.to_cpu_and_numpy() # After all calls to `.gather_function`, reset to `gather_for_metrics`: self.gather_function = self.accelerator.gather_for_metrics @@ -3717,20 +3689,10 @@ def evaluation_loop( delattr(self, "_past") # Gather all remaining tensors and put them back on the CPU - if losses_host is not None: - losses = nested_numpify(losses_host) - all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) - if preds_host is not None: - logits = nested_numpify(preds_host) - all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) - if inputs_host is not None: - inputs_decode = nested_numpify(inputs_host) - all_inputs = ( - inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) - ) - if labels_host is not None: - labels = nested_numpify(labels_host) - all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + all_losses = all_losses.get_arrays() + all_preds = all_preds.get_arrays() + all_labels = all_labels.get_arrays() + all_inputs = all_inputs.get_arrays() # Number of samples if has_length(eval_dataset): @@ -3760,8 +3722,10 @@ def evaluation_loop( # To be JSON-serializable, we need to remove numpy types or zero-d tensors metrics = denumpify_detensorize(metrics) - - if all_losses is not None: + + if isinstance(all_losses, list) and all_losses: + metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item() + elif isinstance(all_losses, np.ndarray): metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() if hasattr(self, "jit_compilation_time"): metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time @@ -4204,6 +4168,7 @@ def prediction_loop( logger.info(f"***** Running {description} *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Batch size = {batch_size}") + losses_host: torch.Tensor = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 5c57ce0696f6..4b35018d62ab 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -29,6 +29,7 @@ import numpy as np +from .trainer_pt_utils import nested_concat, nested_numpify from .utils import ( ExplicitEnum, is_psutil_available, @@ -199,6 +200,58 @@ class TrainOutput(NamedTuple): metrics: Dict[str, float] +class EvalLoopContainer: + """ + Container to store intermediate results of evaluation loop + + Args: + do_nested_concat (`bool`, *optional*, defaults to `True`): + If set to `True`, each iteration will recursively concatenate a new object containing tensors to + the existing stored tensors, provided that the structure of the existing object and the new one + are identical. If set to `False`, all newly added tensors will be stored in a list. + padding_index (`int`, *optional*, defaults to -100): + Value used to pad tensors of different shapes when `do_nested_concat=True`. + """ + + def __init__(self, do_nested_concat: bool = True, padding_index: int = -100): + self.do_nested_concat = do_nested_concat + self.padding_index = padding_index + self.tensors = None + self.arrays = None + + def add(self, tensors) -> None: + """Add tensors to the stored objects. If `do_nested_concat=True`, the tensors will be concatenated recursively.""" + if self.tensors is None: + self.tensors = tensors if self.do_nested_concat else [tensors] + elif self.do_nested_concat: + self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index) + else: + self.tensors.append(tensors) + + def to_cpu_and_numpy(self) -> None: + """Move tensors in stored objects to CPU and convert them to numpy arrays.""" + + # Check if we have something to add, if not just return + if self.tensors is None: + return + + new_arrays = nested_numpify(self.tensors) + if self.arrays is None: + self.arrays = new_arrays + elif self.do_nested_concat: + self.arrays = nested_concat(self.arrays, new_arrays, padding_index=self.padding_index) + else: + self.arrays.extend(new_arrays) + + # reset device tensors after adding to cpu + self.tensors = None + + def get_arrays(self): + """Returns the numpified and moved to CPU stored objects.""" + self.to_cpu_and_numpy() + return self.arrays + + PREFIX_CHECKPOINT_DIR = "checkpoint" _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index cdf6325c4b4a..464c1188085e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -639,6 +639,9 @@ class TrainingArguments: include_inputs_for_metrics (`bool`, *optional*, defaults to `False`): Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics that need inputs, predictions and references for scoring calculation in Metric class. + eval_do_concat_batches (`bool`, *optional*, defaults to `True`): + If set to `False`, inputs/losses/labels/predictions are stored as lists, with each batch kept separate. + If set to `True`, tensors in these nested objects are recursively concatenated across batches. auto_find_batch_size (`bool`, *optional*, defaults to `False`) Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) @@ -1261,6 +1264,9 @@ class TrainingArguments: include_inputs_for_metrics: bool = field( default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."} ) + eval_do_concat_batches: bool = field( + default=True, metadata={"help": "Whether or not tensors in nested objects in batches should be recursively concatenated between batches."} + ) # Deprecated arguments fp16_backend: str = field( default="auto", From 4b9bab54ece9328e84023a94c6558dac1692cff7 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 16 Apr 2024 11:07:16 +0000 Subject: [PATCH 2/7] Add tests for EvalLoopContainer --- tests/trainer/test_trainer_utils.py | 90 ++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer_utils.py b/tests/trainer/test_trainer_utils.py index ccf162677e9f..8acca94a59a1 100644 --- a/tests/trainer/test_trainer_utils.py +++ b/tests/trainer/test_trainer_utils.py @@ -20,7 +20,7 @@ from transformers.data.data_collator import default_data_collator from transformers.testing_utils import require_accelerate, require_torch -from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size +from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size, EvalLoopContainer from transformers.utils import is_torch_available @@ -497,3 +497,91 @@ def info(self, msg): remove_columns_collator(data_batch) self.assertEqual(logger.called, 1) self.assertIn("col3", logger.last_msg) + + def test_eval_loop_container(self): + + batch_1 = [ + torch.ones([8, 5]), + {"loss": torch.tensor(1.)}, + (torch.ones([8, 2, 3]), torch.ones([8, 2])), + ] + batch_2 = [ + torch.ones([4, 5]), + {"loss": torch.tensor(2.)}, + (torch.ones([4, 2, 3]), torch.ones([4, 6])), + ] + + concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=False) + concat_container.add(batch_1) + concat_container.add(batch_2) + concat_container.to_cpu_and_numpy() + arrays = concat_container.get_arrays() + + # Test two nested batches concatenation + self.assertIsInstance(arrays, list) + self.assertEqual(len(arrays), 3) + self.assertIsInstance(arrays[0], np.ndarray) + self.assertEqual(arrays[0].shape, (12, 5)) + self.assertIsInstance(arrays[1], dict) + self.assertIsInstance(arrays[1]["loss"], np.ndarray) + self.assertEqual(arrays[1]["loss"].shape, (2,)) + self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1., 2.]))) + self.assertIsInstance(arrays[2], tuple) + self.assertEqual(len(arrays[2]), 2) + self.assertEqual(arrays[2][0].shape, (12, 2, 3)) + self.assertEqual(arrays[2][1].shape, (12, 6)) + + # Test tow batches with no concatenation + list_container = EvalLoopContainer(do_nested_concat=False) + list_container.add(batch_1) + list_container.add(batch_2) + list_container.to_cpu_and_numpy() + arrays = list_container.get_arrays() + + self.assertEqual(len(arrays), 2) + self.assertIsInstance(arrays, list) + np_batch_1, np_batch_2 = arrays + + self.assertIsInstance(np_batch_1, list) + self.assertEqual(len(np_batch_1), 3) + self.assertIsInstance(np_batch_1[0], np.ndarray) + self.assertIsInstance(np_batch_1[1], dict) + self.assertIsInstance(np_batch_1[2], tuple) + self.assertEqual(np_batch_1[0].shape, (8, 5)) + self.assertEqual(np_batch_1[1]["loss"].shape, ()) + self.assertEqual(np_batch_1[2][0].shape, (8, 2, 3)) + self.assertEqual(np_batch_1[2][1].shape, (8, 2)) + + self.assertIsInstance(np_batch_2, list) + self.assertEqual(len(np_batch_2), 3) + self.assertIsInstance(np_batch_2[0], np.ndarray) + self.assertIsInstance(np_batch_2[1], dict) + self.assertIsInstance(np_batch_2[2], tuple) + self.assertEqual(np_batch_2[0].shape, (4, 5)) + self.assertEqual(np_batch_2[1]["loss"].shape, ()) + self.assertEqual(np_batch_2[2][0].shape, (4, 2, 3)) + self.assertEqual(np_batch_2[2][1].shape, (4, 6)) + + # Test no batches + none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=False).get_arrays() + self.assertIsNone(none_arr) + + none_arr = EvalLoopContainer(do_nested_concat=False).get_arrays() + self.assertIsNone(none_arr) + + # Test one batch + concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=False) + concat_container.add(batch_1) + arrays = concat_container.get_arrays() + self.assertIsInstance(arrays, list) + self.assertEqual(len(arrays), 3) + self.assertIsInstance(arrays[0], np.ndarray) + self.assertEqual(arrays[0].shape, (8, 5)) + self.assertIsInstance(arrays[1], dict) + self.assertIsInstance(arrays[1]["loss"], np.ndarray) + self.assertEqual(arrays[1]["loss"].shape, ()) + self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.]))) + self.assertIsInstance(arrays[2], tuple) + self.assertEqual(len(arrays[2]), 2) + self.assertEqual(arrays[2][0].shape, (8, 2, 3)) + self.assertEqual(arrays[2][1].shape, (8, 2)) From bcbe83f8532cf0ad3ef60ff851592dede5ba14e1 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 16 Apr 2024 11:23:41 +0000 Subject: [PATCH 3/7] Formatting --- src/transformers/trainer.py | 6 +++--- src/transformers/trainer_utils.py | 14 +++++++------- src/transformers/training_args.py | 7 +++++-- tests/trainer/test_trainer_utils.py | 11 +++++------ 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 94da32856b3f..330d2d031653 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3632,7 +3632,7 @@ def evaluation_loop( all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) - + # Will be useful when we have an iterable dataset so don't know its length. observed_num_examples = 0 @@ -3722,7 +3722,7 @@ def evaluation_loop( # To be JSON-serializable, we need to remove numpy types or zero-d tensors metrics = denumpify_detensorize(metrics) - + if isinstance(all_losses, list) and all_losses: metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item() elif isinstance(all_losses, np.ndarray): @@ -4168,7 +4168,7 @@ def prediction_loop( logger.info(f"***** Running {description} *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Batch size = {batch_size}") - + losses_host: torch.Tensor = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 4b35018d62ab..db725dbd90ff 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -203,11 +203,11 @@ class TrainOutput(NamedTuple): class EvalLoopContainer: """ Container to store intermediate results of evaluation loop - + Args: do_nested_concat (`bool`, *optional*, defaults to `True`): - If set to `True`, each iteration will recursively concatenate a new object containing tensors to - the existing stored tensors, provided that the structure of the existing object and the new one + If set to `True`, each iteration will recursively concatenate a new object containing tensors to + the existing stored tensors, provided that the structure of the existing object and the new one are identical. If set to `False`, all newly added tensors will be stored in a list. padding_index (`int`, *optional*, defaults to -100): Value used to pad tensors of different shapes when `do_nested_concat=True`. @@ -218,7 +218,7 @@ def __init__(self, do_nested_concat: bool = True, padding_index: int = -100): self.padding_index = padding_index self.tensors = None self.arrays = None - + def add(self, tensors) -> None: """Add tensors to the stored objects. If `do_nested_concat=True`, the tensors will be concatenated recursively.""" if self.tensors is None: @@ -227,14 +227,14 @@ def add(self, tensors) -> None: self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index) else: self.tensors.append(tensors) - + def to_cpu_and_numpy(self) -> None: """Move tensors in stored objects to CPU and convert them to numpy arrays.""" # Check if we have something to add, if not just return if self.tensors is None: - return - + return + new_arrays = nested_numpify(self.tensors) if self.arrays is None: self.arrays = new_arrays diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 464c1188085e..f4f28d5ed110 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -640,7 +640,7 @@ class TrainingArguments: Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics that need inputs, predictions and references for scoring calculation in Metric class. eval_do_concat_batches (`bool`, *optional*, defaults to `True`): - If set to `False`, inputs/losses/labels/predictions are stored as lists, with each batch kept separate. + If set to `False`, inputs/losses/labels/predictions are stored as lists, with each batch kept separate. If set to `True`, tensors in these nested objects are recursively concatenated across batches. auto_find_batch_size (`bool`, *optional*, defaults to `False`) Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding @@ -1265,7 +1265,10 @@ class TrainingArguments: default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."} ) eval_do_concat_batches: bool = field( - default=True, metadata={"help": "Whether or not tensors in nested objects in batches should be recursively concatenated between batches."} + default=True, + metadata={ + "help": "Whether or not tensors in nested objects in batches should be recursively concatenated between batches." + }, ) # Deprecated arguments fp16_backend: str = field( diff --git a/tests/trainer/test_trainer_utils.py b/tests/trainer/test_trainer_utils.py index 8acca94a59a1..0a72be470911 100644 --- a/tests/trainer/test_trainer_utils.py +++ b/tests/trainer/test_trainer_utils.py @@ -20,7 +20,7 @@ from transformers.data.data_collator import default_data_collator from transformers.testing_utils import require_accelerate, require_torch -from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size, EvalLoopContainer +from transformers.trainer_utils import EvalLoopContainer, RemoveColumnsCollator, find_executable_batch_size from transformers.utils import is_torch_available @@ -499,15 +499,14 @@ def info(self, msg): self.assertIn("col3", logger.last_msg) def test_eval_loop_container(self): - batch_1 = [ torch.ones([8, 5]), - {"loss": torch.tensor(1.)}, + {"loss": torch.tensor(1.0)}, (torch.ones([8, 2, 3]), torch.ones([8, 2])), ] batch_2 = [ torch.ones([4, 5]), - {"loss": torch.tensor(2.)}, + {"loss": torch.tensor(2.0)}, (torch.ones([4, 2, 3]), torch.ones([4, 6])), ] @@ -525,7 +524,7 @@ def test_eval_loop_container(self): self.assertIsInstance(arrays[1], dict) self.assertIsInstance(arrays[1]["loss"], np.ndarray) self.assertEqual(arrays[1]["loss"].shape, (2,)) - self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1., 2.]))) + self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0, 2.0]))) self.assertIsInstance(arrays[2], tuple) self.assertEqual(len(arrays[2]), 2) self.assertEqual(arrays[2][0].shape, (12, 2, 3)) @@ -580,7 +579,7 @@ def test_eval_loop_container(self): self.assertIsInstance(arrays[1], dict) self.assertIsInstance(arrays[1]["loss"], np.ndarray) self.assertEqual(arrays[1]["loss"].shape, ()) - self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.]))) + self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0]))) self.assertIsInstance(arrays[2], tuple) self.assertEqual(len(arrays[2]), 2) self.assertEqual(arrays[2][0].shape, (8, 2, 3)) From 7deb60f167bbef448a1cbf53cb2e62e1c77b4ede Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 16 Apr 2024 17:53:30 +0000 Subject: [PATCH 4/7] Fix padding_index in test and typo --- tests/trainer/test_trainer_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_trainer_utils.py b/tests/trainer/test_trainer_utils.py index 0a72be470911..f14f1093c044 100644 --- a/tests/trainer/test_trainer_utils.py +++ b/tests/trainer/test_trainer_utils.py @@ -510,7 +510,7 @@ def test_eval_loop_container(self): (torch.ones([4, 2, 3]), torch.ones([4, 6])), ] - concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=False) + concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100) concat_container.add(batch_1) concat_container.add(batch_2) concat_container.to_cpu_and_numpy() @@ -529,8 +529,10 @@ def test_eval_loop_container(self): self.assertEqual(len(arrays[2]), 2) self.assertEqual(arrays[2][0].shape, (12, 2, 3)) self.assertEqual(arrays[2][1].shape, (12, 6)) + # check that first batch padded with padding index -100 after concatenation + self.assertEqual(arrays[2][1][0][2], -100) - # Test tow batches with no concatenation + # Test two batches with no concatenation list_container = EvalLoopContainer(do_nested_concat=False) list_container.add(batch_1) list_container.add(batch_2) @@ -562,14 +564,14 @@ def test_eval_loop_container(self): self.assertEqual(np_batch_2[2][1].shape, (4, 6)) # Test no batches - none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=False).get_arrays() + none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=-100).get_arrays() self.assertIsNone(none_arr) none_arr = EvalLoopContainer(do_nested_concat=False).get_arrays() self.assertIsNone(none_arr) # Test one batch - concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=False) + concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100) concat_container.add(batch_1) arrays = concat_container.get_arrays() self.assertIsInstance(arrays, list) From ec048e145ccc4caf8a56dfc2afc94eb849931d49 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 17 Apr 2024 10:35:43 +0000 Subject: [PATCH 5/7] Move EvalLoopContainer to pr_utils to avoid additional imports --- src/transformers/trainer.py | 2 +- src/transformers/trainer_pt_utils.py | 52 +++++++++++++++++++++++++++ src/transformers/trainer_utils.py | 53 ---------------------------- 3 files changed, 53 insertions(+), 54 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 330d2d031653..92025cb979d3 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -82,6 +82,7 @@ ) from .trainer_pt_utils import ( DistributedTensorGatherer, + EvalLoopContainer, IterableDatasetShard, LabelSmoother, LayerWiseDummyOptimizer, @@ -104,7 +105,6 @@ from .trainer_utils import ( PREFIX_CHECKPOINT_DIR, BestRun, - EvalLoopContainer, EvalLoopOutput, EvalPrediction, HPSearchBackend, diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 9ee670e94288..a4372ae78a79 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -299,6 +299,58 @@ def __iter__(self): return iter(indices) +class EvalLoopContainer: + """ + Container to store intermediate results of evaluation loop + + Args: + do_nested_concat (`bool`, *optional*, defaults to `True`): + If set to `True`, each iteration will recursively concatenate a new object containing tensors to + the existing stored tensors, provided that the structure of the existing object and the new one + are identical. If set to `False`, all newly added tensors will be stored in a list. + padding_index (`int`, *optional*, defaults to -100): + Value used to pad tensors of different shapes when `do_nested_concat=True`. + """ + + def __init__(self, do_nested_concat: bool = True, padding_index: int = -100): + self.do_nested_concat = do_nested_concat + self.padding_index = padding_index + self.tensors = None + self.arrays = None + + def add(self, tensors) -> None: + """Add tensors to the stored objects. If `do_nested_concat=True`, the tensors will be concatenated recursively.""" + if self.tensors is None: + self.tensors = tensors if self.do_nested_concat else [tensors] + elif self.do_nested_concat: + self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index) + else: + self.tensors.append(tensors) + + def to_cpu_and_numpy(self) -> None: + """Move tensors in stored objects to CPU and convert them to numpy arrays.""" + + # Check if we have something to add, if not just return + if self.tensors is None: + return + + new_arrays = nested_numpify(self.tensors) + if self.arrays is None: + self.arrays = new_arrays + elif self.do_nested_concat: + self.arrays = nested_concat(self.arrays, new_arrays, padding_index=self.padding_index) + else: + self.arrays.extend(new_arrays) + + # reset device tensors after adding to cpu + self.tensors = None + + def get_arrays(self): + """Returns the numpified and moved to CPU stored objects.""" + self.to_cpu_and_numpy() + return self.arrays + + class SequentialDistributedSampler(Sampler): """ Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end. diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index db725dbd90ff..5c57ce0696f6 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -29,7 +29,6 @@ import numpy as np -from .trainer_pt_utils import nested_concat, nested_numpify from .utils import ( ExplicitEnum, is_psutil_available, @@ -200,58 +199,6 @@ class TrainOutput(NamedTuple): metrics: Dict[str, float] -class EvalLoopContainer: - """ - Container to store intermediate results of evaluation loop - - Args: - do_nested_concat (`bool`, *optional*, defaults to `True`): - If set to `True`, each iteration will recursively concatenate a new object containing tensors to - the existing stored tensors, provided that the structure of the existing object and the new one - are identical. If set to `False`, all newly added tensors will be stored in a list. - padding_index (`int`, *optional*, defaults to -100): - Value used to pad tensors of different shapes when `do_nested_concat=True`. - """ - - def __init__(self, do_nested_concat: bool = True, padding_index: int = -100): - self.do_nested_concat = do_nested_concat - self.padding_index = padding_index - self.tensors = None - self.arrays = None - - def add(self, tensors) -> None: - """Add tensors to the stored objects. If `do_nested_concat=True`, the tensors will be concatenated recursively.""" - if self.tensors is None: - self.tensors = tensors if self.do_nested_concat else [tensors] - elif self.do_nested_concat: - self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index) - else: - self.tensors.append(tensors) - - def to_cpu_and_numpy(self) -> None: - """Move tensors in stored objects to CPU and convert them to numpy arrays.""" - - # Check if we have something to add, if not just return - if self.tensors is None: - return - - new_arrays = nested_numpify(self.tensors) - if self.arrays is None: - self.arrays = new_arrays - elif self.do_nested_concat: - self.arrays = nested_concat(self.arrays, new_arrays, padding_index=self.padding_index) - else: - self.arrays.extend(new_arrays) - - # reset device tensors after adding to cpu - self.tensors = None - - def get_arrays(self): - """Returns the numpified and moved to CPU stored objects.""" - self.to_cpu_and_numpy() - return self.arrays - - PREFIX_CHECKPOINT_DIR = "checkpoint" _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") From b76fe220765c927017c49688a67430ff98fba49a Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 17 Apr 2024 10:40:07 +0000 Subject: [PATCH 6/7] Fix `eval_do_concat_batches` arg description --- src/transformers/training_args.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index f4f28d5ed110..6de45a6a976a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -640,8 +640,8 @@ class TrainingArguments: Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics that need inputs, predictions and references for scoring calculation in Metric class. eval_do_concat_batches (`bool`, *optional*, defaults to `True`): - If set to `False`, inputs/losses/labels/predictions are stored as lists, with each batch kept separate. - If set to `True`, tensors in these nested objects are recursively concatenated across batches. + Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, + will instead store them as lists, with each batch kept separate. auto_find_batch_size (`bool`, *optional*, defaults to `False`) Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) @@ -1267,7 +1267,7 @@ class TrainingArguments: eval_do_concat_batches: bool = field( default=True, metadata={ - "help": "Whether or not tensors in nested objects in batches should be recursively concatenated between batches." + "help": "Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, will instead store them as lists, with each batch kept separate." }, ) # Deprecated arguments From d93171b12b38c79c802eaa2d31241f82eb5c6e66 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 17 Apr 2024 10:49:43 +0000 Subject: [PATCH 7/7] Fix EvalLoopContainer import --- tests/trainer/test_trainer_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer_utils.py b/tests/trainer/test_trainer_utils.py index f14f1093c044..a730ff07ccb2 100644 --- a/tests/trainer/test_trainer_utils.py +++ b/tests/trainer/test_trainer_utils.py @@ -20,7 +20,7 @@ from transformers.data.data_collator import default_data_collator from transformers.testing_utils import require_accelerate, require_torch -from transformers.trainer_utils import EvalLoopContainer, RemoveColumnsCollator, find_executable_batch_size +from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size from transformers.utils import is_torch_available @@ -35,6 +35,7 @@ DistributedLengthGroupedSampler, DistributedSamplerWithLoop, DistributedTensorGatherer, + EvalLoopContainer, IterableDatasetShard, LabelSmoother, LengthGroupedSampler,