diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 978f1a6708e7..5975649a9eae 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3822,6 +3822,9 @@ def evaluation_loop( inputs_decode = self.gather_function((inputs_decode)) if not self.args.batch_eval_metrics or description == "Prediction": all_inputs.add(inputs_decode) + if labels is not None: + # Pad labels here, preparing for preprocess_logits_for_metrics in next logits block. + labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: @@ -3830,7 +3833,6 @@ def evaluation_loop( if not self.args.batch_eval_metrics or description == "Prediction": all_preds.add(logits) if labels is not None: - labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) labels = self.gather_function((labels)) if not self.args.batch_eval_metrics or description == "Prediction": all_labels.add(labels)