diff --git a/CHANGELOG.md b/CHANGELOG.md index 66ebcc8803f..e9ba95ce852 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for calculating segmentation quality and recognition quality in `PanopticQuality` metric ([#2381](https://github.com/Lightning-AI/torchmetrics/pull/2381)) +- Added a new segmentation metric `MeanIoU` ([#1236](https://github.com/PyTorchLightning/metrics/pull/1236)) + + - Added `pretty-errors` for improving error prints ([#2431](https://github.com/Lightning-AI/torchmetrics/pull/2431)) diff --git a/docs/source/segmentation/mean_iou.rst b/docs/source/segmentation/mean_iou.rst new file mode 100644 index 00000000000..7fddd9f316d --- /dev/null +++ b/docs/source/segmentation/mean_iou.rst @@ -0,0 +1,19 @@ +.. customcarditem:: + :header: Mean Intersection over Union (mIoU) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg + :tags: segmentation + +################################### +Mean Intersection over Union (mIoU) +################################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.segmentation.MeanIoU + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.segmentation.mean_iou diff --git a/src/torchmetrics/functional/segmentation/__init__.py b/src/torchmetrics/functional/segmentation/__init__.py index eec2e4dfcf3..3d23192a36a 100644 --- a/src/torchmetrics/functional/segmentation/__init__.py +++ b/src/torchmetrics/functional/segmentation/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score +from torchmetrics.functional.segmentation.mean_iou import mean_iou -__all__ = ["generalized_dice_score"] +__all__ = ["generalized_dice_score", "mean_iou"] diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py new file mode 100644 index 00000000000..0a4e24da6e1 --- /dev/null +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -0,0 +1,109 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import torch +from torch import Tensor + +from torchmetrics.functional.segmentation.utils import _ignore_background +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.compute import _safe_divide + + +def _mean_iou_validate_args( + num_classes: int, + include_background: bool, + per_class: bool, +) -> None: + """Validate the arguments of the metric.""" + if num_classes <= 0: + raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.") + if not isinstance(include_background, bool): + raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.") + if not isinstance(per_class, bool): + raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.") + + +def _mean_iou_update( + preds: Tensor, + target: Tensor, + num_classes: int, + include_background: bool = False, +) -> Tuple[Tensor, Tensor]: + """Update the intersection and union counts for the mean IoU computation.""" + _check_same_shape(preds, target) + + if (preds.bool() != preds).any(): # preds is an index tensor + preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) + if (target.bool() != target).any(): # target is an index tensor + target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) + + if not include_background: + preds, target = _ignore_background(preds, target) + + reduce_axis = list(range(2, preds.ndim)) + intersection = torch.sum(preds & target, dim=reduce_axis) + target_sum = torch.sum(target, dim=reduce_axis) + pred_sum = torch.sum(preds, dim=reduce_axis) + union = target_sum + pred_sum - intersection + return intersection, union + + +def _mean_iou_compute( + intersection: Tensor, + union: Tensor, + per_class: bool = False, +) -> Tensor: + """Compute the mean IoU metric.""" + val = _safe_divide(intersection, union) + return val if per_class else torch.mean(val, 1) + + +def mean_iou( + preds: Tensor, + target: Tensor, + num_classes: int, + include_background: bool = True, + per_class: bool = False, +) -> Tensor: + """Calculates the mean Intersection over Union (mIoU) for semantic segmentation. + + Args: + preds: Predictions from model + target: Ground truth values + num_classes: Number of classes + include_background: Whether to include the background class in the computation + per_class: Whether to compute the IoU for each class separately, else average over all classes + + Returns: + The mean IoU score + + Example: + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.functional.segmentation import mean_iou + >>> preds = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> mean_iou(preds, target, num_classes=5) + tensor([0.3193, 0.3305, 0.3382, 0.3246]) + >>> mean_iou(preds, target, num_classes=5, per_class=True) + tensor([[0.3093, 0.3500, 0.3081, 0.3389, 0.2903], + [0.2963, 0.3316, 0.3505, 0.2804, 0.3936], + [0.3724, 0.3249, 0.3660, 0.3184, 0.3093], + [0.3085, 0.3267, 0.3155, 0.3575, 0.3147]]) + + """ + _mean_iou_validate_args(num_classes, include_background, per_class) + intersection, union = _mean_iou_update(preds, target, num_classes, include_background) + return _mean_iou_compute(intersection, union, per_class=per_class) diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py index e8427a69326..6c2fed92df2 100644 --- a/src/torchmetrics/functional/segmentation/utils.py +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -25,7 +25,7 @@ def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: - """Ignore the background class in the computation.""" + """Ignore the background class in the computation assuming it is the first, index 0.""" preds = preds[:, 1:] if preds.shape[1] > 1 else preds target = target[:, 1:] if target.shape[1] > 1 else target return preds, target diff --git a/src/torchmetrics/segmentation/__init__.py b/src/torchmetrics/segmentation/__init__.py index 24275594e4c..5b609c2c738 100644 --- a/src/torchmetrics/segmentation/__init__.py +++ b/src/torchmetrics/segmentation/__init__.py @@ -1,4 +1,4 @@ -# Copyright The PyTorch Lightning team. +# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore +from torchmetrics.segmentation.mean_iou import MeanIoU -__all__ = ["GeneralizedDiceScore"] +__all__ = ["GeneralizedDiceScore", "MeanIoU"] diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py new file mode 100644 index 00000000000..f36f7fc3bc0 --- /dev/null +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -0,0 +1,157 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence, Union + +import torch +from torch import Tensor + +from torchmetrics.functional.segmentation.mean_iou import _mean_iou_compute, _mean_iou_update, _mean_iou_validate_args +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MeanIoU.plot"] + + +class MeanIoU(Metric): + """Computes Mean Intersection over Union (mIoU) for semantic segmentation. + + The metric is defined by the overlap between the predicted segmentation and the ground truth, divided by the + total area covered by the union of the two. The metric can be computed for each class separately or for all + classes at once. The metric is optimal at a value of 1 and worst at a value of 0. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. That format will be automatically + converted to a one-hot tensor. + - ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. That format will be automatically + converted to a one-hot tensor. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``miou`` (:class:`~torch.Tensor`): The mean Intersection over Union (mIoU) score. If ``per_class`` is set to + ``True``, the output will be a tensor of shape ``(C,)`` with the IoU score for each class. If ``per_class`` is + set to ``False``, the output will be a scalar tensor. + + Args: + num_classes: The number of classes in the segmentation problem. + include_background: Whether to include the background class in the computation + per_class: Whether to compute the IoU for each class separately. If set to ``False``, the metric will + compute the mean IoU over all classes. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: + If ``num_classes`` is not a positive integer + ValueError: + If ``include_background`` is not a boolean + ValueError: + If ``per_class`` is not a boolean + + Example: + >>> import torch + >>> _ = torch.manual_seed(0) + >>> from torchmetrics.segmentation import MeanIoU + >>> miou = MeanIoU(num_classes=3) + >>> preds = torch.randint(0, 2, (10, 3, 128, 128)) + >>> target = torch.randint(0, 2, (10, 3, 128, 128)) + >>> miou(preds, target) + tensor(0.3318) + >>> miou = MeanIoU(num_classes=3, per_class=True) + >>> miou(preds, target) + tensor([0.3322, 0.3303, 0.3329]) + >>> miou = MeanIoU(num_classes=3, per_class=True, include_background=False) + >>> miou(preds, target) + tensor([0.3303, 0.3329]) + + """ + + score: Tensor + num_batches: Tensor + full_state_update: bool = False + is_differentiable: bool = False + higher_is_better: bool = True + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + def __init__( + self, + num_classes: int, + include_background: bool = True, + per_class: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + _mean_iou_validate_args(num_classes, include_background, per_class) + self.num_classes = num_classes + self.include_background = include_background + self.per_class = per_class + + num_classes = num_classes - 1 if not include_background else num_classes + self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update the state with the new data.""" + intersection, union = _mean_iou_update(preds, target, self.num_classes, self.include_background) + score = _mean_iou_compute(intersection, union, per_class=self.per_class) + self.score += score.mean(0) if self.per_class else score.mean() + + def compute(self) -> Tensor: + """Update the state with the new data.""" + return self.score # / self.num_batches + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality + >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') + >>> metric.update(torch.rand(8000), torch.rand(8000)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality + >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(8000), torch.rand(8000))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/tests/unittests/segmentation/test_mean_iou.py b/tests/unittests/segmentation/test_mean_iou.py new file mode 100644 index 00000000000..013b71572d6 --- /dev/null +++ b/tests/unittests/segmentation/test_mean_iou.py @@ -0,0 +1,97 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import pytest +import torch +from monai.metrics.meaniou import compute_iou +from torchmetrics.functional.segmentation.mean_iou import mean_iou +from torchmetrics.segmentation.mean_iou import MeanIoU + +from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input +from unittests._helpers.testers import MetricTester + +_inputs1 = _Input( + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), + target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), +) +_inputs2 = _Input( + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), + target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), +) +_inputs3 = _Input( + preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), + target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), +) + + +def _reference_mean_iou( + preds: torch.Tensor, + target: torch.Tensor, + include_background: bool = True, + per_class: bool = True, + reduce: bool = True, +): + """Calculate reference metric for `MeanIoU`.""" + if (preds.bool() != preds).any(): # preds is an index tensor + preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) + if (target.bool() != target).any(): # target is an index tensor + target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) + + val = compute_iou(preds, target, include_background=include_background) + val[torch.isnan(val)] = 0.0 + if reduce: + return torch.mean(val, 0) if per_class else torch.mean(val) + return val + + +@pytest.mark.parametrize( + "preds, target", + [ + (_inputs1.preds, _inputs1.target), + (_inputs2.preds, _inputs2.target), + (_inputs3.preds, _inputs3.target), + ], +) +@pytest.mark.parametrize("include_background", [True, False]) +class TestMeanIoU(MetricTester): + """Test class for `MeanIoU` metric.""" + + atol = 1e-4 + + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + @pytest.mark.parametrize("per_class", [True, False]) + def test_mean_iou_class(self, preds, target, include_background, per_class, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MeanIoU, + reference_metric=partial( + _reference_mean_iou, include_background=include_background, per_class=per_class, reduce=True + ), + metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": per_class}, + ) + + def test_mean_iou_functional(self, preds, target, include_background): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=mean_iou, + reference_metric=partial(_reference_mean_iou, include_background=include_background, reduce=False), + metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": True}, + )