Skip to content

Commit fdbc4aa

Browse files
SkafteNickiBorda
authored andcommitted
fix + test
1 parent ead5cbb commit fdbc4aa

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

src/torchmetrics/functional/classification/matthews_corrcoef.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,14 @@ def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor:
6464
denom = cov_ypyp * cov_ytyt
6565

6666
if denom == 0 and confmat.numel() == 4:
67-
if tp == 0 or tn == 0:
68-
a = tp + tn
69-
70-
if fp == 0 or fn == 0:
71-
b = fp + fn
72-
67+
if fn == 0 and tn == 0:
68+
a, b = tp, fp
69+
if fp == 0 and tn == 0:
70+
a, b = tp, fn
71+
if tp == 0 and fn == 0:
72+
a, b = tn, fp
73+
if tp == 0 and fp == 0:
74+
a, b = tn, fn
7375
eps = torch.tensor(torch.finfo(torch.float32).eps, dtype=torch.float32, device=confmat.device)
7476
numerator = torch.sqrt(eps) * (a - b)
7577
denom = (tp + fp + eps) * (tp + fn + eps) * (tn + fp + eps) * (tn + fn + eps)

tests/unittests/classification/test_matthews_corrcoef.py

+6
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,12 @@ def test_zero_case_in_multiclass():
331331
torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]),
332332
0.0,
333333
),
334+
(
335+
binary_matthews_corrcoef,
336+
torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]),
337+
torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
338+
0.0,
339+
),
334340
(binary_matthews_corrcoef, torch.zeros(10), torch.ones(10), -1.0),
335341
(binary_matthews_corrcoef, torch.ones(10), torch.zeros(10), -1.0),
336342
(

0 commit comments

Comments
 (0)