From 176711de62e7d374bf5850be610260cbf9d50f9b Mon Sep 17 00:00:00 2001 From: rittik9 Date: Sun, 1 Sep 2024 03:04:34 +0530 Subject: [PATCH 01/11] Fix: Corrected MulticlassRecall macro average calculation when ignore_index is specified --- src/torchmetrics/classification/precision_recall.py | 1 + .../functional/classification/precision_recall.py | 3 ++- src/torchmetrics/utilities/compute.py | 6 +++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 0380545b5ac..9a57267fab7 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -746,6 +746,7 @@ def compute(self) -> Tensor: fn, average=self.average, multidim_average=self.multidim_average, + ignore_index = self.ignore_index, top_k=self.top_k, zero_division=self.zero_division, ) diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index 96214c82274..fd5438795d4 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -43,6 +43,7 @@ def _precision_recall_reduce( average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], multidim_average: Literal["global", "samplewise"] = "global", multilabel: bool = False, + ignore_index: Optional[int] = None, top_k: int = 1, zero_division: float = 0, ) -> Tensor: @@ -56,7 +57,7 @@ def _precision_recall_reduce( return _safe_divide(tp, tp + different_stat, zero_division) score = _safe_divide(tp, tp + different_stat, zero_division) - return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k) + return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, ignore_index = ignore_index, top_k=top_k) def binary_precision( diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index ee11a36136f..be5f3b5008e 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -61,7 +61,7 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens def _adjust_weights_safe_divide( - score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor, top_k: int = 1 + score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor, ignore_index:Optional[int] = None, top_k: int = 1 ) -> Tensor: if average is None or average == "none": return score @@ -71,6 +71,10 @@ def _adjust_weights_safe_divide( weights = torch.ones_like(score) if not multilabel: weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0 + + if ignore_index is not None and 0 <= ignore_index < len(score): + weights[ignore_index] = 0.0 + return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) From df36d0f5b035314a2dbdb61707dd9d48f8823e1f Mon Sep 17 00:00:00 2001 From: rittik9 Date: Sun, 1 Sep 2024 23:24:35 +0530 Subject: [PATCH 02/11] style: format code to comply with pre-commit hooks --- src/torchmetrics/classification/precision_recall.py | 2 +- .../functional/classification/precision_recall.py | 2 +- src/torchmetrics/utilities/compute.py | 13 ++++++++++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 9a57267fab7..f889b01f770 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -746,7 +746,7 @@ def compute(self) -> Tensor: fn, average=self.average, multidim_average=self.multidim_average, - ignore_index = self.ignore_index, + ignore_index=self.ignore_index, top_k=self.top_k, zero_division=self.zero_division, ) diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index fd5438795d4..5f233863e8a 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -57,7 +57,7 @@ def _precision_recall_reduce( return _safe_divide(tp, tp + different_stat, zero_division) score = _safe_divide(tp, tp + different_stat, zero_division) - return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, ignore_index = ignore_index, top_k=top_k) + return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, ignore_index=ignore_index, top_k=top_k) def binary_precision( diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index be5f3b5008e..deb2e614ca3 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -61,7 +61,14 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens def _adjust_weights_safe_divide( - score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor, ignore_index:Optional[int] = None, top_k: int = 1 + score: Tensor, + average: Optional[str], + multilabel: bool, + tp: Tensor, + fp: Tensor, + fn: Tensor, + ignore_index: Optional[int] = None, + top_k: int = 1, ) -> Tensor: if average is None or average == "none": return score @@ -71,10 +78,10 @@ def _adjust_weights_safe_divide( weights = torch.ones_like(score) if not multilabel: weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0 - + if ignore_index is not None and 0 <= ignore_index < len(score): weights[ignore_index] = 0.0 - + return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) From 0773bab4a172a97c6d97d30d5a98527c249f1514 Mon Sep 17 00:00:00 2001 From: rittik9 Date: Tue, 3 Sep 2024 03:37:53 +0530 Subject: [PATCH 03/11] test: Add test for MulticlassRecall with ignore_index+macro (fixes #2441) --- .../classification/test_precision_recall.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 00eee202cc0..a40427d16c4 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -659,6 +659,37 @@ def test_corner_case(): assert res == 1.0 +def test_multiclass_recall_ignore_index(): + """Issue: https://github.com/Lightning-AI/torchmetrics/issues/2441.""" + y_true = torch.tensor([0, 0, 1, 1]) + y_pred = torch.tensor([ + [0.9, 0.1], + [0.9, 0.1], + [0.9, 0.1], + [0.1, 0.9], + ]) + + # Test with ignore_index=0 and average="macro" + metric_ignore_0 = MulticlassRecall(num_classes=2, ignore_index=0, average="macro") + res_ignore_0 = metric_ignore_0(y_pred, y_true) + assert res_ignore_0 == 0.5, f"Expected 0.5, but got {res_ignore_0}" + + # Test with ignore_index=1 and average="macro" + metric_ignore_1 = MulticlassRecall(num_classes=2, ignore_index=1, average="macro") + res_ignore_1 = metric_ignore_1(y_pred, y_true) + assert res_ignore_1 == 1.0, f"Expected 1.0, but got {res_ignore_1}" + + # Test with no ignore_index and average="macro" + metric_no_ignore = MulticlassRecall(num_classes=2, average="macro") + res_no_ignore = metric_no_ignore(y_pred, y_true) + assert res_no_ignore == 0.75, f"Expected 0.75, but got {res_no_ignore}" + + # Test with ignore_index=0 and average="none" + metric_none = MulticlassRecall(num_classes=2, ignore_index=0, average="none") + res_none = metric_none(y_pred, y_true) + assert torch.allclose(res_none, torch.tensor([0.0, 0.5])), f"Expected [0.0, 0.5], but got {res_none}" + + @pytest.mark.parametrize( ("metric", "kwargs", "base_metric"), [ From 78177ac753f93a738a5a17f178154dfe2661579d Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 9 Sep 2024 20:06:05 +0200 Subject: [PATCH 04/11] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fc6c936492..578d0485019 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed multiclass recall macro avg. ignore index ([#2710](https://github.com/Lightning-AI/torchmetrics/pull/2710)) + + - Fixed handling zero division error in binary IoU (Jaccard index) calculation ([#2726](https://github.com/Lightning-AI/torchmetrics/pull/2726)) From 259c4bd8c46305285dc4b4e783ccc2db82fd0f07 Mon Sep 17 00:00:00 2001 From: rittik9 Date: Sun, 24 Nov 2024 16:44:48 +0530 Subject: [PATCH 05/11] fix:Reference Metric in multiclass pecision recall unittests provides wrong answer when ignore_index is specified --- tests/unittests/classification/test_precision_recall.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index a40427d16c4..c813ae578a4 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -202,6 +202,10 @@ def _reference_sklearn_precision_recall_multiclass( if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) + valid_labels = list(range(NUM_CLASSES)) + if ignore_index is not None: + valid_labels = [label for label in valid_labels if label != ignore_index] + if multidim_average == "global": preds = preds.numpy().flatten() target = target.numpy().flatten() @@ -210,7 +214,7 @@ def _reference_sklearn_precision_recall_multiclass( target, preds, average=average, - labels=list(range(NUM_CLASSES)) if average is None else None, + labels=valid_labels if average in ("macro", "weighted") else None, zero_division=zero_division, ) @@ -235,7 +239,7 @@ def _reference_sklearn_precision_recall_multiclass( true, pred, average=average, - labels=list(range(NUM_CLASSES)) if average is None else None, + labels=valid_labels if average in ("macro", "weighted") else None, zero_division=zero_division, ) res.append(0.0 if np.isnan(r).any() else r) From 447031e80447d4e0a22f06e618ec59ae3c3d69a4 Mon Sep 17 00:00:00 2001 From: rittik9 Date: Sun, 24 Nov 2024 19:57:29 +0530 Subject: [PATCH 06/11] refactor: compute.py --- src/torchmetrics/classification/precision_recall.py | 2 +- .../functional/classification/precision_recall.py | 4 ++-- src/torchmetrics/utilities/compute.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index f889b01f770..4964df0ed25 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -746,9 +746,9 @@ def compute(self) -> Tensor: fn, average=self.average, multidim_average=self.multidim_average, - ignore_index=self.ignore_index, top_k=self.top_k, zero_division=self.zero_division, + ignore_index=self.ignore_index, ) def plot( diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index 5f233863e8a..e8c518f3a5e 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -43,9 +43,9 @@ def _precision_recall_reduce( average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], multidim_average: Literal["global", "samplewise"] = "global", multilabel: bool = False, - ignore_index: Optional[int] = None, top_k: int = 1, zero_division: float = 0, + ignore_index: Optional[int] = None, ) -> Tensor: different_stat = fp if stat == "precision" else fn # this is what differs between the two scores if average == "binary": @@ -57,7 +57,7 @@ def _precision_recall_reduce( return _safe_divide(tp, tp + different_stat, zero_division) score = _safe_divide(tp, tp + different_stat, zero_division) - return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, ignore_index=ignore_index, top_k=top_k) + return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k, ignore_index=ignore_index) def binary_precision( diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index deb2e614ca3..54bd81ddebe 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -67,8 +67,8 @@ def _adjust_weights_safe_divide( tp: Tensor, fp: Tensor, fn: Tensor, - ignore_index: Optional[int] = None, top_k: int = 1, + ignore_index: Optional[int] = None, ) -> Tensor: if average is None or average == "none": return score From 58c0070aad4fb176c0d28abd9950c16b4b02b31a Mon Sep 17 00:00:00 2001 From: rittik9 Date: Mon, 25 Nov 2024 01:47:54 +0530 Subject: [PATCH 07/11] modify _reference_sklearn_precision_recall_multiclass --- tests/unittests/classification/test_precision_recall.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index af1d4e8e2be..5a66df3a35f 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -216,7 +216,9 @@ def _reference_sklearn_precision_recall_multiclass( target, preds, average=average, - labels=valid_labels if average in ("macro", "weighted") else None, + labels=valid_labels + if average in ("macro", "weighted") + else (list(range(num_classes)) if average is None else None), zero_division=zero_division, ) @@ -241,7 +243,9 @@ def _reference_sklearn_precision_recall_multiclass( true, pred, average=average, - labels=valid_labels if average in ("macro", "weighted") else None, + labels=valid_labels + if average in ("macro", "weighted") + else (list(range(num_classes)) if average is None else None), zero_division=zero_division, ) res.append(0.0 if np.isnan(r).any() else r) From 7b1a09fc8fe211c57bceaed6d5fd489ff08863e5 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Mon, 25 Nov 2024 20:47:27 +0530 Subject: [PATCH 08/11] Update CHANGELOG.md Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 34c70856b63..bed0a457620 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -117,7 +117,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [1.4.3] - 2024-10-10 ### Fixed -- Fixed handling zero division error in binary IoU (Jaccard index) calculation ([#2726](https://github.com/Lightning-AI/torchmetrics/pull/2726)) - Fixed for Pearson changes inputs ([#2765](https://github.com/Lightning-AI/torchmetrics/pull/2765)) - Fixed bug in `PESQ` metric where `NoUtterancesError` prevented calculating on a batch of data ([#2753](https://github.com/Lightning-AI/torchmetrics/pull/2753)) - Fixed corner case in `MatthewsCorrCoef` ([#2743](https://github.com/Lightning-AI/torchmetrics/pull/2743)) From bf1c29f69b65b80453bc49cdd4ee58465bd0d37d Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 6 Dec 2024 17:35:54 +0100 Subject: [PATCH 09/11] Pass down ignore_index --- src/torchmetrics/functional/classification/precision_recall.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index e8c518f3a5e..954c0a5fbe7 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -250,6 +250,7 @@ def multiclass_precision( fp, tn, fn, + ignore_index=ignore_index, average=average, multidim_average=multidim_average, top_k=top_k, @@ -560,6 +561,7 @@ def multiclass_recall( fp, tn, fn, + ignore_index=ignore_index, average=average, multidim_average=multidim_average, top_k=top_k, @@ -673,6 +675,7 @@ def multilabel_recall( fp, tn, fn, + ignore_index=ignore_index, average=average, multidim_average=multidim_average, multilabel=True, From 930fba37b40dd2d0faa9672d2179efdcd6634b9e Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 6 Dec 2024 17:36:38 +0100 Subject: [PATCH 10/11] Set weights only for the classes axis --- src/torchmetrics/utilities/compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 7529203c7f0..e7cd3a7d277 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -87,7 +87,7 @@ def _adjust_weights_safe_divide( weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0 if ignore_index is not None and 0 <= ignore_index < len(score): - weights[ignore_index] = 0.0 + weights[..., ignore_index] = 0.0 return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) From 7e956960fbf2eb5ced856b9b2bf6979a7ac7fa14 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Sun, 8 Dec 2024 02:01:45 +0530 Subject: [PATCH 11/11] Update precision_recall.py --- .../functional/classification/precision_recall.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index 954c0a5fbe7..db00aaca67d 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -250,11 +250,11 @@ def multiclass_precision( fp, tn, fn, - ignore_index=ignore_index, average=average, multidim_average=multidim_average, top_k=top_k, zero_division=zero_division, + ignore_index=ignore_index, ) @@ -561,11 +561,11 @@ def multiclass_recall( fp, tn, fn, - ignore_index=ignore_index, average=average, multidim_average=multidim_average, top_k=top_k, zero_division=zero_division, + ignore_index=ignore_index, ) @@ -675,11 +675,11 @@ def multilabel_recall( fp, tn, fn, - ignore_index=ignore_index, average=average, multidim_average=multidim_average, multilabel=True, zero_division=zero_division, + ignore_index=ignore_index, )