From 1335c7b7c47e729ef0f99785d7d1692803155beb Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Mon, 14 Oct 2024 16:20:47 -0400 Subject: [PATCH] New segmentation metric: Hausdorff Distance (#2122) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Bas Krahmer Co-authored-by: Nicki Skafte Detlefsen Co-authored-by: Jirka B --- CHANGELOG.md | 3 + docs/source/links.rst | 2 + .../segmentation/hausdorff_distance.rst | 21 +++ requirements/integrate.txt | 0 .../functional/segmentation/__init__.py | 3 +- .../segmentation/hausdorff_distance.py | 114 +++++++++++++ .../functional/segmentation/utils.py | 57 +++++-- src/torchmetrics/segmentation/__init__.py | 3 +- .../segmentation/hausdorff_distance.py | 157 ++++++++++++++++++ tests/unittests/segmentation/inputs.py | 28 ++++ .../segmentation/test_hausdorff_distance.py | 116 +++++++++++++ tests/unittests/segmentation/test_utils.py | 49 ++++++ 12 files changed, 538 insertions(+), 15 deletions(-) create mode 100644 docs/source/segmentation/hausdorff_distance.rst create mode 100644 requirements/integrate.txt create mode 100644 src/torchmetrics/functional/segmentation/hausdorff_distance.py create mode 100644 src/torchmetrics/segmentation/hausdorff_distance.py create mode 100644 tests/unittests/segmentation/inputs.py create mode 100644 tests/unittests/segmentation/test_hausdorff_distance.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 50be7dd1f53..f432c7afa26 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `truncation` argument to `BERTScore` ([#2776](https://github.com/Lightning-AI/torchmetrics/pull/2776)) +- Added `HausdorffDistance` to segmentation package ([#2122](https://github.com/Lightning-AI/torchmetrics/pull/2122)) + + ### Changed - Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649)) diff --git a/docs/source/links.rst b/docs/source/links.rst index 035cbbec8b7..2e9b222f28f 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -172,4 +172,6 @@ .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html .. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013 .. _Generalized Dice Score: https://arxiv.org/abs/1707.03237 +.. _Hausdorff Distance: https://en.wikipedia.org/wiki/Hausdorff_distance +.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html .. _Procrustes Disparity: https://en.wikipedia.org/wiki/Procrustes_analysis diff --git a/docs/source/segmentation/hausdorff_distance.rst b/docs/source/segmentation/hausdorff_distance.rst new file mode 100644 index 00000000000..cfe1d3fdb5b --- /dev/null +++ b/docs/source/segmentation/hausdorff_distance.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Hausdorff Distance + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/text_classification.svg + :tags: segmentation + +.. include:: ../links.rst + +################## +Hausdorff Distance +################## + +Module Interface +________________ + +.. autoclass:: torchmetrics.segmentation.HausdorffDistance + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.segmentation.hausdorff_distance diff --git a/requirements/integrate.txt b/requirements/integrate.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/torchmetrics/functional/segmentation/__init__.py b/src/torchmetrics/functional/segmentation/__init__.py index 3d23192a36a..068bf77d775 100644 --- a/src/torchmetrics/functional/segmentation/__init__.py +++ b/src/torchmetrics/functional/segmentation/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score +from torchmetrics.functional.segmentation.hausdorff_distance import hausdorff_distance from torchmetrics.functional.segmentation.mean_iou import mean_iou -__all__ = ["generalized_dice_score", "mean_iou"] +__all__ = ["generalized_dice_score", "mean_iou", "hausdorff_distance"] diff --git a/src/torchmetrics/functional/segmentation/hausdorff_distance.py b/src/torchmetrics/functional/segmentation/hausdorff_distance.py new file mode 100644 index 00000000000..daadc90f6ba --- /dev/null +++ b/src/torchmetrics/functional/segmentation/hausdorff_distance.py @@ -0,0 +1,114 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Literal, Optional, Union + +import torch +from torch import Tensor + +from torchmetrics.functional.segmentation.utils import ( + _ignore_background, + edge_surface_distance, +) +from torchmetrics.utilities.checks import _check_same_shape + + +def _hausdorff_distance_validate_args( + num_classes: int, + include_background: bool, + distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", + spacing: Optional[Union[Tensor, List[float]]] = None, + directed: bool = False, + input_format: Literal["one-hot", "index"] = "one-hot", +) -> None: + """Validate the arguments of `hausdorff_distance` function.""" + if num_classes <= 0: + raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.") + if not isinstance(include_background, bool): + raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.") + if distance_metric not in ["euclidean", "chessboard", "taxicab"]: + raise ValueError( + f"Arg `distance_metric` must be one of 'euclidean', 'chessboard', 'taxicab', but got {distance_metric}." + ) + if spacing is not None and not isinstance(spacing, (list, Tensor)): + raise ValueError(f"Arg `spacing` must be a list or tensor, but got {type(spacing)}.") + if not isinstance(directed, bool): + raise ValueError(f"Expected argument `directed` must be a boolean, but got {directed}.") + if input_format not in ["one-hot", "index"]: + raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.") + + +def hausdorff_distance( + preds: Tensor, + target: Tensor, + num_classes: int, + include_background: bool = False, + distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", + spacing: Optional[Union[Tensor, List[float]]] = None, + directed: bool = False, + input_format: Literal["one-hot", "index"] = "one-hot", +) -> Tensor: + """Calculate `Hausdorff Distance`_ for semantic segmentation. + + Args: + preds: predicted binarized segmentation map + target: target binarized segmentation map + num_classes: number of classes + include_background: whether to include background class in calculation + distance_metric: distance metric to calculate surface distance. Choose one of `"euclidean"`, + `"chessboard"` or `"taxicab"` + spacing: spacing between pixels along each spatial dimension. If not provided the spacing is assumed to be 1 + directed: whether to calculate directed or undirected Hausdorff distance + input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors + or ``"index"`` for index tensors + + Returns: + Hausdorff Distance for each class and batch element + + Example: + >>> from torch import randint + >>> from torchmetrics.functional.segmentation import hausdorff_distance + >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> hausdorff_distance(preds, target, num_classes=5) + tensor([[2.0000, 1.4142, 2.0000, 2.0000], + [1.4142, 2.0000, 2.0000, 2.0000], + [2.0000, 2.0000, 1.4142, 2.0000], + [2.0000, 2.8284, 2.0000, 2.2361]]) + + """ + _hausdorff_distance_validate_args(num_classes, include_background, distance_metric, spacing, directed, input_format) + _check_same_shape(preds, target) + + if input_format == "index": + preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) + target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) + + if not include_background: + preds, target = _ignore_background(preds, target) + + distances = torch.zeros(preds.shape[0], preds.shape[1], device=preds.device) + + # TODO: add support for batched inputs + for b in range(preds.shape[0]): + for c in range(preds.shape[1]): + dist = edge_surface_distance( + preds=preds[b, c], + target=target[b, c], + distance_metric=distance_metric, + spacing=spacing, + symmetric=not directed, + ) + distances[b, c] = torch.max(dist) if directed else torch.max(dist[0].max(), dist[1].max()) # type: ignore + return distances diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py index 6c2fed92df2..59d42e16171 100644 --- a/src/torchmetrics/functional/segmentation/utils.py +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -32,7 +32,7 @@ def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: def check_if_binarized(x: Tensor) -> None: - """Check if the input is binarized. + """Check if tensor is binarized. Example: >>> from torchmetrics.functional.segmentation.utils import check_if_binarized @@ -200,9 +200,8 @@ def distance_transform( Args: x: The binary tensor to calculate the distance transform of. - sampling: Only relevant when distance is calculated using the euclidean distance. The sampling refers to the - pixel spacing in the image, i.e. the distance between two adjacent pixels. If not provided, the pixel - spacing is assumed to be 1. + sampling: The sampling refers to the pixel spacing in the image, i.e. the distance between two adjacent pixels. + If not provided, the pixel spacing is assumed to be 1. metric: The distance to use for the distance transform. Can be one of ``"euclidean"``, ``"chessboard"`` or ``"taxicab"``. engine: The engine to use for the distance transform. Can be one of ``["pytorch", "scipy"]``. In general, @@ -249,25 +248,25 @@ def distance_transform( raise ValueError(f"Expected argument `sampling` to have length 2 but got length `{len(sampling)}`.") if engine == "pytorch": + x = x.float() # calculate distance from every foreground pixel to every background pixel i0, j0 = torch.where(x == 0) i1, j1 = torch.where(x == 1) - dis_row = (i1.unsqueeze(1) - i0.unsqueeze(0)).abs_().mul_(sampling[0]) - dis_col = (j1.unsqueeze(1) - j0.unsqueeze(0)).abs_().mul_(sampling[1]) + dis_row = (i1.view(-1, 1) - i0.view(1, -1)).abs() + dis_col = (j1.view(-1, 1) - j0.view(1, -1)).abs() # # calculate distance h, _ = x.shape if metric == "euclidean": - dis_row = dis_row.float() - dis_row.pow_(2).add_(dis_col.pow_(2)).sqrt_() + dis = ((sampling[0] * dis_row) ** 2 + (sampling[1] * dis_col) ** 2).sqrt() if metric == "chessboard": - dis_row = dis_row.max(dis_col) + dis = torch.max(sampling[0] * dis_row, sampling[1] * dis_col).float() if metric == "taxicab": - dis_row.add_(dis_col) + dis = (sampling[0] * dis_row + sampling[1] * dis_col).float() # select only the closest distance - mindis, _ = torch.min(dis_row, dim=1) - z = torch.zeros_like(x, dtype=mindis.dtype).view(-1) + mindis, _ = torch.min(dis, dim=1) + z = torch.zeros_like(x).view(-1) z[i1 * h + j1] = mindis return z.view(x.shape) @@ -279,7 +278,7 @@ def distance_transform( if metric == "euclidean": return ndimage.distance_transform_edt(x.cpu().numpy(), sampling) - return ndimage.distance_transform_cdt(x.cpu().numpy(), metric=metric) + return ndimage.distance_transform_cdt(x.cpu().numpy(), sampling, metric=metric) def mask_edges( @@ -390,6 +389,38 @@ def surface_distance( return dis[preds] +def edge_surface_distance( + preds: Tensor, + target: Tensor, + distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", + spacing: Optional[Union[Tensor, List[float]]] = None, + symmetric: bool = False, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Extracts the edges from the input masks and calculates the surface distance between them. + + Args: + preds: The predicted binary edge mask. + target: The target binary edge mask. + distance_metric: The distance metric to use. One of `["euclidean", "chessboard", "taxicab"]`. + spacing: The spacing between pixels along each spatial dimension. + symmetric: Whether to calculate the symmetric distance between the edges. + + Returns: + A tensor with length equal to the number of edges in predictions e.g. `preds.sum()`. Each element is the + distance from the corresponding edge in `preds` to the closest edge in `target`. If `symmetric` is `True`, the + function returns a tuple containing the distances from the predicted edges to the target edges and vice versa. + + """ + output = mask_edges(preds, target) + edges_preds, edges_target = output[0].bool(), output[1].bool() + if symmetric: + return ( + surface_distance(edges_preds, edges_target, distance_metric=distance_metric, spacing=spacing), + surface_distance(edges_target, edges_preds, distance_metric=distance_metric, spacing=spacing), + ) + return surface_distance(edges_preds, edges_target, distance_metric=distance_metric, spacing=spacing) + + @functools.lru_cache def get_neighbour_tables( spacing: Union[Tuple[int, int], Tuple[int, int, int]], device: Optional[torch.device] = None diff --git a/src/torchmetrics/segmentation/__init__.py b/src/torchmetrics/segmentation/__init__.py index 5b609c2c738..6e9b1c63313 100644 --- a/src/torchmetrics/segmentation/__init__.py +++ b/src/torchmetrics/segmentation/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore +from torchmetrics.segmentation.hausdorff_distance import HausdorffDistance from torchmetrics.segmentation.mean_iou import MeanIoU -__all__ = ["GeneralizedDiceScore", "MeanIoU"] +__all__ = ["GeneralizedDiceScore", "MeanIoU", "HausdorffDistance"] diff --git a/src/torchmetrics/segmentation/hausdorff_distance.py b/src/torchmetrics/segmentation/hausdorff_distance.py new file mode 100644 index 00000000000..f1e8812ed30 --- /dev/null +++ b/src/torchmetrics/segmentation/hausdorff_distance.py @@ -0,0 +1,157 @@ +# Copyright The Lightning team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Literal, Optional, Sequence, Union + +import torch +from torch import Tensor + +from torchmetrics.functional.segmentation.hausdorff_distance import ( + _hausdorff_distance_validate_args, + hausdorff_distance, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["HausdorffDistance.plot"] + + +class HausdorffDistance(Metric): + r"""Compute the `Hausdorff Distance`_ between two subsets of a metric space for semantic segmentation. + + .. math:: + d_{\Pi}(X,Y) = \max{/sup_{x\in X} {d(x,Y)}, /sup_{y\in Y} {d(X,y)}} + + where :math:`\X, \Y` are two subsets of a metric space with distance metric :math:`d`. The Hausdorff distance is + the maximum distance from a point in one set to the closest point in the other set. The Hausdorff distance is a + measure of the degree of mismatch between two sets. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. The input type can be controlled + with the ``input_format`` argument. + - ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. The input type can be controlled + with the ``input_format`` argument. + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``hausdorff_distance`` (:class:`~torch.Tensor`): A scalar float tensor with the Hausdorff distance averaged over + classes and samples + + Args: + num_classes: number of classes + include_background: whether to include background class in calculation + distance_metric: distance metric to calculate surface distance. Choose one of `"euclidean"`, + `"chessboard"` or `"taxicab"` + spacing: spacing between pixels along each spatial dimension. If not provided the spacing is assumed to be 1 + directed: whether to calculate directed or undirected Hausdorff distance + input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors + or ``"index"`` for index tensors + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torch import randint + >>> from torchmetrics.segmentation import HausdorffDistance + >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> hausdorff_distance = HausdorffDistance(distance_metric="euclidean", num_classes=5) + >>> hausdorff_distance(preds, target) + tensor(1.9567) + + """ + + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = False + plot_lower_bound: float = 0.0 + + score: Tensor + total: Tensor + + def __init__( + self, + num_classes: int, + include_background: bool = False, + distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", + spacing: Optional[Union[Tensor, List[float]]] = None, + directed: bool = False, + input_format: Literal["one-hot", "index"] = "one-hot", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + _hausdorff_distance_validate_args( + num_classes, include_background, distance_metric, spacing, directed, input_format + ) + self.num_classes = num_classes + self.include_background = include_background + self.distance_metric = distance_metric + self.spacing = spacing + self.directed = directed + self.input_format = input_format + self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + score = hausdorff_distance( + preds, + target, + self.num_classes, + include_background=self.include_background, + distance_metric=self.distance_metric, + spacing=self.spacing, + directed=self.directed, + input_format=self.input_format, + ) + self.score += score.sum() + self.total += score.numel() + + def compute(self) -> Tensor: + """Compute final Hausdorff distance over states.""" + return self.score / self.total + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> from torchmetrics.segmentation import HausdorffDistance + >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> metric = HausdorffDistance(num_classes=5) + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + """ + return self._plot(val, ax) diff --git a/tests/unittests/segmentation/inputs.py b/tests/unittests/segmentation/inputs.py new file mode 100644 index 00000000000..996b8364e9c --- /dev/null +++ b/tests/unittests/segmentation/inputs.py @@ -0,0 +1,28 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__all__ = ["_Input"] + +from typing import NamedTuple + +from torch import Tensor + +from unittests._helpers import seed_all + +seed_all(42) + + +# extrinsic input for clustering metrics that requires predicted clustering labels and target clustering labels +class _Input(NamedTuple): + preds: Tensor + target: Tensor diff --git a/tests/unittests/segmentation/test_hausdorff_distance.py b/tests/unittests/segmentation/test_hausdorff_distance.py new file mode 100644 index 00000000000..afd77c1f4b2 --- /dev/null +++ b/tests/unittests/segmentation/test_hausdorff_distance.py @@ -0,0 +1,116 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Any + +import pytest +import torch +from monai.metrics.hausdorff_distance import compute_hausdorff_distance as monai_hausdorff_distance +from torchmetrics.functional.segmentation.hausdorff_distance import hausdorff_distance +from torchmetrics.segmentation.hausdorff_distance import HausdorffDistance + +from unittests import NUM_BATCHES, _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester + +seed_all(42) +BATCH_SIZE = 4 # use smaller than normal batch size to reduce test time +NUM_CLASSES = 3 # use smaller than normal class size to reduce test time + +_inputs1 = _Input( + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16, 16)), + target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16, 16)), +) +_inputs2 = _Input( + preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), + target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), +) + + +def reference_metric(preds, target, input_format, reduce, **kwargs: Any): + """Reference implementation of metric.""" + if input_format == "index": + preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) + target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) + score = monai_hausdorff_distance(preds, target, **kwargs) + return score.mean() if reduce else score + + +@pytest.mark.parametrize("inputs, input_format", [(_inputs1, "one-hot"), (_inputs2, "index")]) +@pytest.mark.parametrize("distance_metric", ["euclidean", "chessboard", "taxicab"]) +@pytest.mark.parametrize("directed", [True, False]) +@pytest.mark.parametrize("spacing", [None, [2, 2]]) +class TestHausdorffDistance(MetricTester): + """Test class for `HausdorffDistance` metric.""" + + atol = 1e-5 + + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_hausdorff_distance_class(self, inputs, input_format, distance_metric, directed, spacing, ddp): + """Test class implementation of metric.""" + if spacing is not None and distance_metric != "euclidean": + pytest.skip("Spacing is only supported for Euclidean distance metric.") + preds, target = inputs + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=HausdorffDistance, + reference_metric=partial( + reference_metric, + input_format=input_format, + distance_metric=distance_metric, + directed=directed, + spacing=spacing, + reduce=True, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "distance_metric": distance_metric, + "directed": directed, + "spacing": spacing, + "input_format": input_format, + }, + ) + + def test_hausdorff_distance_functional(self, inputs, input_format, distance_metric, directed, spacing): + """Test functional implementation of metric.""" + if spacing is not None and distance_metric != "euclidean": + pytest.skip("Spacing is only supported for Euclidean distance metric.") + preds, target = inputs + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=hausdorff_distance, + reference_metric=partial( + reference_metric, + input_format=input_format, + distance_metric=distance_metric, + directed=directed, + spacing=spacing, + reduce=False, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "distance_metric": distance_metric, + "directed": directed, + "spacing": spacing, + "input_format": input_format, + }, + ) + + +def test_hausdorff_distance_raises_error(): + """Check that metric raises appropriate errors.""" + preds, target = _inputs1 diff --git a/tests/unittests/segmentation/test_utils.py b/tests/unittests/segmentation/test_utils.py index d37941a6ff3..39cff09a2dd 100644 --- a/tests/unittests/segmentation/test_utils.py +++ b/tests/unittests/segmentation/test_utils.py @@ -14,6 +14,7 @@ import pytest import torch from monai.metrics.utils import get_code_to_measure_table +from monai.metrics.utils import get_edge_surface_distance as monai_get_edge_surface_distance from monai.metrics.utils import get_mask_edges as monai_get_mask_edges from monai.metrics.utils import get_surface_distance as monai_get_surface_distance from scipy.ndimage import binary_erosion as scibinary_erosion @@ -23,6 +24,7 @@ from torchmetrics.functional.segmentation.utils import ( binary_erosion, distance_transform, + edge_surface_distance, generate_binary_structure, get_neighbour_tables, mask_edges, @@ -231,3 +233,50 @@ def test_mask_edges(cases, spacing, crop, device): for r1, r2 in zip(res, reference_res): assert torch.allclose(r1.cpu().float(), torch.from_numpy(r2).float()) + + +@pytest.mark.parametrize( + "cases", + [ + ( + torch.tensor( + [[1, 1, 1, 1, 1], [1, 0, 0, 0, 1], [1, 0, 0, 0, 1], [1, 0, 0, 0, 1], [1, 1, 1, 1, 1]], dtype=torch.bool + ), + torch.tensor( + [[1, 1, 1, 1, 0], [1, 0, 0, 1, 0], [1, 0, 0, 1, 0], [1, 0, 0, 1, 0], [1, 1, 1, 1, 0]], dtype=torch.bool + ), + ), + (torch.randint(0, 2, (5, 5), dtype=torch.bool), torch.randint(0, 2, (5, 5), dtype=torch.bool)), + (torch.randint(0, 2, (50, 50), dtype=torch.bool), torch.randint(0, 2, (50, 50), dtype=torch.bool)), + ], +) +@pytest.mark.parametrize("distance_metric", ["euclidean", "chessboard", "taxicab"]) +@pytest.mark.parametrize("symmetric", [False, True]) +@pytest.mark.parametrize("spacing", [None, 1, 2]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_edge_surface_distance(cases, distance_metric, symmetric, spacing, device): + """Test the edge surface distance function.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA device not available.") + if spacing == 2 and distance_metric != "euclidean": + pytest.skip("Only euclidean distance is supported for spacing != 1 in reference") + preds, target = cases + if spacing is not None: + spacing = preds.ndim * [spacing] + + res = edge_surface_distance( + preds.to(device), target.to(device), spacing=spacing, distance_metric=distance_metric, symmetric=symmetric + ) + _, reference_res, _ = monai_get_edge_surface_distance( + preds, + target, + spacing=tuple(spacing) if spacing is not None else spacing, + distance_metric=distance_metric, + symmetric=symmetric, + ) + + if symmetric: + assert torch.allclose(res[0].cpu(), reference_res[0].to(res[0].dtype)) + assert torch.allclose(res[1].cpu(), reference_res[1].to(res[1].dtype)) + else: + assert torch.allclose(res.cpu(), reference_res[0].to(res.dtype))