diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 93c21f2535e..45e0238dae5 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -66,11 +66,11 @@ def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: if denom == 0 and confmat.numel() == 4: if fn == 0 and tn == 0: a, b = tp, fp - if fp == 0 and tn == 0: + elif fp == 0 and tn == 0: a, b = tp, fn - if tp == 0 and fn == 0: + elif tp == 0 and fn == 0: a, b = tn, fp - if tp == 0 and fp == 0: + elif tp == 0 and fp == 0: a, b = tn, fn eps = torch.tensor(torch.finfo(torch.float32).eps, dtype=torch.float32, device=confmat.device) numerator = torch.sqrt(eps) * (a - b)