From f6a55acefb4cd3162ff09cc45ca691eef2a049be Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 3 Sep 2024 16:32:21 +0200 Subject: [PATCH] smaller array to fix torch.unique case --- tests/unittests/classification/test_stat_scores.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 53fa78d0368..5ea4c206bc0 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -582,8 +582,8 @@ def test_support_for_int(): """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1970.""" seed_all(42) metric = MulticlassStatScores(num_classes=4, average="none", multidim_average="samplewise", ignore_index=0) - prediction = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8) - label = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8) + prediction = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8) + label = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8) score = metric(preds=prediction, target=label) assert score.shape == (1, 4, 5)