Skip to content

Commit

Permalink
Add input_format argument to segmentation metrics (#2572)
Browse files Browse the repository at this point in the history
* add argument + change tests
* Apply suggestions from code review

---------

Co-authored-by: Daniel Stancl <[email protected]>
  • Loading branch information
SkafteNicki and stancld authored May 31, 2024
1 parent 8ca1735 commit 744905c
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 44 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `MetricInputTransformer` wrapper ([#2392](https://github.com/Lightning-AI/torchmetrics/pull/2392))


- Added `input_format` argument to segmentation metrics ([#2572](https://github.com/Lightning-AI/torchmetrics/pull/2572))


### Changed

-
Expand Down
16 changes: 12 additions & 4 deletions src/torchmetrics/functional/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def _generalized_dice_validate_args(
include_background: bool,
per_class: bool,
weight_type: Literal["square", "simple", "linear"],
input_format: Literal["one-hot", "index"],
) -> None:
"""Validate the arguments of the metric."""
if num_classes <= 0:
Expand All @@ -37,6 +38,8 @@ def _generalized_dice_validate_args(
raise ValueError(
f"Expected argument `weight_type` to be one of 'square', 'simple', 'linear', but got {weight_type}."
)
if input_format not in ["one-hot", "index"]:
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")


def _generalized_dice_update(
Expand All @@ -45,15 +48,15 @@ def _generalized_dice_update(
num_classes: int,
include_background: bool,
weight_type: Literal["square", "simple", "linear"] = "square",
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tensor:
"""Update the state with the current prediction and target."""
_check_same_shape(preds, target)
if preds.ndim < 3:
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")

if (preds.bool() != preds).any(): # preds is an index tensor
if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
if (target.bool() != target).any(): # target is an index tensor
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if not include_background:
Expand Down Expand Up @@ -104,6 +107,7 @@ def generalized_dice_score(
include_background: bool = True,
per_class: bool = False,
weight_type: Literal["square", "simple", "linear"] = "square",
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tensor:
"""Compute the Generalized Dice Score for semantic segmentation.
Expand All @@ -114,6 +118,8 @@ def generalized_dice_score(
include_background: Whether to include the background class in the computation
per_class: Whether to compute the IoU for each class separately, else average over all classes
weight_type: Type of weight factor to apply to the classes. One of ``"square"``, ``"simple"``, or ``"linear"``
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors
Returns:
The Generalized Dice Score
Expand All @@ -133,6 +139,8 @@ def generalized_dice_score(
[0.4715, 0.4925, 0.4797, 0.5267, 0.4788]])
"""
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type)
numerator, denominator = _generalized_dice_update(preds, target, num_classes, include_background, weight_type)
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format)
numerator, denominator = _generalized_dice_update(
preds, target, num_classes, include_background, weight_type, input_format
)
return _generalized_dice_compute(numerator, denominator, per_class)
15 changes: 11 additions & 4 deletions src/torchmetrics/functional/segmentation/mean_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.segmentation.utils import _ignore_background
from torchmetrics.utilities.checks import _check_same_shape
Expand All @@ -25,6 +26,7 @@ def _mean_iou_validate_args(
num_classes: int,
include_background: bool,
per_class: bool,
input_format: Literal["one-hot", "index"] = "one-hot",
) -> None:
"""Validate the arguments of the metric."""
if num_classes <= 0:
Expand All @@ -33,20 +35,22 @@ def _mean_iou_validate_args(
raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.")
if not isinstance(per_class, bool):
raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.")
if input_format not in ["one-hot", "index"]:
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")


def _mean_iou_update(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = False,
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tuple[Tensor, Tensor]:
"""Update the intersection and union counts for the mean IoU computation."""
_check_same_shape(preds, target)

if (preds.bool() != preds).any(): # preds is an index tensor
if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
if (target.bool() != target).any(): # target is an index tensor
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if not include_background:
Expand Down Expand Up @@ -76,6 +80,7 @@ def mean_iou(
num_classes: int,
include_background: bool = True,
per_class: bool = False,
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tensor:
"""Calculates the mean Intersection over Union (mIoU) for semantic segmentation.
Expand All @@ -85,6 +90,8 @@ def mean_iou(
num_classes: Number of classes
include_background: Whether to include the background class in the computation
per_class: Whether to compute the IoU for each class separately, else average over all classes
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors
Returns:
The mean IoU score
Expand All @@ -104,6 +111,6 @@ def mean_iou(
[0.3085, 0.3267, 0.3155, 0.3575, 0.3147]])
"""
_mean_iou_validate_args(num_classes, include_background, per_class)
intersection, union = _mean_iou_update(preds, target, num_classes, include_background)
_mean_iou_validate_args(num_classes, include_background, per_class, input_format)
intersection, union = _mean_iou_update(preds, target, num_classes, include_background, input_format)
return _mean_iou_compute(intersection, union, per_class=per_class)
18 changes: 12 additions & 6 deletions src/torchmetrics/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ class GeneralizedDiceScore(Metric):
- ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. That format will be automatically
converted to a one-hot tensor.
can be provided, where the integer values correspond to the class index. The input type can be controlled
with the ``input_format`` argument.
- ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. That format will be automatically
converted to a one-hot tensor.
can be provided, where the integer values correspond to the class index. The input type can be controlled
with the ``input_format`` argument.
As output to ``forward`` and ``compute`` the metric returns the following output:
Expand All @@ -72,6 +72,8 @@ class GeneralizedDiceScore(Metric):
per_class: Whether to compute the metric for each class separately.
weight_type: The type of weight to apply to each class. Can be one of ``"square"``, ``"simple"``, or
``"linear"``.
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -83,6 +85,8 @@ class GeneralizedDiceScore(Metric):
If ``per_class`` is not a boolean
ValueError:
If ``weight_type`` is not one of ``"square"``, ``"simple"``, or ``"linear"``
ValueError:
If ``input_format`` is not one of ``"one-hot"`` or ``"index"``
Example:
>>> import torch
Expand Down Expand Up @@ -116,14 +120,16 @@ def __init__(
include_background: bool = True,
per_class: bool = False,
weight_type: Literal["square", "simple", "linear"] = "square",
input_format: Literal["one-hot", "index"] = "one-hot",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type)
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format)
self.num_classes = num_classes
self.include_background = include_background
self.per_class = per_class
self.weight_type = weight_type
self.input_format = input_format

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum")
Expand All @@ -132,7 +138,7 @@ def __init__(
def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with new data."""
numerator, denominator = _generalized_dice_update(
preds, target, self.num_classes, self.include_background, self.weight_type
preds, target, self.num_classes, self.include_background, self.weight_type, self.input_format
)
self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0)
self.samples += preds.shape[0]
Expand Down
21 changes: 15 additions & 6 deletions src/torchmetrics/segmentation/mean_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.segmentation.mean_iou import _mean_iou_compute, _mean_iou_update, _mean_iou_validate_args
from torchmetrics.metric import Metric
Expand All @@ -36,12 +37,12 @@ class MeanIoU(Metric):
- ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. That format will be automatically
converted to a one-hot tensor.
can be provided, where the integer values correspond to the class index. The input type can be controlled
with the ``input_format`` argument.
- ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. That format will be automatically
converted to a one-hot tensor.
can be provided, where the integer values correspond to the class index. The input type can be controlled
with the ``input_format`` argument.
As output to ``forward`` and ``compute`` the metric returns the following output:
Expand All @@ -54,6 +55,8 @@ class MeanIoU(Metric):
include_background: Whether to include the background class in the computation
per_class: Whether to compute the IoU for each class separately. If set to ``False``, the metric will
compute the mean IoU over all classes.
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -63,6 +66,8 @@ class MeanIoU(Metric):
If ``include_background`` is not a boolean
ValueError:
If ``per_class`` is not a boolean
ValueError:
If ``input_format`` is not one of ``"one-hot"`` or ``"index"``
Example:
>>> import torch
Expand Down Expand Up @@ -95,20 +100,24 @@ def __init__(
num_classes: int,
include_background: bool = True,
per_class: bool = False,
input_format: Literal["one-hot", "index"] = "one-hot",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
_mean_iou_validate_args(num_classes, include_background, per_class)
_mean_iou_validate_args(num_classes, include_background, per_class, input_format)
self.num_classes = num_classes
self.include_background = include_background
self.per_class = per_class
self.input_format = input_format

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with the new data."""
intersection, union = _mean_iou_update(preds, target, self.num_classes, self.include_background)
intersection, union = _mean_iou_update(
preds, target, self.num_classes, self.include_background, self.input_format
)
score = _mean_iou_compute(intersection, union, per_class=self.per_class)
self.score += score.mean(0) if self.per_class else score.mean()

Expand Down
43 changes: 31 additions & 12 deletions tests/unittests/segmentation/test_generalized_dice_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@
def _reference_generalized_dice(
preds: torch.Tensor,
target: torch.Tensor,
input_format: str,
include_background: bool = True,
reduce: bool = True,
):
"""Calculate reference metric for `MeanIoU`."""
if (preds.bool() != preds).any(): # preds is an index tensor
if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1)
if (target.bool() != target).any(): # target is an index tensor
target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1)
val = compute_generalized_dice(preds, target, include_background=include_background)
if reduce:
Expand All @@ -55,35 +55,54 @@ def _reference_generalized_dice(


@pytest.mark.parametrize(
"preds, target",
"preds, target, input_format",
[
(_inputs1.preds, _inputs1.target),
(_inputs2.preds, _inputs2.target),
(_inputs3.preds, _inputs3.target),
(_inputs1.preds, _inputs1.target, "one-hot"),
(_inputs2.preds, _inputs2.target, "one-hot"),
(_inputs3.preds, _inputs3.target, "index"),
],
)
@pytest.mark.parametrize("include_background", [True, False])
class TestMeanIoU(MetricTester):
"""Test class for `MeanIoU` metric."""

@pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False])
def test_mean_iou_class(self, preds, target, include_background, ddp):
def test_mean_iou_class(self, preds, target, input_format, include_background, ddp):
"""Test class implementation of metric."""
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=GeneralizedDiceScore,
reference_metric=partial(_reference_generalized_dice, include_background=include_background, reduce=True),
metric_args={"num_classes": NUM_CLASSES, "include_background": include_background},
reference_metric=partial(
_reference_generalized_dice,
input_format=input_format,
include_background=include_background,
reduce=True,
),
metric_args={
"num_classes": NUM_CLASSES,
"include_background": include_background,
"input_format": input_format,
},
)

def test_mean_iou_functional(self, preds, target, include_background):
def test_mean_iou_functional(self, preds, target, input_format, include_background):
"""Test functional implementation of metric."""
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=generalized_dice_score,
reference_metric=partial(_reference_generalized_dice, include_background=include_background, reduce=False),
metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": False},
reference_metric=partial(
_reference_generalized_dice,
input_format=input_format,
include_background=include_background,
reduce=False,
),
metric_args={
"num_classes": NUM_CLASSES,
"include_background": include_background,
"per_class": False,
"input_format": input_format,
},
)
Loading

0 comments on commit 744905c

Please sign in to comment.