Skip to content

Commit 0d1d727

Browse files
Pass task="binary" to binary metrics to handle latest torchmetrics (#920)
1 parent df84c81 commit 0d1d727

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

merlin/models/torch/model/prediction_task.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def build(self, input_size) -> SequentialBlock:
4343
class BinaryClassificationTask(PredictionTask):
4444
DEFAULT_LOSS = torch.nn.BCELoss()
4545
DEFAULT_METRICS = (
46-
tm.Precision(num_classes=2),
47-
tm.Recall(num_classes=2),
48-
tm.Accuracy(),
46+
tm.Precision(num_classes=2, task="binary"),
47+
tm.Recall(num_classes=2, task="binary"),
48+
tm.Accuracy(task="binary"),
4949
# TODO: Fix this: tm.AUC()
5050
)
5151

requirements/pytorch.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
torch>=1.0
2-
torchmetrics==0.3.2
2+
torchmetrics>=0.10.0

0 commit comments

Comments
 (0)