diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1bf6fde9fc5e..0edd81988a03 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2433,10 +2433,13 @@ def evaluation_loop( losses_host = None preds_host = None labels_host = None + inputs_host = None + # losses/preds/labels on CPU (final containers) all_losses = None all_preds = None all_labels = None + all_inputs = None # Will be useful when we have an iterable dataset so don't know its length. observed_num_examples = 0 @@ -2452,6 +2455,7 @@ def evaluation_loop( # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None if is_torch_tpu_available(): xm.mark_step() @@ -2464,6 +2468,14 @@ def evaluation_loop( labels = self._pad_across_processes(labels) labels = self._nested_gather(labels) labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + if inputs_decode is not None: + inputs_decode = self._pad_across_processes(inputs_decode) + inputs_decode = self._nested_gather(inputs_decode) + inputs_host = ( + inputs_decode + if inputs_host is None + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) if logits is not None: logits = self._pad_across_processes(logits) logits = self._nested_gather(logits) @@ -2480,6 +2492,13 @@ def evaluation_loop( if preds_host is not None: logits = nested_numpify(preds_host) all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if inputs_host is not None: + inputs_decode = nested_numpify(inputs_host) + all_inputs = ( + inputs_decode + if all_inputs is None + else nested_concat(all_inputs, inputs_decode, padding_index=-100) + ) if labels_host is not None: labels = nested_numpify(labels_host) all_labels = ( @@ -2487,7 +2506,7 @@ def evaluation_loop( ) # Set back to None to begin a new accumulation - losses_host, preds_host, labels_host = None, None, None + losses_host, preds_host, inputs_host, labels_host = None, None, None, None if args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop @@ -2500,6 +2519,11 @@ def evaluation_loop( if preds_host is not None: logits = nested_numpify(preds_host) all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if inputs_host is not None: + inputs_decode = nested_numpify(inputs_host) + all_inputs = ( + inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) + ) if labels_host is not None: labels = nested_numpify(labels_host) all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) @@ -2522,10 +2546,17 @@ def evaluation_loop( all_preds = nested_truncate(all_preds, num_samples) if all_labels is not None: all_labels = nested_truncate(all_labels, num_samples) + if all_inputs is not None: + all_inputs = nested_truncate(all_inputs, num_samples) # Metrics! if self.compute_metrics is not None and all_preds is not None and all_labels is not None: - metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) + ) + else: + metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) else: metrics = {} @@ -2905,7 +2936,6 @@ def prediction_loop( # if eval is called w/o train init deepspeed here if args.deepspeed and not self.deepspeed: - # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval # from the checkpoint eventually deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None) @@ -2936,6 +2966,7 @@ def prediction_loop( losses_host: torch.Tensor = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None + inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None world_size = max(1, args.world_size) @@ -2948,6 +2979,7 @@ def prediction_loop( make_multiple_of = dataloader.sampler.batch_size preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) model.eval() @@ -2961,6 +2993,8 @@ def prediction_loop( for step, inputs in enumerate(dataloader): loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None + if loss is not None: losses = loss.repeat(batch_size) losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) @@ -2968,6 +3002,12 @@ def prediction_loop( preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) if labels is not None: labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + if inputs_decode is not None: + inputs_host = ( + inputs_decode + if inputs_host is None + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. @@ -2976,9 +3016,10 @@ def prediction_loop( if not prediction_loss_only: preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) # Set back to None to begin a new accumulation - losses_host, preds_host, labels_host = None, None, None + losses_host, preds_host, labels_host, inputs_host = None, None, None, None if args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop @@ -2989,13 +3030,20 @@ def prediction_loop( if not prediction_loss_only: preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) eval_loss = eval_losses_gatherer.finalize() preds = preds_gatherer.finalize() if not prediction_loss_only else None label_ids = labels_gatherer.finalize() if not prediction_loss_only else None + inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None if self.compute_metrics is not None and preds is not None and label_ids is not None: - metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids) + ) + else: + metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) else: metrics = {} diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 04f5dc1058ed..0d224740c120 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -64,17 +64,43 @@ def set_seed(seed: int): tf.random.set_seed(seed) -class EvalPrediction(NamedTuple): +class EvalPrediction: """ Evaluation output (always contains labels), to be used to compute metrics. Parameters: predictions (`np.ndarray`): Predictions of the model. label_ids (`np.ndarray`): Targets to be matched. + inputs (`np.ndarray`, *optional*) """ - predictions: Union[np.ndarray, Tuple[np.ndarray]] - label_ids: Union[np.ndarray, Tuple[np.ndarray]] + def __init__( + self, + predictions: Union[np.ndarray, Tuple[np.ndarray]], + label_ids: Union[np.ndarray, Tuple[np.ndarray]], + inputs: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None, + ): + self.predictions = predictions + self.label_ids = label_ids + self.inputs = inputs + + def __iter__(self): + if self.inputs is not None: + return iter((self.predictions, self.label_ids, self.inputs)) + else: + return iter((self.predictions, self.label_ids)) + + def __getitem__(self, idx): + if idx < 0 or idx > 2: + raise IndexError("tuple index out of range") + if idx == 2 and self.inputs is None: + raise IndexError("tuple index out of range") + if idx == 0: + return self.predictions + elif idx == 1: + return self.label_ids + elif idx == 2: + return self.inputs class EvalLoopOutput(NamedTuple): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 2087fbb7ace0..b8a5ee25a038 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -417,6 +417,9 @@ class TrainingArguments: `huggingface-cli login`. gradient_checkpointing (`bool`, *optional*, defaults to `False`): If True, use gradient checkpointing to save memory at the expense of slower backward pass. + include_inputs_for_metrics (`bool`, *optional*, defaults to `False`): + Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics + that need inputs, predictions and references for scoring calculation in Metric class. """ output_dir: str = field( @@ -740,6 +743,9 @@ class TrainingArguments: "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." }, ) + include_inputs_for_metrics: bool = field( + default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."} + ) # Deprecated arguments fp16_backend: str = field( default="auto",