diff --git a/tests/unittests/segmentation/test_dice.py b/tests/unittests/segmentation/test_dice.py index b009401f481..b324bdaa0f9 100644 --- a/tests/unittests/segmentation/test_dice.py +++ b/tests/unittests/segmentation/test_dice.py @@ -106,3 +106,18 @@ def test_dice_score_functional(self, preds, target, input_format, include_backgr "input_format": input_format, }, ) + + +@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) +def test_corner_case_no_overlap(average): + """Check that if no overlap and intersection between target and preds, the dice score is 0. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2851 + + """ + target = torch.full((4, 4, 128, 128), 0, dtype=torch.int8) + preds = torch.full((4, 4, 128, 128), 0, dtype=torch.int8) + target[0, 0] = 1 + preds[0, 0] = 1 + dice = DiceScore(num_classes=3, average=average, include_background=False) + assert dice(preds, target) == 0.0