From fdbc4aa9d4e247eae5b56d799aadf83744fc1dfd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 14 Sep 2024 09:59:43 +0200 Subject: [PATCH 1/3] fix + test --- .../functional/classification/matthews_corrcoef.py | 14 ++++++++------ .../classification/test_matthews_corrcoef.py | 6 ++++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 544414ee4a8..93c21f2535e 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -64,12 +64,14 @@ def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: denom = cov_ypyp * cov_ytyt if denom == 0 and confmat.numel() == 4: - if tp == 0 or tn == 0: - a = tp + tn - - if fp == 0 or fn == 0: - b = fp + fn - + if fn == 0 and tn == 0: + a, b = tp, fp + if fp == 0 and tn == 0: + a, b = tp, fn + if tp == 0 and fn == 0: + a, b = tn, fp + if tp == 0 and fp == 0: + a, b = tn, fn eps = torch.tensor(torch.finfo(torch.float32).eps, dtype=torch.float32, device=confmat.device) numerator = torch.sqrt(eps) * (a - b) denom = (tp + fp + eps) * (tp + fn + eps) * (tn + fp + eps) * (tn + fn + eps) diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 03f649bc0ac..2f881604d09 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -331,6 +331,12 @@ def test_zero_case_in_multiclass(): torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]), 0.0, ), + ( + binary_matthews_corrcoef, + torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), + torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + 0.0, + ), (binary_matthews_corrcoef, torch.zeros(10), torch.ones(10), -1.0), (binary_matthews_corrcoef, torch.ones(10), torch.zeros(10), -1.0), ( From b5ee13e42993ed21a9b4e7a249a12175027c3bd1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 14 Sep 2024 10:02:47 +0200 Subject: [PATCH 2/3] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ef9ed3c896f..e05c9a5d1d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed how `prefix`/`postfix` works in `MultitaskWrapper` ([#2722](https://github.com/Lightning-AI/torchmetrics/pull/2722)) +- Fixed corner case in `MatthewsCorrCoef` ([#2743](https://github.com/Lightning-AI/torchmetrics/pull/2743)) + + ## [1.4.1] - 2024-08-02 ### Changed From 4dfa566a5ff519aafe60705a2d3eb67f6fa809a8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 16 Sep 2024 10:00:45 +0200 Subject: [PATCH 3/3] Apply suggestions from code review --- .../functional/classification/matthews_corrcoef.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 93c21f2535e..45e0238dae5 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -66,11 +66,11 @@ def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: if denom == 0 and confmat.numel() == 4: if fn == 0 and tn == 0: a, b = tp, fp - if fp == 0 and tn == 0: + elif fp == 0 and tn == 0: a, b = tp, fn - if tp == 0 and fn == 0: + elif tp == 0 and fn == 0: a, b = tn, fp - if tp == 0 and fp == 0: + elif tp == 0 and fp == 0: a, b = tn, fn eps = torch.tensor(torch.finfo(torch.float32).eps, dtype=torch.float32, device=confmat.device) numerator = torch.sqrt(eps) * (a - b)