From ede52a0c99cdfe1012290a071564ca9517979cb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Paw=C5=82owicz?= Date: Thu, 5 Jan 2023 20:20:11 +0100 Subject: [PATCH] fix args passed to predict function --- examples/pytorch/question-answering/trainer_seq2seq_qa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/question-answering/trainer_seq2seq_qa.py b/examples/pytorch/question-answering/trainer_seq2seq_qa.py index 73517c06d7cd..6abb41b33feb 100644 --- a/examples/pytorch/question-answering/trainer_seq2seq_qa.py +++ b/examples/pytorch/question-answering/trainer_seq2seq_qa.py @@ -151,7 +151,7 @@ def predict( if self.post_process_function is None or self.compute_metrics is None: return output - predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict") + predictions = self.post_process_function(predict_examples, predict_dataset, output, "predict") metrics = self.compute_metrics(predictions) # Prefix all keys with metric_key_prefix + '_'