From 79fcdf35a4a859c09a19304f6523384c2e5ac6d9 Mon Sep 17 00:00:00 2001 From: Willard Sheen Date: Mon, 17 Jun 2024 18:13:43 +0800 Subject: [PATCH 1/4] [fix BUG] pad labels before use it in preprocess_logits_for_metrics --- src/transformers/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 34cf5aa49046..1a80b57c5083 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3810,6 +3810,7 @@ def evaluation_loop( if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: + labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) if labels is not None else None logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.gather_function((logits)) if not self.args.batch_eval_metrics or description == "Prediction": From 7aab5efdc3c5eb243a28cc8dc464555c04514831 Mon Sep 17 00:00:00 2001 From: Willard Sheen Date: Tue, 18 Jun 2024 11:03:06 +0800 Subject: [PATCH 2/4] a more readable fix labels can't use `gather` before pass to `preprocess_logits_for_metrics`, so must split into 2 if-block --- src/transformers/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1a80b57c5083..5ed481d6c790 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3807,16 +3807,16 @@ def evaluation_loop( inputs_decode = self.gather_function((inputs_decode)) if not self.args.batch_eval_metrics or description == "Prediction": all_inputs.add(inputs_decode) + if labels is not None: + labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: - labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) if labels is not None else None logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.gather_function((logits)) if not self.args.batch_eval_metrics or description == "Prediction": all_preds.add(logits) if labels is not None: - labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) labels = self.gather_function((labels)) if not self.args.batch_eval_metrics or description == "Prediction": all_labels.add(labels) From 376eb0c29fbe65bed6897550e9c71cb126fef7ab Mon Sep 17 00:00:00 2001 From: Willard Sheen Date: Wed, 19 Jun 2024 09:56:10 +0800 Subject: [PATCH 3/4] add a comment --- src/transformers/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ee874b13faa3..f5ce34d522e5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3823,6 +3823,7 @@ def evaluation_loop( if not self.args.batch_eval_metrics or description == "Prediction": all_inputs.add(inputs_decode) if labels is not None: + # pad labels here, preparing for preprocess_logits_for_metrics in next logits block labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) From fe59a9aa4dffa75ef3fb37bcaa3f7b6da988a031 Mon Sep 17 00:00:00 2001 From: Willard Sheen Date: Wed, 19 Jun 2024 10:00:18 +0800 Subject: [PATCH 4/4] oh code quality check --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f5ce34d522e5..5975649a9eae 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3823,7 +3823,7 @@ def evaluation_loop( if not self.args.batch_eval_metrics or description == "Prediction": all_inputs.add(inputs_decode) if labels is not None: - # pad labels here, preparing for preprocess_logits_for_metrics in next logits block + # Pad labels here, preparing for preprocess_logits_for_metrics in next logits block. labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)