Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add input_format argument to segmentation metrics #2572

Merged
merged 5 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `MetricInputTransformer` wrapper ([#2392](https://github.com/Lightning-AI/torchmetrics/pull/2392))


- Added `input_format` argument to segmentation metrics ([#2572](https://github.com/Lightning-AI/torchmetrics/pull/2572))


### Changed

-
Expand Down
17 changes: 13 additions & 4 deletions src/torchmetrics/functional/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def _generalized_dice_validate_args(
include_background: bool,
per_class: bool,
weight_type: Literal["square", "simple", "linear"],
input_format: Literal["one-hot", "index"],
) -> None:
"""Validate the arguments of the metric."""
if num_classes <= 0:
Expand All @@ -37,6 +38,8 @@ def _generalized_dice_validate_args(
raise ValueError(
f"Expected argument `weight_type` to be one of 'square', 'simple', 'linear', but got {weight_type}."
)
if input_format not in ["one-hot", "index"]:
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")


def _generalized_dice_update(
Expand All @@ -45,15 +48,16 @@ def _generalized_dice_update(
num_classes: int,
include_background: bool,
weight_type: Literal["square", "simple", "linear"] = "square",
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tensor:
"""Update the state with the current prediction and target."""
_check_same_shape(preds, target)
if preds.ndim < 3:
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")

if (preds.bool() != preds).any(): # preds is an index tensor
if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
if (target.bool() != target).any(): # target is an index tensor
if input_format == "index":
Borda marked this conversation as resolved.
Show resolved Hide resolved
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if not include_background:
Expand Down Expand Up @@ -104,6 +108,7 @@ def generalized_dice_score(
include_background: bool = True,
per_class: bool = False,
weight_type: Literal["square", "simple", "linear"] = "square",
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tensor:
"""Compute the Generalized Dice Score for semantic segmentation.

Expand All @@ -114,6 +119,8 @@ def generalized_dice_score(
include_background: Whether to include the background class in the computation
per_class: Whether to compute the IoU for each class separately, else average over all classes
weight_type: Type of weight factor to apply to the classes. One of ``"square"``, ``"simple"``, or ``"linear"``
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors

Returns:
The Generalized Dice Score
Expand All @@ -133,6 +140,8 @@ def generalized_dice_score(
[0.4715, 0.4925, 0.4797, 0.5267, 0.4788]])

"""
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type)
numerator, denominator = _generalized_dice_update(preds, target, num_classes, include_background, weight_type)
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format)
numerator, denominator = _generalized_dice_update(
preds, target, num_classes, include_background, weight_type, input_format
)
return _generalized_dice_compute(numerator, denominator, per_class)
16 changes: 12 additions & 4 deletions src/torchmetrics/functional/segmentation/mean_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.segmentation.utils import _ignore_background
from torchmetrics.utilities.checks import _check_same_shape
Expand All @@ -25,6 +26,7 @@ def _mean_iou_validate_args(
num_classes: int,
include_background: bool,
per_class: bool,
input_format: Literal["one-hot", "index"] = "one-hot",
) -> None:
"""Validate the arguments of the metric."""
if num_classes <= 0:
Expand All @@ -33,20 +35,23 @@ def _mean_iou_validate_args(
raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.")
if not isinstance(per_class, bool):
raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.")
if input_format not in ["one-hot", "index"]:
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")


def _mean_iou_update(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = False,
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tuple[Tensor, Tensor]:
"""Update the intersection and union counts for the mean IoU computation."""
_check_same_shape(preds, target)

if (preds.bool() != preds).any(): # preds is an index tensor
if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
if (target.bool() != target).any(): # target is an index tensor
if input_format == "index":
Borda marked this conversation as resolved.
Show resolved Hide resolved
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if not include_background:
Expand Down Expand Up @@ -76,6 +81,7 @@ def mean_iou(
num_classes: int,
include_background: bool = True,
per_class: bool = False,
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tensor:
"""Calculates the mean Intersection over Union (mIoU) for semantic segmentation.

Expand All @@ -85,6 +91,8 @@ def mean_iou(
num_classes: Number of classes
include_background: Whether to include the background class in the computation
per_class: Whether to compute the IoU for each class separately, else average over all classes
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors

Returns:
The mean IoU score
Expand All @@ -104,6 +112,6 @@ def mean_iou(
[0.3085, 0.3267, 0.3155, 0.3575, 0.3147]])

"""
_mean_iou_validate_args(num_classes, include_background, per_class)
intersection, union = _mean_iou_update(preds, target, num_classes, include_background)
_mean_iou_validate_args(num_classes, include_background, per_class, input_format)
intersection, union = _mean_iou_update(preds, target, num_classes, include_background, input_format)
return _mean_iou_compute(intersection, union, per_class=per_class)
18 changes: 12 additions & 6 deletions src/torchmetrics/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ class GeneralizedDiceScore(Metric):

- ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. That format will be automatically
converted to a one-hot tensor.
can be provided, where the integer values correspond to the class index. The input type can be controlled
with the ``input_format`` argument.
- ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. That format will be automatically
converted to a one-hot tensor.
can be provided, where the integer values correspond to the class index. The input type can be controlled
with the ``input_format`` argument.

As output to ``forward`` and ``compute`` the metric returns the following output:

Expand All @@ -72,6 +72,8 @@ class GeneralizedDiceScore(Metric):
per_class: Whether to compute the metric for each class separately.
weight_type: The type of weight to apply to each class. Can be one of ``"square"``, ``"simple"``, or
``"linear"``.
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand All @@ -83,6 +85,8 @@ class GeneralizedDiceScore(Metric):
If ``per_class`` is not a boolean
ValueError:
If ``weight_type`` is not one of ``"square"``, ``"simple"``, or ``"linear"``
ValueError:
If ``input_format`` is not one of ``"one-hot"`` or ``"index"``

Example:
>>> import torch
Expand Down Expand Up @@ -116,14 +120,16 @@ def __init__(
include_background: bool = True,
per_class: bool = False,
weight_type: Literal["square", "simple", "linear"] = "square",
input_format: Literal["one-hot", "index"] = "one-hot",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type)
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format)
self.num_classes = num_classes
self.include_background = include_background
self.per_class = per_class
self.weight_type = weight_type
self.input_format = input_format

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum")
Expand All @@ -132,7 +138,7 @@ def __init__(
def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with new data."""
numerator, denominator = _generalized_dice_update(
preds, target, self.num_classes, self.include_background, self.weight_type
preds, target, self.num_classes, self.include_background, self.weight_type, self.input_format
)
self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0)
self.samples += preds.shape[0]
Expand Down
21 changes: 15 additions & 6 deletions src/torchmetrics/segmentation/mean_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.segmentation.mean_iou import _mean_iou_compute, _mean_iou_update, _mean_iou_validate_args
from torchmetrics.metric import Metric
Expand All @@ -36,12 +37,12 @@ class MeanIoU(Metric):

- ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. That format will be automatically
converted to a one-hot tensor.
can be provided, where the integer values correspond to the class index. The input type can be controlled
with the ``input_format`` argument.
- ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. That format will be automatically
converted to a one-hot tensor.
can be provided, where the integer values correspond to the class index. The input type can be controlled
with the ``input_format`` argument.

As output to ``forward`` and ``compute`` the metric returns the following output:

Expand All @@ -54,6 +55,8 @@ class MeanIoU(Metric):
include_background: Whether to include the background class in the computation
per_class: Whether to compute the IoU for each class separately. If set to ``False``, the metric will
compute the mean IoU over all classes.
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand All @@ -63,6 +66,8 @@ class MeanIoU(Metric):
If ``include_background`` is not a boolean
ValueError:
If ``per_class`` is not a boolean
ValueError:
If ``input_format`` is not one of ``"one-hot"`` or ``"index"``

Example:
>>> import torch
Expand Down Expand Up @@ -95,20 +100,24 @@ def __init__(
num_classes: int,
include_background: bool = True,
per_class: bool = False,
input_format: Literal["one-hot", "index"] = "one-hot",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
_mean_iou_validate_args(num_classes, include_background, per_class)
_mean_iou_validate_args(num_classes, include_background, per_class, input_format)
self.num_classes = num_classes
self.include_background = include_background
self.per_class = per_class
self.input_format = input_format

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with the new data."""
intersection, union = _mean_iou_update(preds, target, self.num_classes, self.include_background)
intersection, union = _mean_iou_update(
preds, target, self.num_classes, self.include_background, self.input_format
)
score = _mean_iou_compute(intersection, union, per_class=self.per_class)
self.score += score.mean(0) if self.per_class else score.mean()

Expand Down
44 changes: 32 additions & 12 deletions tests/unittests/segmentation/test_generalized_dice_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@
def _reference_generalized_dice(
preds: torch.Tensor,
target: torch.Tensor,
input_format: str,
include_background: bool = True,
reduce: bool = True,
):
"""Calculate reference metric for `MeanIoU`."""
if (preds.bool() != preds).any(): # preds is an index tensor
if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1)
if (target.bool() != target).any(): # target is an index tensor
if input_format == "index":
Borda marked this conversation as resolved.
Show resolved Hide resolved
target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1)
val = compute_generalized_dice(preds, target, include_background=include_background)
if reduce:
Expand All @@ -55,35 +56,54 @@ def _reference_generalized_dice(


@pytest.mark.parametrize(
"preds, target",
"preds, target, input_format",
[
(_inputs1.preds, _inputs1.target),
(_inputs2.preds, _inputs2.target),
(_inputs3.preds, _inputs3.target),
(_inputs1.preds, _inputs1.target, "one-hot"),
(_inputs2.preds, _inputs2.target, "one-hot"),
(_inputs3.preds, _inputs3.target, "index"),
],
)
@pytest.mark.parametrize("include_background", [True, False])
class TestMeanIoU(MetricTester):
"""Test class for `MeanIoU` metric."""

@pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False])
def test_mean_iou_class(self, preds, target, include_background, ddp):
def test_mean_iou_class(self, preds, target, input_format, include_background, ddp):
"""Test class implementation of metric."""
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=GeneralizedDiceScore,
reference_metric=partial(_reference_generalized_dice, include_background=include_background, reduce=True),
metric_args={"num_classes": NUM_CLASSES, "include_background": include_background},
reference_metric=partial(
_reference_generalized_dice,
input_format=input_format,
include_background=include_background,
reduce=True,
),
metric_args={
"num_classes": NUM_CLASSES,
"include_background": include_background,
"input_format": input_format,
},
)

def test_mean_iou_functional(self, preds, target, include_background):
def test_mean_iou_functional(self, preds, target, input_format, include_background):
"""Test functional implementation of metric."""
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=generalized_dice_score,
reference_metric=partial(_reference_generalized_dice, include_background=include_background, reduce=False),
metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": False},
reference_metric=partial(
_reference_generalized_dice,
input_format=input_format,
include_background=include_background,
reduce=False,
),
metric_args={
"num_classes": NUM_CLASSES,
"include_background": include_background,
"per_class": False,
"input_format": input_format,
},
)
Loading
Loading