Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New segmentation metric: Hausdorff Distance #2122

Merged
merged 44 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
5ab879d
add docs
matsumotosan Sep 30, 2023
b6519c3
initial commit
matsumotosan Oct 1, 2023
4f5d606
fix hausdorff metric args
matsumotosan Oct 4, 2023
80cbb1a
ci: switch to custom docker images (#2123)
matsumotosan Oct 14, 2023
05c154a
Add `average` to curve metrics (#2084)
matsumotosan Oct 14, 2023
ea58776
symmetric test
matsumotosan Oct 14, 2023
efc972f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2024
3147712
Merge branch 'master' into 1990-hausdorff-distance
matsumotosan May 27, 2024
0c9b1a8
fix merge error
matsumotosan May 27, 2024
a3dcc86
fix imports
matsumotosan May 27, 2024
bfe0a3b
tests running
matsumotosan May 27, 2024
62b7a4c
Merge branch 'master' into 1990-hausdorff-distance
Borda Jul 22, 2024
abe3069
Merge branch 'master' into 1990-hausdorff-distance
baskrahmer Aug 6, 2024
5e0b253
Merge branch 'master' into 1990-hausdorff-distance
baskrahmer Aug 9, 2024
550770b
Add Torch-to-numpy wrapper for skimage metric
baskrahmer Aug 10, 2024
78da660
Return average over states
baskrahmer Aug 10, 2024
011722d
Fix docs for doctests
baskrahmer Aug 16, 2024
dd837e5
Merge branch 'master' into 1990-hausdorff-distance
baskrahmer Aug 16, 2024
c0091e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2024
1a53a22
Merge branch 'master' into 1990-hausdorff-distance
baskrahmer Aug 29, 2024
e2e86d9
Merge branch 'master' into 1990-hausdorff-distance
Borda Sep 2, 2024
873a8ca
Merge branch 'master' into 1990-hausdorff-distance
Borda Sep 16, 2024
223404a
Merge branch 'master' into 1990-hausdorff-distance
baskrahmer Sep 24, 2024
80829e8
Add pytest param for ddp
baskrahmer Sep 24, 2024
0e96276
Fix type hints
baskrahmer Sep 24, 2024
7b84f03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2024
cc0d239
Refactor lambda to function definition
baskrahmer Sep 24, 2024
a07e021
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2024
c7bdce9
Fix docs
baskrahmer Sep 24, 2024
8541d97
Output a tensor for reference metric
baskrahmer Sep 24, 2024
90ad414
Set dtype to float32 in reference metric
baskrahmer Sep 24, 2024
f61e7ac
Add links back
baskrahmer Sep 24, 2024
aaee609
Merge branch 'master' into 1990-hausdorff-distance
baskrahmer Oct 7, 2024
708ce27
Merge branch 'master' into 1990-hausdorff-distance
SkafteNicki Oct 12, 2024
3345393
changelog
SkafteNicki Oct 12, 2024
a4f129f
fix docstring + add input validation
SkafteNicki Oct 12, 2024
3f4b68e
add edge_surface_distance utility
SkafteNicki Oct 14, 2024
6c8b5b6
fix functional implementation
SkafteNicki Oct 14, 2024
d2723c5
fix modular implementation
SkafteNicki Oct 14, 2024
cc7294d
tests
SkafteNicki Oct 14, 2024
879b3e5
Merge branch 'master' into 1990-hausdorff-distance
SkafteNicki Oct 14, 2024
29cc85c
mypy
SkafteNicki Oct 14, 2024
ce37ff0
fix typing issue
SkafteNicki Oct 14, 2024
29abf60
fix docs
Borda Oct 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `truncation` argument to `BERTScore` ([#2776](https://github.com/Lightning-AI/torchmetrics/pull/2776))


- Added `HausdorffDistance` to segmentation package ([#2122](https://github.com/Lightning-AI/torchmetrics/pull/2122))


### Changed

- Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649))
Expand Down
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,6 @@
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013
.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237
.. _Hausdorff Distance: https://en.wikipedia.org/wiki/Hausdorff_distance
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _Procrustes Disparity: https://en.wikipedia.org/wiki/Procrustes_analysis
21 changes: 21 additions & 0 deletions docs/source/segmentation/hausdorff_distance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Hausdorff Distance
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/text_classification.svg
:tags: segmentation

.. include:: ../links.rst

##################
Hausdorff Distance
##################

Module Interface
________________

.. autoclass:: torchmetrics.segmentation.HausdorffDistance
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.segmentation.hausdorff_distance
Empty file added requirements/integrate.txt
Empty file.
3 changes: 2 additions & 1 deletion src/torchmetrics/functional/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score
from torchmetrics.functional.segmentation.hausdorff_distance import hausdorff_distance
from torchmetrics.functional.segmentation.mean_iou import mean_iou

__all__ = ["generalized_dice_score", "mean_iou"]
__all__ = ["generalized_dice_score", "mean_iou", "hausdorff_distance"]
114 changes: 114 additions & 0 deletions src/torchmetrics/functional/segmentation/hausdorff_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Literal, Optional, Union

import torch
from torch import Tensor

from torchmetrics.functional.segmentation.utils import (
_ignore_background,
edge_surface_distance,
)
from torchmetrics.utilities.checks import _check_same_shape


def _hausdorff_distance_validate_args(
num_classes: int,
include_background: bool,
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
spacing: Optional[Union[Tensor, List[float]]] = None,
directed: bool = False,
input_format: Literal["one-hot", "index"] = "one-hot",
) -> None:
"""Validate the arguments of `hausdorff_distance` function."""
if num_classes <= 0:
raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.")
if not isinstance(include_background, bool):
raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.")
if distance_metric not in ["euclidean", "chessboard", "taxicab"]:
raise ValueError(
f"Arg `distance_metric` must be one of 'euclidean', 'chessboard', 'taxicab', but got {distance_metric}."
)
if spacing is not None and not isinstance(spacing, (list, Tensor)):
raise ValueError(f"Arg `spacing` must be a list or tensor, but got {type(spacing)}.")
if not isinstance(directed, bool):
raise ValueError(f"Expected argument `directed` must be a boolean, but got {directed}.")
if input_format not in ["one-hot", "index"]:
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")


def hausdorff_distance(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = False,
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
spacing: Optional[Union[Tensor, List[float]]] = None,
directed: bool = False,
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tensor:
"""Calculate `Hausdorff Distance`_ for semantic segmentation.

Args:
preds: predicted binarized segmentation map
target: target binarized segmentation map
num_classes: number of classes
include_background: whether to include background class in calculation
distance_metric: distance metric to calculate surface distance. Choose one of `"euclidean"`,
`"chessboard"` or `"taxicab"`
spacing: spacing between pixels along each spatial dimension. If not provided the spacing is assumed to be 1
directed: whether to calculate directed or undirected Hausdorff distance
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors

Returns:
Hausdorff Distance for each class and batch element

Example:
>>> from torch import randint
>>> from torchmetrics.functional.segmentation import hausdorff_distance
>>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
>>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target
>>> hausdorff_distance(preds, target, num_classes=5)
tensor([[2.0000, 1.4142, 2.0000, 2.0000],
[1.4142, 2.0000, 2.0000, 2.0000],
[2.0000, 2.0000, 1.4142, 2.0000],
[2.0000, 2.8284, 2.0000, 2.2361]])

"""
_hausdorff_distance_validate_args(num_classes, include_background, distance_metric, spacing, directed, input_format)
_check_same_shape(preds, target)

if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if not include_background:
preds, target = _ignore_background(preds, target)

distances = torch.zeros(preds.shape[0], preds.shape[1], device=preds.device)

# TODO: add support for batched inputs
for b in range(preds.shape[0]):
for c in range(preds.shape[1]):
dist = edge_surface_distance(
preds=preds[b, c],
target=target[b, c],
distance_metric=distance_metric,
spacing=spacing,
symmetric=not directed,
)
distances[b, c] = torch.max(dist) if directed else torch.max(dist[0].max(), dist[1].max()) # type: ignore
return distances
57 changes: 44 additions & 13 deletions src/torchmetrics/functional/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:


def check_if_binarized(x: Tensor) -> None:
"""Check if the input is binarized.
"""Check if tensor is binarized.

Example:
>>> from torchmetrics.functional.segmentation.utils import check_if_binarized
Expand Down Expand Up @@ -200,9 +200,8 @@ def distance_transform(

Args:
x: The binary tensor to calculate the distance transform of.
sampling: Only relevant when distance is calculated using the euclidean distance. The sampling refers to the
pixel spacing in the image, i.e. the distance between two adjacent pixels. If not provided, the pixel
spacing is assumed to be 1.
sampling: The sampling refers to the pixel spacing in the image, i.e. the distance between two adjacent pixels.
If not provided, the pixel spacing is assumed to be 1.
metric: The distance to use for the distance transform. Can be one of ``"euclidean"``, ``"chessboard"``
or ``"taxicab"``.
engine: The engine to use for the distance transform. Can be one of ``["pytorch", "scipy"]``. In general,
Expand Down Expand Up @@ -249,25 +248,25 @@ def distance_transform(
raise ValueError(f"Expected argument `sampling` to have length 2 but got length `{len(sampling)}`.")

if engine == "pytorch":
x = x.float()
# calculate distance from every foreground pixel to every background pixel
i0, j0 = torch.where(x == 0)
i1, j1 = torch.where(x == 1)
dis_row = (i1.unsqueeze(1) - i0.unsqueeze(0)).abs_().mul_(sampling[0])
dis_col = (j1.unsqueeze(1) - j0.unsqueeze(0)).abs_().mul_(sampling[1])
dis_row = (i1.view(-1, 1) - i0.view(1, -1)).abs()
dis_col = (j1.view(-1, 1) - j0.view(1, -1)).abs()

# # calculate distance
h, _ = x.shape
if metric == "euclidean":
dis_row = dis_row.float()
dis_row.pow_(2).add_(dis_col.pow_(2)).sqrt_()
dis = ((sampling[0] * dis_row) ** 2 + (sampling[1] * dis_col) ** 2).sqrt()
if metric == "chessboard":
dis_row = dis_row.max(dis_col)
dis = torch.max(sampling[0] * dis_row, sampling[1] * dis_col).float()
if metric == "taxicab":
dis_row.add_(dis_col)
dis = (sampling[0] * dis_row + sampling[1] * dis_col).float()

# select only the closest distance
mindis, _ = torch.min(dis_row, dim=1)
z = torch.zeros_like(x, dtype=mindis.dtype).view(-1)
mindis, _ = torch.min(dis, dim=1)
z = torch.zeros_like(x).view(-1)
z[i1 * h + j1] = mindis
return z.view(x.shape)

Expand All @@ -279,7 +278,7 @@ def distance_transform(

if metric == "euclidean":
return ndimage.distance_transform_edt(x.cpu().numpy(), sampling)
return ndimage.distance_transform_cdt(x.cpu().numpy(), metric=metric)
return ndimage.distance_transform_cdt(x.cpu().numpy(), sampling, metric=metric)


def mask_edges(
Expand Down Expand Up @@ -390,6 +389,38 @@ def surface_distance(
return dis[preds]


def edge_surface_distance(
preds: Tensor,
target: Tensor,
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
spacing: Optional[Union[Tensor, List[float]]] = None,
symmetric: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Extracts the edges from the input masks and calculates the surface distance between them.

Args:
preds: The predicted binary edge mask.
target: The target binary edge mask.
distance_metric: The distance metric to use. One of `["euclidean", "chessboard", "taxicab"]`.
spacing: The spacing between pixels along each spatial dimension.
symmetric: Whether to calculate the symmetric distance between the edges.

Returns:
A tensor with length equal to the number of edges in predictions e.g. `preds.sum()`. Each element is the
distance from the corresponding edge in `preds` to the closest edge in `target`. If `symmetric` is `True`, the
function returns a tuple containing the distances from the predicted edges to the target edges and vice versa.

"""
output = mask_edges(preds, target)
edges_preds, edges_target = output[0].bool(), output[1].bool()
if symmetric:
return (
surface_distance(edges_preds, edges_target, distance_metric=distance_metric, spacing=spacing),
surface_distance(edges_target, edges_preds, distance_metric=distance_metric, spacing=spacing),
)
return surface_distance(edges_preds, edges_target, distance_metric=distance_metric, spacing=spacing)


@functools.lru_cache
def get_neighbour_tables(
spacing: Union[Tuple[int, int], Tuple[int, int, int]], device: Optional[torch.device] = None
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore
from torchmetrics.segmentation.hausdorff_distance import HausdorffDistance
from torchmetrics.segmentation.mean_iou import MeanIoU

__all__ = ["GeneralizedDiceScore", "MeanIoU"]
__all__ = ["GeneralizedDiceScore", "MeanIoU", "HausdorffDistance"]
Loading
Loading