From 8596130b146e42258d49e369f2f8928e9538aad6 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Fri, 3 Jan 2025 23:44:06 +0400 Subject: [PATCH 1/4] first working version, some errots on tests --- .../functional/segmentation/mean_iou.py | 11 +++++++---- src/torchmetrics/segmentation/mean_iou.py | 14 ++++++++------ tests/unittests/segmentation/test_mean_iou.py | 7 +++++-- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 9cfed0fa1bf..7c960ee0ed0 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -66,11 +66,9 @@ def _mean_iou_update( def _mean_iou_compute( intersection: Tensor, union: Tensor, - per_class: bool = False, ) -> Tensor: """Compute the mean IoU metric.""" - val = _safe_divide(intersection, union) - return val if per_class else torch.mean(val, 1) + return _safe_divide(intersection, union) def mean_iou( @@ -111,4 +109,9 @@ def mean_iou( """ _mean_iou_validate_args(num_classes, include_background, per_class, input_format) intersection, union = _mean_iou_update(preds, target, num_classes, include_background, input_format) - return _mean_iou_compute(intersection, union, per_class=per_class) + score = _mean_iou_compute(intersection, union) + score[torch.isnan(score)] = 0.0 # Handle division by zero like reference + valid_classes = union > 0 + score = (score * valid_classes).sum(dim=0) + num_batches = valid_classes.sum(dim=0) + return score / num_batches if per_class else (score / num_batches).mean() diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py index ae8dd3d2aea..0fe1fe699ed 100644 --- a/src/torchmetrics/segmentation/mean_iou.py +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -111,21 +111,23 @@ def __init__( self.input_format = input_format num_classes = num_classes - 1 if not include_background else num_classes - self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum") - self.add_state("num_batches", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("score", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("num_batches", default=torch.zeros(num_classes), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with the new data.""" intersection, union = _mean_iou_update( preds, target, self.num_classes, self.include_background, self.input_format ) - score = _mean_iou_compute(intersection, union, per_class=self.per_class) - self.score += score.mean(0) if self.per_class else score.mean() - self.num_batches += 1 + score = _mean_iou_compute(intersection, union) + # only update for classes that are present (i.e. union > 0) + valid_classes = union > 0 + self.score += (score * valid_classes).sum(dim=0) + self.num_batches += valid_classes.sum(dim=0) def compute(self) -> Tensor: """Compute the final Mean Intersection over Union (mIoU).""" - return self.score / self.num_batches + return self.score / self.num_batches if self.per_class else (self.score / self.num_batches).mean() def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/tests/unittests/segmentation/test_mean_iou.py b/tests/unittests/segmentation/test_mean_iou.py index 8c21d5c70c3..0f9a732c12c 100644 --- a/tests/unittests/segmentation/test_mean_iou.py +++ b/tests/unittests/segmentation/test_mean_iou.py @@ -22,7 +22,7 @@ from unittests import NUM_CLASSES from unittests._helpers.testers import MetricTester -from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3 +from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3 # , _inputs5 def _reference_mean_iou( @@ -90,7 +90,10 @@ def test_mean_iou_functional(self, preds, target, input_format, include_backgrou target=target, metric_functional=mean_iou, reference_metric=partial( - _reference_mean_iou, input_format=input_format, include_background=include_background, reduce=False + _reference_mean_iou, + input_format=input_format, + include_background=include_background, + reduce=False, ), metric_args={ "num_classes": NUM_CLASSES, From 68a1d06063513f3cd25d70a110937e778e9f92a7 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Sat, 4 Jan 2025 00:11:53 +0400 Subject: [PATCH 2/4] functional part rewritten as well --- src/torchmetrics/functional/segmentation/mean_iou.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 7c960ee0ed0..b6706bb83f2 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -110,8 +110,6 @@ def mean_iou( _mean_iou_validate_args(num_classes, include_background, per_class, input_format) intersection, union = _mean_iou_update(preds, target, num_classes, include_background, input_format) score = _mean_iou_compute(intersection, union) - score[torch.isnan(score)] = 0.0 # Handle division by zero like reference valid_classes = union > 0 - score = (score * valid_classes).sum(dim=0) - num_batches = valid_classes.sum(dim=0) - return score / num_batches if per_class else (score / num_batches).mean() + score = score * valid_classes + return score if per_class else score[valid_classes].mean(dim=-1) From 998a2bc22a69b2227e2dea6003f9184ad31afe94 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Sat, 4 Jan 2025 00:17:17 +0400 Subject: [PATCH 3/4] remove extra code added in tests for now --- tests/unittests/segmentation/test_mean_iou.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/unittests/segmentation/test_mean_iou.py b/tests/unittests/segmentation/test_mean_iou.py index 0f9a732c12c..8c21d5c70c3 100644 --- a/tests/unittests/segmentation/test_mean_iou.py +++ b/tests/unittests/segmentation/test_mean_iou.py @@ -22,7 +22,7 @@ from unittests import NUM_CLASSES from unittests._helpers.testers import MetricTester -from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3 # , _inputs5 +from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3 def _reference_mean_iou( @@ -90,10 +90,7 @@ def test_mean_iou_functional(self, preds, target, input_format, include_backgrou target=target, metric_functional=mean_iou, reference_metric=partial( - _reference_mean_iou, - input_format=input_format, - include_background=include_background, - reduce=False, + _reference_mean_iou, input_format=input_format, include_background=include_background, reduce=False ), metric_args={ "num_classes": NUM_CLASSES, From b67a5b6bec0d77786c073a61770990c8d400e076 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Sat, 4 Jan 2025 00:31:44 +0400 Subject: [PATCH 4/4] proper mean calculation --- src/torchmetrics/functional/segmentation/mean_iou.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index b6706bb83f2..0dfc7a5e692 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -112,4 +112,4 @@ def mean_iou( score = _mean_iou_compute(intersection, union) valid_classes = union > 0 score = score * valid_classes - return score if per_class else score[valid_classes].mean(dim=-1) + return score if per_class else score.sum(dim=-1) / valid_classes.sum(dim=-1)