Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Dec 3, 2024
1 parent 50982ab commit 059f6df
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/unittests/segmentation/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,18 @@ def test_dice_score_functional(self, preds, target, input_format, include_backgr
"input_format": input_format,
},
)


@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None])
def test_corner_case_no_overlap(average):
"""Check that if no overlap and intersection between target and preds, the dice score is 0.
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2851
"""
target = torch.full((4, 4, 128, 128), 0, dtype=torch.int8)
preds = torch.full((4, 4, 128, 128), 0, dtype=torch.int8)
target[0, 0] = 1
preds[0, 0] = 1
dice = DiceScore(num_classes=3, average=average, include_background=False)
assert dice(preds, target) == 0.0

0 comments on commit 059f6df

Please sign in to comment.