From b88c8ec796b38b9bc86057423ea853d5d4f510f2 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 26 Sep 2022 09:44:25 +0200 Subject: [PATCH 01/19] First draft --- CHANGELOG.md | 4 +- README.md | 1 + docs/source/index.rst | 8 + docs/source/segmentation/mean_iou.rst | 20 ++ .../functional/segmentation/__init__.py | 15 ++ .../functional/segmentation/mean_iou.py | 238 ++++++++++++++++++ .../functional/segmentation/test.py | 22 ++ 7 files changed, 306 insertions(+), 2 deletions(-) create mode 100644 docs/source/segmentation/mean_iou.rst create mode 100644 src/torchmetrics/functional/segmentation/__init__.py create mode 100644 src/torchmetrics/functional/segmentation/mean_iou.py create mode 100644 src/torchmetrics/functional/segmentation/test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 008f16ba52b..2eb659beed3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,16 +13,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added a new NLP metric `InfoLM` ([#915](https://github.com/PyTorchLightning/metrics/pull/915)) - - Added `Perplexity` metric ([#922](https://github.com/PyTorchLightning/metrics/pull/922)) - - Added argument `normalize` to `LPIPS` metric ([#1216](https://github.com/Lightning-AI/metrics/pull/1216)) - Added support for multiprocessing of batches in `PESQ` metric ([#1227](https://github.com/Lightning-AI/metrics/pull/1227)) - Added support for multioutput in `PearsonCorrCoef` and `SpearmanCorrCoef` ([#1200](https://github.com/Lightning-AI/metrics/pull/1200)) +- Added a new segmentation metric `mean IoU` ([#915](https://github.com/PyTorchLightning/metrics/pull/915)) + ### Changed - Classification refactor ( diff --git a/README.md b/README.md index 926e145f27f..3a2b059c6ea 100644 --- a/README.md +++ b/README.md @@ -274,6 +274,7 @@ We currently have implemented metrics within the following domains: - Audio - Classification - Detection +- Segmentation - Information Retrieval - Image - Regression diff --git a/docs/source/index.rst b/docs/source/index.rst index f3bc95a8ae9..111e8b61104 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -167,6 +167,14 @@ Or directly from conda detection/* +.. toctree:: + :maxdepth: 2 + :name: segmentation + :caption: Segmentation + :glob: + + segmentation/* + .. toctree:: :maxdepth: 2 :name: pairwise diff --git a/docs/source/segmentation/mean_iou.rst b/docs/source/segmentation/mean_iou.rst new file mode 100644 index 00000000000..d578cc768b2 --- /dev/null +++ b/docs/source/segmentation/mean_iou.rst @@ -0,0 +1,20 @@ +.. customcarditem:: + :header: Mean Intersection over Union (mIoU) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg + :tags: segmentation + +############################ +Mean Intersection over Union (mIoU) +############################ + +Module Interface +________________ + +.. autoclass:: torchmetrics.segmentation.mean_iou.MeanIntersectionOverUnion + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.mean_iou + :noindex: \ No newline at end of file diff --git a/src/torchmetrics/functional/segmentation/__init__.py b/src/torchmetrics/functional/segmentation/__init__.py new file mode 100644 index 00000000000..8ebd7985ace --- /dev/null +++ b/src/torchmetrics/functional/segmentation/__init__.py @@ -0,0 +1,15 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torchmetrics.functional.segmentation.mean_iou import mean_iou # noqa: F401 \ No newline at end of file diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py new file mode 100644 index 00000000000..58b5804de9d --- /dev/null +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -0,0 +1,238 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Optional, Sequence, Tuple + +import torch +from torch import Tensor + +from torchmetrics.utilities.checks import _check_same_shape + + +def _input_validator(preds: Sequence[Tensor], target: Sequence[ Tensor]) -> None: + """Ensure the correct input format of `preds` and `targets`""" + if not isinstance(preds, Sequence): + raise ValueError("Expected argument `preds` to be of type Sequence") + if not isinstance(target, Sequence): + raise ValueError("Expected argument `target` to be of type Sequence") + if len(preds) != len(target): + raise ValueError("Expected argument `preds` and `target` to have the same length") + for prediction, ground_truth in zip(preds, target): + _check_same_shape(prediction, ground_truth) + + +def intersect_and_union(pred_label, + label, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False): + """Calculate Intersection and Union. + + Args: + pred_label (torch.Tensor): + Prediction segmentation map. + label (torch.Tensor): + Ground truth segmentation map. + num_classes (int): + Number of categories. + ignore_index (int): + Index that will be ignored in evaluation. + label_map (dict): + Mapping old labels to new labels. The parameter will work only when label is str. Default: dict(). + reduce_zero_label (bool): + Whether ignore zero label. The parameter will work only when label is str. Default: False. + + Returns: + torch.Tensor: + The intersection of prediction and ground truth histogram on all classes. + torch.Tensor: + The union of prediction and ground truth histogram on all classes. + torch.Tensor: + The prediction histogram on all classes. + torch.Tensor: + The ground truth histogram on all classes. + """ + + if label_map is not None: + label_copy = label.clone() + for old_id, new_id in label_map.items(): + label[label_copy == old_id] = new_id + + if reduce_zero_label: + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + + mask = (label != ignore_index) + pred_label = pred_label[mask] + label = label[mask] + + intersect = pred_label[pred_label == label] + area_intersect = torch.histc(intersect.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_pred_label = torch.histc(pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_label = torch.histc(label.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_union = area_pred_label + area_label - area_intersect + + return area_intersect, area_union, area_pred_label, area_label + + +def total_intersect_and_union(preds, + target, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False): + """Calculate Total Intersection and Union. + + Args: + preds (list[torch.Tensor]): + List of prediction segmentation maps. + target (list[torch.Tensor]): + List of ground truth segmentation maps. + num_classes (int): + Number of categories. + ignore_index (int): + Index that will be ignored in evaluation. + label_map (dict): + Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): + Whether ignore zero label. Default: False. + + Returns: + torch.Tensor: + The intersection of prediction and ground truth histogram on all classes. + torch.Tensor: + The union of prediction and ground truth histogram on all classes. + torch.Tensor: + The prediction histogram on all classes. + torch.Tensor: + The ground truth histogram on all classes. + """ + total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_label = torch.zeros((num_classes, ), dtype=torch.float64) + + for result, gt_seg_map in zip(preds, target): + area_intersect, area_union, area_pred_label, area_label = \ + intersect_and_union( + result, gt_seg_map, num_classes, ignore_index, + label_map, reduce_zero_label) + + total_area_intersect += area_intersect + total_area_union += area_union + total_area_pred_label += area_pred_label + total_area_label += area_label + + return total_area_intersect, total_area_union, total_area_pred_label, total_area_label + + +def _mean_iou_update(preds: Tensor, + target: Tensor, + num_labels: int, + ignore_index: bool, + nan_to_num: Optional[int] = None, + label_map: Optional[Dict[int, int]] = None, + reduce_labels: bool = False,) -> Tuple[Tensor, int]: + """Updates and returns variables required to compute Mean Intersection over Union. + + Checks for same shape of each element of the ``preds`` and ``target`` lists. + + Args: + preds (list[torch.Tensor]): + List of prediction segmentation maps. + target (list[torch.Tensor]): + List of ground truth segmentation maps. + num_classes (int): + Number of categories. + ignore_index (int): + Index that will be ignored in evaluation. + label_map (dict): + Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): + Whether ignore zero label. Default: False. + """ + _input_validator(preds, target) + + total_area_intersect, total_area_union, total_area_pred_label, total_area_label = total_intersect_and_union( + preds, target, num_labels, ignore_index, label_map, reduce_labels + ) + + return total_area_intersect, total_area_union, total_area_pred_label, total_area_label + + +def _mean_iou_compute(total_area_intersect, total_area_union, total_area_pred_label, total_area_label) -> Tensor: + """Computes Mean Intersection over Union. + + Args: + total_area_intersect: + ... + total_area_union: + ... + total_area_pred_label: + ... + total_area_label: + ... + + Example: + >>> preds = torch.tensor([0., 1, 2, 3]) + >>> target = torch.tensor([0., 1, 2, 2]) + >>> total_area_intersect, total_area_union, total_area_pred_label, total_area_label = _mean_iou_update(preds, target) + >>> _mean_iou_compute(total_area_intersect, total_area_union, total_area_pred_label, total_area_label) + tensor(0.2500) + """ + iou = total_area_intersect / total_area_union + + mean_iou = torch.nanmean(iou) + + return mean_iou + + +def mean_iou(preds: List[Tensor], + target: List[Tensor], + num_labels: int, + ignore_index: bool, + nan_to_num: Optional[int] = None, + label_map: Optional[Dict[int, int]] = None, + reduce_labels: bool = False,) -> Tensor: + """Computes Mean Intersection over Union (mIoU). + + Args: + preds: + estimated labels + target: + ground truth labels + num_labels: + number of labels + ignore_index: + index that will be ignored in evaluation + nan_to_num: + If specified, NaN values will be replaced by the numbers defined by the user. Default: None. + label_map: + Mapping old labels to new labels. Default: None. + reduce_labels: + Whether to ignore the zero label and reduce all labels by one. Default: False. + + Return: + Tensor with mIoU. + + Example: + >>> from torchmetrics.functional.segmentation import mean_iou + >>> preds = [torch.tensor([[2,0],[2,3]])] + >>> target = [torch.tensor([[255,255],[2,3]])] + >>> mean_iou(preds, target) + tensor(0.2500) + """ + total_area_intersect, total_area_union, total_area_pred_label, total_area_label = _mean_iou_update(preds, target, num_labels, ignore_index, nan_to_num, label_map, reduce_labels) + return _mean_iou_compute(total_area_intersect, total_area_union, total_area_pred_label, total_area_label) \ No newline at end of file diff --git a/src/torchmetrics/functional/segmentation/test.py b/src/torchmetrics/functional/segmentation/test.py new file mode 100644 index 00000000000..b38af0e1a81 --- /dev/null +++ b/src/torchmetrics/functional/segmentation/test.py @@ -0,0 +1,22 @@ +from torchmetrics.functional.segmentation.mean_iou import mean_iou +import torch + +# suppose one has 3 different segmentation maps predicted +predicted_1 = torch.tensor([[1, 2], [3, 4], [5, 255]]) +actual_1 = torch.tensor([[0, 3], [5, 4], [6, 255]]) + +predicted_2 = torch.tensor([[2, 7], [9, 2], [3, 6]]) +actual_2 = torch.tensor([[1, 7], [9, 2], [3, 6]]) + +predicted_3 = torch.tensor([[2, 2, 3], [8, 2, 4], [3, 255, 2]]) +actual_3 = torch.tensor([[1, 2, 2], [8, 2, 1], [3, 255, 1]]) + +predicted = [predicted_1, predicted_2, predicted_3] +ground_truth = [actual_1, actual_2, actual_3] +results = mean_iou(preds=predicted, target=ground_truth, num_labels=10, ignore_index=255, reduce_labels=False) + +print(results) + +preds = [torch.tensor([[2,0],[2,3]])] +target = [torch.tensor([[255,255],[2,3]])] +print(mean_iou(preds, target, num_labels=4, ignore_index=255, reduce_labels=False)) \ No newline at end of file From 2c85af3e08b511ebb86e19b0964ba1408db9387f Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 26 Sep 2022 09:52:08 +0200 Subject: [PATCH 02/19] Update PR number --- CHANGELOG.md | 2 +- docs/source/segmentation/mean_iou.rst | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2eb659beed3..c11b0aa5a5f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for multioutput in `PearsonCorrCoef` and `SpearmanCorrCoef` ([#1200](https://github.com/Lightning-AI/metrics/pull/1200)) -- Added a new segmentation metric `mean IoU` ([#915](https://github.com/PyTorchLightning/metrics/pull/915)) +- Added a new segmentation metric `mean IoU` ([#1236](https://github.com/PyTorchLightning/metrics/pull/1236)) ### Changed diff --git a/docs/source/segmentation/mean_iou.rst b/docs/source/segmentation/mean_iou.rst index d578cc768b2..1576a3a81e7 100644 --- a/docs/source/segmentation/mean_iou.rst +++ b/docs/source/segmentation/mean_iou.rst @@ -7,12 +7,6 @@ Mean Intersection over Union (mIoU) ############################ -Module Interface -________________ - -.. autoclass:: torchmetrics.segmentation.mean_iou.MeanIntersectionOverUnion - :noindex: - Functional Interface ____________________ From f976f14e0d86890c9a676db37c27b4b1db6a9e72 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Sep 2022 07:52:54 +0000 Subject: [PATCH 03/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/segmentation/mean_iou.rst | 2 +- .../functional/segmentation/__init__.py | 2 +- .../functional/segmentation/mean_iou.py | 91 +++++++++---------- .../functional/segmentation/test.py | 9 +- 4 files changed, 50 insertions(+), 54 deletions(-) diff --git a/docs/source/segmentation/mean_iou.rst b/docs/source/segmentation/mean_iou.rst index 1576a3a81e7..cfb9cfcd8d6 100644 --- a/docs/source/segmentation/mean_iou.rst +++ b/docs/source/segmentation/mean_iou.rst @@ -11,4 +11,4 @@ Functional Interface ____________________ .. autofunction:: torchmetrics.functional.mean_iou - :noindex: \ No newline at end of file + :noindex: diff --git a/src/torchmetrics/functional/segmentation/__init__.py b/src/torchmetrics/functional/segmentation/__init__.py index 8ebd7985ace..e1a20dab0b1 100644 --- a/src/torchmetrics/functional/segmentation/__init__.py +++ b/src/torchmetrics/functional/segmentation/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.segmentation.mean_iou import mean_iou # noqa: F401 \ No newline at end of file +from torchmetrics.functional.segmentation.mean_iou import mean_iou # noqa: F401 diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 58b5804de9d..66cdeecc5f2 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -19,7 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape -def _input_validator(preds: Sequence[Tensor], target: Sequence[ Tensor]) -> None: +def _input_validator(preds: Sequence[Tensor], target: Sequence[Tensor]) -> None: """Ensure the correct input format of `preds` and `targets`""" if not isinstance(preds, Sequence): raise ValueError("Expected argument `preds` to be of type Sequence") @@ -31,14 +31,9 @@ def _input_validator(preds: Sequence[Tensor], target: Sequence[ Tensor]) -> None _check_same_shape(prediction, ground_truth) -def intersect_and_union(pred_label, - label, - num_classes, - ignore_index, - label_map=dict(), - reduce_zero_label=False): +def intersect_and_union(pred_label, label, num_classes, ignore_index, label_map=dict(), reduce_zero_label=False): """Calculate Intersection and Union. - + Args: pred_label (torch.Tensor): Prediction segmentation map. @@ -52,7 +47,7 @@ def intersect_and_union(pred_label, Mapping old labels to new labels. The parameter will work only when label is str. Default: dict(). reduce_zero_label (bool): Whether ignore zero label. The parameter will work only when label is str. Default: False. - + Returns: torch.Tensor: The intersection of prediction and ground truth histogram on all classes. @@ -68,13 +63,13 @@ def intersect_and_union(pred_label, label_copy = label.clone() for old_id, new_id in label_map.items(): label[label_copy == old_id] = new_id - + if reduce_zero_label: label[label == 0] = 255 label = label - 1 label[label == 254] = 255 - mask = (label != ignore_index) + mask = label != ignore_index pred_label = pred_label[mask] label = label[mask] @@ -83,18 +78,13 @@ def intersect_and_union(pred_label, area_pred_label = torch.histc(pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) area_label = torch.histc(label.float(), bins=(num_classes), min=0, max=num_classes - 1) area_union = area_pred_label + area_label - area_intersect - + return area_intersect, area_union, area_pred_label, area_label -def total_intersect_and_union(preds, - target, - num_classes, - ignore_index, - label_map=dict(), - reduce_zero_label=False): +def total_intersect_and_union(preds, target, num_classes, ignore_index, label_map=dict(), reduce_zero_label=False): """Calculate Total Intersection and Union. - + Args: preds (list[torch.Tensor]): List of prediction segmentation maps. @@ -108,7 +98,7 @@ def total_intersect_and_union(preds, Mapping old labels to new labels. Default: dict(). reduce_zero_label (bool): Whether ignore zero label. Default: False. - + Returns: torch.Tensor: The intersection of prediction and ground truth histogram on all classes. @@ -119,32 +109,33 @@ def total_intersect_and_union(preds, torch.Tensor: The ground truth histogram on all classes. """ - total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) - total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) - total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) - total_area_label = torch.zeros((num_classes, ), dtype=torch.float64) - + total_area_intersect = torch.zeros((num_classes,), dtype=torch.float64) + total_area_union = torch.zeros((num_classes,), dtype=torch.float64) + total_area_pred_label = torch.zeros((num_classes,), dtype=torch.float64) + total_area_label = torch.zeros((num_classes,), dtype=torch.float64) + for result, gt_seg_map in zip(preds, target): - area_intersect, area_union, area_pred_label, area_label = \ - intersect_and_union( - result, gt_seg_map, num_classes, ignore_index, - label_map, reduce_zero_label) - + area_intersect, area_union, area_pred_label, area_label = intersect_and_union( + result, gt_seg_map, num_classes, ignore_index, label_map, reduce_zero_label + ) + total_area_intersect += area_intersect total_area_union += area_union total_area_pred_label += area_pred_label total_area_label += area_label - + return total_area_intersect, total_area_union, total_area_pred_label, total_area_label -def _mean_iou_update(preds: Tensor, - target: Tensor, - num_labels: int, - ignore_index: bool, - nan_to_num: Optional[int] = None, - label_map: Optional[Dict[int, int]] = None, - reduce_labels: bool = False,) -> Tuple[Tensor, int]: +def _mean_iou_update( + preds: Tensor, + target: Tensor, + num_labels: int, + ignore_index: bool, + nan_to_num: Optional[int] = None, + label_map: Optional[Dict[int, int]] = None, + reduce_labels: bool = False, +) -> Tuple[Tensor, int]: """Updates and returns variables required to compute Mean Intersection over Union. Checks for same shape of each element of the ``preds`` and ``target`` lists. @@ -164,7 +155,7 @@ def _mean_iou_update(preds: Tensor, Whether ignore zero label. Default: False. """ _input_validator(preds, target) - + total_area_intersect, total_area_union, total_area_pred_label, total_area_label = total_intersect_and_union( preds, target, num_labels, ignore_index, label_map, reduce_labels ) @@ -199,13 +190,15 @@ def _mean_iou_compute(total_area_intersect, total_area_union, total_area_pred_la return mean_iou -def mean_iou(preds: List[Tensor], - target: List[Tensor], - num_labels: int, - ignore_index: bool, - nan_to_num: Optional[int] = None, - label_map: Optional[Dict[int, int]] = None, - reduce_labels: bool = False,) -> Tensor: +def mean_iou( + preds: List[Tensor], + target: List[Tensor], + num_labels: int, + ignore_index: bool, + nan_to_num: Optional[int] = None, + label_map: Optional[Dict[int, int]] = None, + reduce_labels: bool = False, +) -> Tensor: """Computes Mean Intersection over Union (mIoU). Args: @@ -234,5 +227,7 @@ def mean_iou(preds: List[Tensor], >>> mean_iou(preds, target) tensor(0.2500) """ - total_area_intersect, total_area_union, total_area_pred_label, total_area_label = _mean_iou_update(preds, target, num_labels, ignore_index, nan_to_num, label_map, reduce_labels) - return _mean_iou_compute(total_area_intersect, total_area_union, total_area_pred_label, total_area_label) \ No newline at end of file + total_area_intersect, total_area_union, total_area_pred_label, total_area_label = _mean_iou_update( + preds, target, num_labels, ignore_index, nan_to_num, label_map, reduce_labels + ) + return _mean_iou_compute(total_area_intersect, total_area_union, total_area_pred_label, total_area_label) diff --git a/src/torchmetrics/functional/segmentation/test.py b/src/torchmetrics/functional/segmentation/test.py index b38af0e1a81..560a9b5c4cc 100644 --- a/src/torchmetrics/functional/segmentation/test.py +++ b/src/torchmetrics/functional/segmentation/test.py @@ -1,6 +1,7 @@ -from torchmetrics.functional.segmentation.mean_iou import mean_iou import torch +from torchmetrics.functional.segmentation.mean_iou import mean_iou + # suppose one has 3 different segmentation maps predicted predicted_1 = torch.tensor([[1, 2], [3, 4], [5, 255]]) actual_1 = torch.tensor([[0, 3], [5, 4], [6, 255]]) @@ -17,6 +18,6 @@ print(results) -preds = [torch.tensor([[2,0],[2,3]])] -target = [torch.tensor([[255,255],[2,3]])] -print(mean_iou(preds, target, num_labels=4, ignore_index=255, reduce_labels=False)) \ No newline at end of file +preds = [torch.tensor([[2, 0], [2, 3]])] +target = [torch.tensor([[255, 255], [2, 3]])] +print(mean_iou(preds, target, num_labels=4, ignore_index=255, reduce_labels=False)) From dd43d31f9646d28bfbae9e1246b1ed19809ff6cb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Mar 2024 17:50:05 +0000 Subject: [PATCH 04/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/segmentation/mean_iou.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 66cdeecc5f2..424711ffc99 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -57,8 +57,8 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index, label_map= The prediction histogram on all classes. torch.Tensor: The ground truth histogram on all classes. - """ + """ if label_map is not None: label_copy = label.clone() for old_id, new_id in label_map.items(): @@ -99,7 +99,7 @@ def total_intersect_and_union(preds, target, num_classes, ignore_index, label_ma reduce_zero_label (bool): Whether ignore zero label. Default: False. - Returns: + Returns: torch.Tensor: The intersection of prediction and ground truth histogram on all classes. torch.Tensor: @@ -108,6 +108,7 @@ def total_intersect_and_union(preds, target, num_classes, ignore_index, label_ma The prediction histogram on all classes. torch.Tensor: The ground truth histogram on all classes. + """ total_area_intersect = torch.zeros((num_classes,), dtype=torch.float64) total_area_union = torch.zeros((num_classes,), dtype=torch.float64) @@ -153,6 +154,7 @@ def _mean_iou_update( Mapping old labels to new labels. Default: dict(). reduce_zero_label (bool): Whether ignore zero label. Default: False. + """ _input_validator(preds, target) @@ -182,6 +184,7 @@ def _mean_iou_compute(total_area_intersect, total_area_union, total_area_pred_la >>> total_area_intersect, total_area_union, total_area_pred_label, total_area_label = _mean_iou_update(preds, target) >>> _mean_iou_compute(total_area_intersect, total_area_union, total_area_pred_label, total_area_label) tensor(0.2500) + """ iou = total_area_intersect / total_area_union @@ -226,6 +229,7 @@ def mean_iou( >>> target = [torch.tensor([[255,255],[2,3]])] >>> mean_iou(preds, target) tensor(0.2500) + """ total_area_intersect, total_area_union, total_area_pred_label, total_area_label = _mean_iou_update( preds, target, num_labels, ignore_index, nan_to_num, label_map, reduce_labels From 20eae44886ef0f29585fdf31f7dc3b0a5cc488bf Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 12 Apr 2024 13:12:07 +0200 Subject: [PATCH 05/19] move testing file --- .../functional => tests/unittests}/segmentation/test.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {src/torchmetrics/functional => tests/unittests}/segmentation/test.py (100%) diff --git a/src/torchmetrics/functional/segmentation/test.py b/tests/unittests/segmentation/test.py similarity index 100% rename from src/torchmetrics/functional/segmentation/test.py rename to tests/unittests/segmentation/test.py From efd4443e9c6b7f9ae430201d523c60ced7b6c110 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 12 Apr 2024 13:12:45 +0200 Subject: [PATCH 06/19] rename testing file --- tests/unittests/segmentation/{test.py => test_mean_iou.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/unittests/segmentation/{test.py => test_mean_iou.py} (100%) diff --git a/tests/unittests/segmentation/test.py b/tests/unittests/segmentation/test_mean_iou.py similarity index 100% rename from tests/unittests/segmentation/test.py rename to tests/unittests/segmentation/test_mean_iou.py From 63172f85d080692fb121f11b86f712f968798bd8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 12 Apr 2024 13:23:04 +0200 Subject: [PATCH 07/19] more structure for class interface --- docs/source/segmentation/mean_iou.rst | 9 ++++-- .../functional/segmentation/__init__.py | 4 ++- src/torchmetrics/segmentation/__init__.py | 16 ++++++++++ src/torchmetrics/segmentation/mean_iou.py | 29 +++++++++++++++++++ 4 files changed, 55 insertions(+), 3 deletions(-) create mode 100644 src/torchmetrics/segmentation/__init__.py create mode 100644 src/torchmetrics/segmentation/mean_iou.py diff --git a/docs/source/segmentation/mean_iou.rst b/docs/source/segmentation/mean_iou.rst index cfb9cfcd8d6..90322225f8d 100644 --- a/docs/source/segmentation/mean_iou.rst +++ b/docs/source/segmentation/mean_iou.rst @@ -7,8 +7,13 @@ Mean Intersection over Union (mIoU) ############################ +Module Interface +________________ + +.. autoclass:: torchmetrics.segmentation.MeanIOU + :exclude-members: update, compute + Functional Interface ____________________ -.. autofunction:: torchmetrics.functional.mean_iou - :noindex: +.. autofunction:: torchmetrics.functional.segmentation.mean_iou diff --git a/src/torchmetrics/functional/segmentation/__init__.py b/src/torchmetrics/functional/segmentation/__init__.py index 4a0504fc612..4a8f5d3a943 100644 --- a/src/torchmetrics/functional/segmentation/__init__.py +++ b/src/torchmetrics/functional/segmentation/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.segmentation.mean_iou import mean_iou # noqa: F401 +from torchmetrics.functional.segmentation.mean_iou import mean_iou + +__all__ = ["mean_iou"] diff --git a/src/torchmetrics/segmentation/__init__.py b/src/torchmetrics/segmentation/__init__.py new file mode 100644 index 00000000000..33872e5c057 --- /dev/null +++ b/src/torchmetrics/segmentation/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.segmentation.mean_iou import MeanIOU + +__all__ = ["MeanIOU"] diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py new file mode 100644 index 00000000000..df09044a38b --- /dev/null +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -0,0 +1,29 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torch import Tensor + +from torchmetrics.metric import Metric + + +class MeanIOU(Metric): + """Computes Mean Intersection over Union (mIoU) for semantic segmentation.""" + + def __init__(self) -> None: + pass + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update the state with the new data.""" + + def compute(self) -> Tensor: + """Update the state with the new data.""" From 08613fbe2b5e6eff74044da44a952bd657734878 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 12 Apr 2024 15:50:56 +0200 Subject: [PATCH 08/19] working implementation --- docs/source/segmentation/mean_iou.rst | 2 +- .../functional/segmentation/mean_iou.py | 264 ++++-------------- .../functional/segmentation/utils.py | 6 + src/torchmetrics/segmentation/__init__.py | 4 +- src/torchmetrics/segmentation/mean_iou.py | 29 +- tests/unittests/segmentation/test_mean_iou.py | 89 +++++- 6 files changed, 163 insertions(+), 231 deletions(-) diff --git a/docs/source/segmentation/mean_iou.rst b/docs/source/segmentation/mean_iou.rst index 90322225f8d..1c24eceaabc 100644 --- a/docs/source/segmentation/mean_iou.rst +++ b/docs/source/segmentation/mean_iou.rst @@ -10,7 +10,7 @@ Mean Intersection over Union (mIoU) Module Interface ________________ -.. autoclass:: torchmetrics.segmentation.MeanIOU +.. autoclass:: torchmetrics.segmentation.MeanIoU :exclude-members: update, compute Functional Interface diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 424711ffc99..cc6cbb5718c 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -16,222 +16,66 @@ import torch from torch import Tensor -from torchmetrics.utilities.checks import _check_same_shape - - -def _input_validator(preds: Sequence[Tensor], target: Sequence[Tensor]) -> None: - """Ensure the correct input format of `preds` and `targets`""" - if not isinstance(preds, Sequence): - raise ValueError("Expected argument `preds` to be of type Sequence") - if not isinstance(target, Sequence): - raise ValueError("Expected argument `target` to be of type Sequence") - if len(preds) != len(target): - raise ValueError("Expected argument `preds` and `target` to have the same length") - for prediction, ground_truth in zip(preds, target): - _check_same_shape(prediction, ground_truth) - - -def intersect_and_union(pred_label, label, num_classes, ignore_index, label_map=dict(), reduce_zero_label=False): - """Calculate Intersection and Union. - - Args: - pred_label (torch.Tensor): - Prediction segmentation map. - label (torch.Tensor): - Ground truth segmentation map. - num_classes (int): - Number of categories. - ignore_index (int): - Index that will be ignored in evaluation. - label_map (dict): - Mapping old labels to new labels. The parameter will work only when label is str. Default: dict(). - reduce_zero_label (bool): - Whether ignore zero label. The parameter will work only when label is str. Default: False. - - Returns: - torch.Tensor: - The intersection of prediction and ground truth histogram on all classes. - torch.Tensor: - The union of prediction and ground truth histogram on all classes. - torch.Tensor: - The prediction histogram on all classes. - torch.Tensor: - The ground truth histogram on all classes. - - """ - if label_map is not None: - label_copy = label.clone() - for old_id, new_id in label_map.items(): - label[label_copy == old_id] = new_id - - if reduce_zero_label: - label[label == 0] = 255 - label = label - 1 - label[label == 254] = 255 - - mask = label != ignore_index - pred_label = pred_label[mask] - label = label[mask] - - intersect = pred_label[pred_label == label] - area_intersect = torch.histc(intersect.float(), bins=(num_classes), min=0, max=num_classes - 1) - area_pred_label = torch.histc(pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) - area_label = torch.histc(label.float(), bins=(num_classes), min=0, max=num_classes - 1) - area_union = area_pred_label + area_label - area_intersect - - return area_intersect, area_union, area_pred_label, area_label - - -def total_intersect_and_union(preds, target, num_classes, ignore_index, label_map=dict(), reduce_zero_label=False): - """Calculate Total Intersection and Union. - - Args: - preds (list[torch.Tensor]): - List of prediction segmentation maps. - target (list[torch.Tensor]): - List of ground truth segmentation maps. - num_classes (int): - Number of categories. - ignore_index (int): - Index that will be ignored in evaluation. - label_map (dict): - Mapping old labels to new labels. Default: dict(). - reduce_zero_label (bool): - Whether ignore zero label. Default: False. - - Returns: - torch.Tensor: - The intersection of prediction and ground truth histogram on all classes. - torch.Tensor: - The union of prediction and ground truth histogram on all classes. - torch.Tensor: - The prediction histogram on all classes. - torch.Tensor: - The ground truth histogram on all classes. - - """ - total_area_intersect = torch.zeros((num_classes,), dtype=torch.float64) - total_area_union = torch.zeros((num_classes,), dtype=torch.float64) - total_area_pred_label = torch.zeros((num_classes,), dtype=torch.float64) - total_area_label = torch.zeros((num_classes,), dtype=torch.float64) - - for result, gt_seg_map in zip(preds, target): - area_intersect, area_union, area_pred_label, area_label = intersect_and_union( - result, gt_seg_map, num_classes, ignore_index, label_map, reduce_zero_label +from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.functional.segmentation.utils import _ignore_background + +def _mean_iou_validate_args( + num_classes: int, + include_background: bool, + per_class: bool, +) -> None: + if num_classes <= 0: + raise ValueError( + f"Expected argument `num_classes` must be a positive integer, but got {num_classes}." ) - - total_area_intersect += area_intersect - total_area_union += area_union - total_area_pred_label += area_pred_label - total_area_label += area_label - - return total_area_intersect, total_area_union, total_area_pred_label, total_area_label - + if not isinstance(include_background, bool): + raise ValueError( + f"Expected argument `include_background` must be a boolean, but got {include_background}." + ) + if not isinstance(per_class, bool): + raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.") def _mean_iou_update( preds: Tensor, target: Tensor, - num_labels: int, - ignore_index: bool, - nan_to_num: Optional[int] = None, - label_map: Optional[Dict[int, int]] = None, - reduce_labels: bool = False, -) -> Tuple[Tensor, int]: - """Updates and returns variables required to compute Mean Intersection over Union. - - Checks for same shape of each element of the ``preds`` and ``target`` lists. - - Args: - preds (list[torch.Tensor]): - List of prediction segmentation maps. - target (list[torch.Tensor]): - List of ground truth segmentation maps. - num_classes (int): - Number of categories. - ignore_index (int): - Index that will be ignored in evaluation. - label_map (dict): - Mapping old labels to new labels. Default: dict(). - reduce_zero_label (bool): - Whether ignore zero label. Default: False. - - """ - _input_validator(preds, target) - - total_area_intersect, total_area_union, total_area_pred_label, total_area_label = total_intersect_and_union( - preds, target, num_labels, ignore_index, label_map, reduce_labels - ) - - return total_area_intersect, total_area_union, total_area_pred_label, total_area_label - - -def _mean_iou_compute(total_area_intersect, total_area_union, total_area_pred_label, total_area_label) -> Tensor: - """Computes Mean Intersection over Union. - - Args: - total_area_intersect: - ... - total_area_union: - ... - total_area_pred_label: - ... - total_area_label: - ... - - Example: - >>> preds = torch.tensor([0., 1, 2, 3]) - >>> target = torch.tensor([0., 1, 2, 2]) - >>> total_area_intersect, total_area_union, total_area_pred_label, total_area_label = _mean_iou_update(preds, target) - >>> _mean_iou_compute(total_area_intersect, total_area_union, total_area_pred_label, total_area_label) - tensor(0.2500) - - """ - iou = total_area_intersect / total_area_union - - mean_iou = torch.nanmean(iou) - - return mean_iou - + num_classes: int, + include_background: bool = False, +): + if preds.shape != target.shape: # assume preds is probabilities with an extra dimension + preds = preds.argmax(dim=1) + if (preds.bool() != preds).any(): # preds is an index tensor + preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) + if (target.bool() != target).any(): # target is an index tensor + target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) + + if not include_background: + preds, target = _ignore_background(preds, target) + + reduce_axis = list(range(2, preds.ndim)) + intersection = torch.sum(preds & target, dim=reduce_axis) + target_sum = torch.sum(target, dim=reduce_axis) + pred_sum = torch.sum(preds, dim=reduce_axis) + union = target_sum + pred_sum - intersection + return intersection, union + +def _mean_iou_compute( + intersection: Tensor, + union: Tensor, + per_class: bool = False, +) -> Tensor: + val = _safe_divide(intersection, union) + return val if per_class else torch.mean(val, 1) def mean_iou( - preds: List[Tensor], - target: List[Tensor], - num_labels: int, - ignore_index: bool, - nan_to_num: Optional[int] = None, - label_map: Optional[Dict[int, int]] = None, - reduce_labels: bool = False, + preds: Tensor, + target: Tensor, + num_classes: int, + include_background: bool = False, + per_class: bool = False, ) -> Tensor: - """Computes Mean Intersection over Union (mIoU). - - Args: - preds: - estimated labels - target: - ground truth labels - num_labels: - number of labels - ignore_index: - index that will be ignored in evaluation - nan_to_num: - If specified, NaN values will be replaced by the numbers defined by the user. Default: None. - label_map: - Mapping old labels to new labels. Default: None. - reduce_labels: - Whether to ignore the zero label and reduce all labels by one. Default: False. - - Return: - Tensor with mIoU. - - Example: - >>> from torchmetrics.functional.segmentation import mean_iou - >>> preds = [torch.tensor([[2,0],[2,3]])] - >>> target = [torch.tensor([[255,255],[2,3]])] - >>> mean_iou(preds, target) - tensor(0.2500) - """ - total_area_intersect, total_area_union, total_area_pred_label, total_area_label = _mean_iou_update( - preds, target, num_labels, ignore_index, nan_to_num, label_map, reduce_labels - ) - return _mean_iou_compute(total_area_intersect, total_area_union, total_area_pred_label, total_area_label) + + """ + intersection, union = _mean_iou_update(preds, target, num_classes, include_background) + return _mean_iou_compute(intersection, union, per_class=per_class) + diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py index bbf5c48ded3..34be884b282 100644 --- a/src/torchmetrics/functional/segmentation/utils.py +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -23,6 +23,12 @@ from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.imports import _SCIPY_AVAILABLE +def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """Ignore the background class in the computation.""" + preds = preds[:, 1:] if preds.shape[1] > 1 else preds + target = target[:, 1:] if target.shape[1] > 1 else target + return preds, target + def check_if_binarized(x: Tensor) -> None: """Check if the input is binarized. diff --git a/src/torchmetrics/segmentation/__init__.py b/src/torchmetrics/segmentation/__init__.py index 33872e5c057..c0e886a57cd 100644 --- a/src/torchmetrics/segmentation/__init__.py +++ b/src/torchmetrics/segmentation/__init__.py @@ -11,6 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.segmentation.mean_iou import MeanIOU +from torchmetrics.segmentation.mean_iou import MeanIoU -__all__ = ["MeanIOU"] +__all__ = ["MeanIoU"] diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py index df09044a38b..2b4fc2b21fc 100644 --- a/src/torchmetrics/segmentation/mean_iou.py +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -12,18 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. from torch import Tensor - +import torch from torchmetrics.metric import Metric +from torchmetrics.functional.segmentation.mean_iou import _mean_iou_update, _mean_iou_compute, _mean_iou_validate_args -class MeanIOU(Metric): +class MeanIoU(Metric): """Computes Mean Intersection over Union (mIoU) for semantic segmentation.""" - def __init__(self) -> None: - pass + def __init__( + self, + num_classes: int, + include_background: bool = False, + per_class: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + _mean_iou_validate_args(num_classes, include_background, per_class) + self.num_classes = num_classes + self.include_background = include_background + self.per_class = per_class + + num_classes = num_classes - 1 if not include_background else num_classes + self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean") + self.add_state("num_batches", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with the new data.""" + intersection, union = _mean_iou_update(preds, target, self.num_classes, self.include_background) + score = _mean_iou_compute(intersection, union, per_class=self.per_class) + self.score += score.mean(0) if self.per_class else score.mean() + self.num_batches += 1 def compute(self) -> Tensor: """Update the state with the new data.""" + return self.score / self.num_batches + diff --git a/tests/unittests/segmentation/test_mean_iou.py b/tests/unittests/segmentation/test_mean_iou.py index 560a9b5c4cc..8843805decc 100644 --- a/tests/unittests/segmentation/test_mean_iou.py +++ b/tests/unittests/segmentation/test_mean_iou.py @@ -1,23 +1,84 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import pytest +import torch + import torch from torchmetrics.functional.segmentation.mean_iou import mean_iou +from torchmetrics.segmentation.mean_iou import MeanIoU +from monai.metrics.meaniou import compute_iou +from unittests._helpers.testers import MetricTester +from unittests import BATCH_SIZE, NUM_BATCHES, _Input, NUM_CLASSES + +_inputs1 = _Input( + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), + target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), +) +_inputs2 = _Input( + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), + target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), +) +# _inputs3 = _Input( +# preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32), +# target=torch.randint(0, 5, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), +# ) -# suppose one has 3 different segmentation maps predicted -predicted_1 = torch.tensor([[1, 2], [3, 4], [5, 255]]) -actual_1 = torch.tensor([[0, 3], [5, 4], [6, 255]]) +def _reference_mean_iou(preds: torch.Tensor, target: torch.Tensor, include_background: bool = True, per_class: bool = True, reduce: bool = True): + """Calculate reference metric for `MeanIoU`.""" + val = compute_iou(preds, target, include_background=include_background) + if reduce: + return torch.mean(val, 0) if per_class else torch.mean(val) + return val -predicted_2 = torch.tensor([[2, 7], [9, 2], [3, 6]]) -actual_2 = torch.tensor([[1, 7], [9, 2], [3, 6]]) -predicted_3 = torch.tensor([[2, 2, 3], [8, 2, 4], [3, 255, 2]]) -actual_3 = torch.tensor([[1, 2, 2], [8, 2, 1], [3, 255, 1]]) +@pytest.mark.parametrize( + "preds, target", + [ + (_inputs1.preds, _inputs1.target), + (_inputs2.preds, _inputs2.target), + #(_inputs3.preds, _inputs3.target), + ], +) +@pytest.mark.parametrize("include_background", [True, False]) +class TestMeanIoU(MetricTester): + """Test class for `MeanIoU` metric.""" + atol = 1e-4 -predicted = [predicted_1, predicted_2, predicted_3] -ground_truth = [actual_1, actual_2, actual_3] -results = mean_iou(preds=predicted, target=ground_truth, num_labels=10, ignore_index=255, reduce_labels=False) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + @pytest.mark.parametrize("per_class", [True, False]) + def test_mean_iou_class(self, preds, target, include_background, per_class, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MeanIoU, + reference_metric=partial(_reference_mean_iou, include_background=include_background, per_class=per_class, reduce=True), + metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": per_class}, + ) -print(results) + def test_mean_iou_functional(self, preds, target, include_background): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=mean_iou, + reference_metric=partial(_reference_mean_iou, include_background=include_background, reduce=False), + metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": True}, + ) -preds = [torch.tensor([[2, 0], [2, 3]])] -target = [torch.tensor([[255, 255], [2, 3]])] -print(mean_iou(preds, target, num_labels=4, ignore_index=255, reduce_labels=False)) From 5d06e010bedaed1a7bfbd9c5b7e3b68cea637a5e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 12 Apr 2024 16:23:46 +0200 Subject: [PATCH 09/19] docstrings --- .../functional/segmentation/mean_iou.py | 59 ++++++--- src/torchmetrics/segmentation/mean_iou.py | 119 +++++++++++++++++- tests/unittests/segmentation/test_mean_iou.py | 42 ++++--- 3 files changed, 184 insertions(+), 36 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index cc6cbb5718c..793cfd45fdf 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -11,41 +11,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Tuple import torch from torch import Tensor -from torchmetrics.utilities.compute import _safe_divide from torchmetrics.functional.segmentation.utils import _ignore_background +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.compute import _safe_divide + def _mean_iou_validate_args( num_classes: int, include_background: bool, per_class: bool, ) -> None: + """Validate the arguments of the metric.""" if num_classes <= 0: - raise ValueError( - f"Expected argument `num_classes` must be a positive integer, but got {num_classes}." - ) + raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.") if not isinstance(include_background, bool): - raise ValueError( - f"Expected argument `include_background` must be a boolean, but got {include_background}." - ) + raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.") if not isinstance(per_class, bool): raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.") + def _mean_iou_update( preds: Tensor, target: Tensor, num_classes: int, include_background: bool = False, -): - if preds.shape != target.shape: # assume preds is probabilities with an extra dimension - preds = preds.argmax(dim=1) +) -> Tuple[Tensor, Tensor]: + """Update the intersection and union counts for the mean IoU computation.""" + _check_same_shape(preds, target) + if (preds.bool() != preds).any(): # preds is an index tensor preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) - if (target.bool() != target).any(): # target is an index tensor + if (target.bool() != target).any(): # target is an index tensor target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) if not include_background: @@ -58,24 +59,50 @@ def _mean_iou_update( union = target_sum + pred_sum - intersection return intersection, union + def _mean_iou_compute( intersection: Tensor, union: Tensor, per_class: bool = False, ) -> Tensor: + """Compute the mean IoU metric.""" val = _safe_divide(intersection, union) return val if per_class else torch.mean(val, 1) + def mean_iou( preds: Tensor, target: Tensor, num_classes: int, - include_background: bool = False, + include_background: bool = True, per_class: bool = False, ) -> Tensor: - """ - + """Calculates the mean Intersection over Union (mIoU) for semantic segmentation. + + Args: + preds: Predictions from model + target: Ground truth values + num_classes: Number of classes + include_background: Whether to include the background class in the computation + per_class: Whether to compute the IoU for each class separately, else average over all classes + + Returns: + The mean IoU score + + Example: + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.functional.segmentation import mean_iou + >>> preds = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> mean_iou(preds, target, num_classes=5) + tensor([0.3193, 0.3305, 0.3382, 0.3246]) + >>> mean_iou(preds, target, num_classes=5, per_class=True) + tensor([[0.3093, 0.3500, 0.3081, 0.3389, 0.2903], + [0.2963, 0.3316, 0.3505, 0.2804, 0.3936], + [0.3724, 0.3249, 0.3660, 0.3184, 0.3093], + [0.3085, 0.3267, 0.3155, 0.3575, 0.3147]]) + """ intersection, union = _mean_iou_update(preds, target, num_classes, include_background) return _mean_iou_compute(intersection, union, per_class=per_class) - diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py index 2b4fc2b21fc..d770af2817c 100644 --- a/src/torchmetrics/segmentation/mean_iou.py +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -11,21 +11,91 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torch import Tensor +from typing import Any, Optional, Sequence, Union + import torch +from torch import Tensor + +from torchmetrics.functional.segmentation.mean_iou import _mean_iou_compute, _mean_iou_update, _mean_iou_validate_args from torchmetrics.metric import Metric -from torchmetrics.functional.segmentation.mean_iou import _mean_iou_update, _mean_iou_compute, _mean_iou_validate_args +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MeanIoU.plot"] class MeanIoU(Metric): - """Computes Mean Intersection over Union (mIoU) for semantic segmentation.""" + """Computes Mean Intersection over Union (mIoU) for semantic segmentation. + + The metric is defined by the overlap between the predicted segmentation and the ground truth, divided by the + total area covered by the union of the two. The metric can be computed for each class separately or for all + classes at once. The metric is optimal at a value of 1 and worst at a value of 0. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. That format will be automatically + converted to a one-hot tensor. + - ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. That format will be automatically + converted to a one-hot tensor. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``miou`` (:class:`~torch.Tensor`): The mean Intersection over Union (mIoU) score. If ``per_class`` is set to + ``True``, the output will be a tensor of shape ``(C,)`` with the IoU score for each class. If ``per_class`` is + set to ``False``, the output will be a scalar tensor. + + Args: + num_classes: The number of classes in the segmentation problem. + include_background: Whether to include the background class in the computation + per_class: Whether to compute the IoU for each class separately. If set to ``False``, the metric will + compute the mean IoU over all classes. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: + If ``num_classes`` is not a positive integer + ValueError: + If ``include_background`` is not a boolean + ValueError: + If ``per_class`` is not a boolean + + Example: + >>> import torch + >>> _ = torch.manual_seed(0) + >>> from torchmetrics.segmentation import MeanIoU + >>> miou = MeanIoU(num_classes=3) + >>> preds = torch.randint(0, 2, (10, 3, 128, 128)) + >>> target = torch.randint(0, 2, (10, 3, 128, 128)) + >>> miou(preds, target) + tensor(0.3318) + >>> miou = MeanIoU(num_classes=3, per_class=True) + >>> miou(preds, target) + tensor([0.3322, 0.3303, 0.3329]) + >>> miou = MeanIoU(num_classes=3, per_class=True, include_background=False) + >>> miou(preds, target) + tensor([0.3303, 0.3329]) + + """ + + score: Tensor + num_batches: Tensor + full_state_update: bool = False + is_differentiable: bool = False + higher_is_better: bool = True + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 def __init__( self, num_classes: int, - include_background: bool = False, + include_background: bool = True, per_class: bool = False, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(**kwargs) _mean_iou_validate_args(num_classes, include_background, per_class) @@ -48,3 +118,42 @@ def compute(self) -> Tensor: """Update the state with the new data.""" return self.score / self.num_batches + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality + >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') + >>> metric.update(torch.rand(8000), torch.rand(8000)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality + >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(8000), torch.rand(8000))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/tests/unittests/segmentation/test_mean_iou.py b/tests/unittests/segmentation/test_mean_iou.py index 8843805decc..1d8426c823c 100644 --- a/tests/unittests/segmentation/test_mean_iou.py +++ b/tests/unittests/segmentation/test_mean_iou.py @@ -16,30 +16,40 @@ import pytest import torch - -import torch - +from monai.metrics.meaniou import compute_iou from torchmetrics.functional.segmentation.mean_iou import mean_iou from torchmetrics.segmentation.mean_iou import MeanIoU -from monai.metrics.meaniou import compute_iou + +from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input from unittests._helpers.testers import MetricTester -from unittests import BATCH_SIZE, NUM_BATCHES, _Input, NUM_CLASSES _inputs1 = _Input( - preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), ) _inputs2 = _Input( - preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), ) -# _inputs3 = _Input( -# preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32), -# target=torch.randint(0, 5, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), -# ) +_inputs3 = _Input( + preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), + target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), +) + -def _reference_mean_iou(preds: torch.Tensor, target: torch.Tensor, include_background: bool = True, per_class: bool = True, reduce: bool = True): +def _reference_mean_iou( + preds: torch.Tensor, + target: torch.Tensor, + include_background: bool = True, + per_class: bool = True, + reduce: bool = True, +): """Calculate reference metric for `MeanIoU`.""" + if (preds.bool() != preds).any(): # preds is an index tensor + preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) + if (target.bool() != target).any(): # target is an index tensor + target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) + val = compute_iou(preds, target, include_background=include_background) if reduce: return torch.mean(val, 0) if per_class else torch.mean(val) @@ -51,12 +61,13 @@ def _reference_mean_iou(preds: torch.Tensor, target: torch.Tensor, include_backg [ (_inputs1.preds, _inputs1.target), (_inputs2.preds, _inputs2.target), - #(_inputs3.preds, _inputs3.target), + # (_inputs3.preds, _inputs3.target), ], ) @pytest.mark.parametrize("include_background", [True, False]) class TestMeanIoU(MetricTester): """Test class for `MeanIoU` metric.""" + atol = 1e-4 @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) @@ -68,7 +79,9 @@ def test_mean_iou_class(self, preds, target, include_background, per_class, ddp) preds=preds, target=target, metric_class=MeanIoU, - reference_metric=partial(_reference_mean_iou, include_background=include_background, per_class=per_class, reduce=True), + reference_metric=partial( + _reference_mean_iou, include_background=include_background, per_class=per_class, reduce=True + ), metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": per_class}, ) @@ -81,4 +94,3 @@ def test_mean_iou_functional(self, preds, target, include_background): reference_metric=partial(_reference_mean_iou, include_background=include_background, reduce=False), metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": True}, ) - From 45fd00230345177e12551a0d35d62ef13efe0577 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 12 Apr 2024 16:24:23 +0200 Subject: [PATCH 10/19] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ec038a82b63..4febe7a114a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for calculating segmentation quality and recognition quality in `PanopticQuality` metric ([#2381](https://github.com/Lightning-AI/torchmetrics/pull/2381)) -- Added a new segmentation metric `mean IoU` ([#1236](https://github.com/PyTorchLightning/metrics/pull/1236)) +- Added a new segmentation metric `MeanIoU` ([#1236](https://github.com/PyTorchLightning/metrics/pull/1236)) ### Changed From 179f6d99bd7cc07f62902dcfcaa94aa47e3772d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:25:08 +0000 Subject: [PATCH 11/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/segmentation/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py index 34be884b282..e8427a69326 100644 --- a/src/torchmetrics/functional/segmentation/utils.py +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -23,6 +23,7 @@ from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.imports import _SCIPY_AVAILABLE + def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """Ignore the background class in the computation.""" preds = preds[:, 1:] if preds.shape[1] > 1 else preds From 1f918221472ab06a2fff2079ea63b8c0d7e84391 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 12 Apr 2024 16:27:09 +0200 Subject: [PATCH 12/19] docs fix --- docs/source/segmentation/mean_iou.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/segmentation/mean_iou.rst b/docs/source/segmentation/mean_iou.rst index 1c24eceaabc..7fddd9f316d 100644 --- a/docs/source/segmentation/mean_iou.rst +++ b/docs/source/segmentation/mean_iou.rst @@ -3,9 +3,9 @@ :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg :tags: segmentation -############################ +################################### Mean Intersection over Union (mIoU) -############################ +################################### Module Interface ________________ From f56ead4d89e41b6cf43e7a8953eba7cfae9b9633 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 19 Apr 2024 16:05:31 +0200 Subject: [PATCH 13/19] fix docs --- src/torchmetrics/segmentation/mean_iou.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py index d770af2817c..e6f9672547d 100644 --- a/src/torchmetrics/segmentation/mean_iou.py +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -38,7 +38,7 @@ class MeanIoU(Metric): the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` can be provided, where the integer values correspond to the class index. That format will be automatically converted to a one-hot tensor. - - ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + - ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` can be provided, where the integer values correspond to the class index. That format will be automatically converted to a one-hot tensor. From c62c50c296a84a7d9488988350e5d495a4616bcb Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 22 Apr 2024 08:07:02 +0200 Subject: [PATCH 14/19] fix + tests --- src/torchmetrics/segmentation/mean_iou.py | 4 +--- tests/unittests/segmentation/test_mean_iou.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py index e6f9672547d..f36f7fc3bc0 100644 --- a/src/torchmetrics/segmentation/mean_iou.py +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -105,18 +105,16 @@ def __init__( num_classes = num_classes - 1 if not include_background else num_classes self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean") - self.add_state("num_batches", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with the new data.""" intersection, union = _mean_iou_update(preds, target, self.num_classes, self.include_background) score = _mean_iou_compute(intersection, union, per_class=self.per_class) self.score += score.mean(0) if self.per_class else score.mean() - self.num_batches += 1 def compute(self) -> Tensor: """Update the state with the new data.""" - return self.score / self.num_batches + return self.score # / self.num_batches def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/tests/unittests/segmentation/test_mean_iou.py b/tests/unittests/segmentation/test_mean_iou.py index 1d8426c823c..ed7e07d06bb 100644 --- a/tests/unittests/segmentation/test_mean_iou.py +++ b/tests/unittests/segmentation/test_mean_iou.py @@ -61,7 +61,7 @@ def _reference_mean_iou( [ (_inputs1.preds, _inputs1.target), (_inputs2.preds, _inputs2.target), - # (_inputs3.preds, _inputs3.target), + (_inputs3.preds, _inputs3.target), ], ) @pytest.mark.parametrize("include_background", [True, False]) From 6cd0d74bd775e1575b2107b7f1e6b1924c9099d8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 22 Apr 2024 08:13:35 +0200 Subject: [PATCH 15/19] validate args in functional --- src/torchmetrics/functional/segmentation/mean_iou.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 793cfd45fdf..0a4e24da6e1 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -104,5 +104,6 @@ def mean_iou( [0.3085, 0.3267, 0.3155, 0.3575, 0.3147]]) """ + _mean_iou_validate_args(num_classes, include_background, per_class) intersection, union = _mean_iou_update(preds, target, num_classes, include_background) return _mean_iou_compute(intersection, union, per_class=per_class) From 46e1183f368c7061b874dd3103714fec68348433 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 23 Apr 2024 13:31:21 +0200 Subject: [PATCH 16/19] Update src/torchmetrics/functional/segmentation/utils.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/functional/segmentation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py index e8427a69326..6c2fed92df2 100644 --- a/src/torchmetrics/functional/segmentation/utils.py +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -25,7 +25,7 @@ def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: - """Ignore the background class in the computation.""" + """Ignore the background class in the computation assuming it is the first, index 0.""" preds = preds[:, 1:] if preds.shape[1] > 1 else preds target = target[:, 1:] if target.shape[1] > 1 else target return preds, target From d46f3967c696a6907e84a4951cbd251ecc85afe4 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 23 Apr 2024 13:35:25 +0200 Subject: [PATCH 17/19] fix nan case --- tests/unittests/segmentation/test_mean_iou.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unittests/segmentation/test_mean_iou.py b/tests/unittests/segmentation/test_mean_iou.py index ed7e07d06bb..013b71572d6 100644 --- a/tests/unittests/segmentation/test_mean_iou.py +++ b/tests/unittests/segmentation/test_mean_iou.py @@ -51,6 +51,7 @@ def _reference_mean_iou( target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) val = compute_iou(preds, target, include_background=include_background) + val[torch.isnan(val)] = 0.0 if reduce: return torch.mean(val, 0) if per_class else torch.mean(val) return val From 54942fc1833302142d37b566a4e8a95cf7b0f729 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 23 Apr 2024 14:15:19 +0200 Subject: [PATCH 18/19] Docs --- docs/source/index.rst | 8 -------- 1 file changed, 8 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 6a3e9a9e20d..880a6a2657e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -198,14 +198,6 @@ Or directly from conda nominal/* -.. toctree:: - :maxdepth: 2 - :name: segmentation - :caption: Segmentation - :glob: - - segmentation/* - .. toctree:: :maxdepth: 2 :name: pairwise From ad6e6b9d4861eb0701eb7d0b55e209ed6eef7bbd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 23 Apr 2024 15:08:56 +0200 Subject: [PATCH 19/19] fix --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index fcb4808284f..82e370c91f5 100644 --- a/README.md +++ b/README.md @@ -283,7 +283,6 @@ covers the following domains: - Audio - Classification - Detection -- Segmentation - Information Retrieval - Image - Multimodal (Image-Text)