From 650f9328d46a035e01616cab3c95c35b174358bc Mon Sep 17 00:00:00 2001 From: Sara Rabhi Date: Tue, 13 Dec 2022 13:06:23 -0500 Subject: [PATCH] Fix error raised by latest Torchmetrics (0.11.0) (#576) * update default torchmetrics * unpin torchmetrics version * update torchmetrics requirement * update torchmetrics version from Oliver suggestion --- requirements/pytorch.txt | 2 +- tests/torch/test_trainer.py | 12 ++++++------ transformers4rec/torch/model/prediction_task.py | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/requirements/pytorch.txt b/requirements/pytorch.txt index 9e76107d68..f0311772d3 100644 --- a/requirements/pytorch.txt +++ b/requirements/pytorch.txt @@ -1,2 +1,2 @@ torch>=1.0 -torchmetrics==0.3.2 +torchmetrics>=0.10.0 diff --git a/tests/torch/test_trainer.py b/tests/torch/test_trainer.py index 3ac5a573d4..b8d52ec1b2 100644 --- a/tests/torch/test_trainer.py +++ b/tests/torch/test_trainer.py @@ -325,9 +325,9 @@ def test_evaluate_results(torch_yoochoose_next_item_prediction_model): ( tr.BinaryClassificationTask("click", summary_type="mean"), [ - "eval_/click/binary_classification_task/accuracy", - "eval_/click/binary_classification_task/precision", - "eval_/click/binary_classification_task/recall", + "eval_/click/binary_classification_task/binary_accuracy", + "eval_/click/binary_classification_task/binary_precision", + "eval_/click/binary_classification_task/binary_recall", ], ), ( @@ -446,9 +446,9 @@ def test_trainer_with_multiple_tasks(): "eval_/next-item/avg_precision_at_20", "eval_/next-item/recall_at_10", "eval_/next-item/recall_at_20", - "eval_/click/binary_classification_task/accuracy", - "eval_/click/binary_classification_task/precision", - "eval_/click/binary_classification_task/recall", + "eval_/click/binary_classification_task/binary_accuracy", + "eval_/click/binary_classification_task/binary_precision", + "eval_/click/binary_classification_task/binary_recall", "eval_/play_percentage/regression_task/mean_squared_error", ] diff --git a/transformers4rec/torch/model/prediction_task.py b/transformers4rec/torch/model/prediction_task.py index 3ed3dd212f..46d6143dde 100644 --- a/transformers4rec/torch/model/prediction_task.py +++ b/transformers4rec/torch/model/prediction_task.py @@ -46,9 +46,9 @@ def build(self, input_size) -> SequentialBlock: class BinaryClassificationTask(PredictionTask): DEFAULT_LOSS = torch.nn.BCELoss() DEFAULT_METRICS = ( - tm.Precision(num_classes=2), - tm.Recall(num_classes=2), - tm.Accuracy(), + tm.Precision(num_classes=2, task="binary"), + tm.Recall(num_classes=2, task="binary"), + tm.Accuracy(task="binary"), # TODO: Fix this: tm.AUC() )