From d48dd72911f640bfb61a3d6f4247c848d830d42b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Louren=C3=A7o=20Silva?= Date: Wed, 15 Jun 2022 17:24:56 +0100 Subject: [PATCH 01/39] Added generalized dice score metric --- .../classification/generalized_dice.rst | 22 ++ src/torchmetrics/__init__.py | 2 + src/torchmetrics/classification/__init__.py | 1 + src/torchmetrics/classification/accuracy.py | 3 - src/torchmetrics/classification/dice.py | 10 +- src/torchmetrics/classification/f_beta.py | 15 + .../classification/generalized_dice.py | 169 ++++++++++ .../classification/precision_recall.py | 12 + .../classification/specificity.py | 6 + .../classification/stat_scores.py | 19 +- src/torchmetrics/functional/__init__.py | 2 + .../functional/classification/__init__.py | 1 + .../functional/classification/dice.py | 11 +- .../classification/generalized_dice.py | 247 ++++++++++++++ .../functional/classification/stat_scores.py | 19 +- src/torchmetrics/utilities/checks.py | 2 +- .../test_generalized_dice_score.py | 307 ++++++++++++++++++ tests/unittests/classification/inputs.py | 29 +- tests/unittests/classification/test_dice.py | 22 +- .../classification/test_stat_scores.py | 16 +- 20 files changed, 885 insertions(+), 30 deletions(-) create mode 100644 docs/source/classification/generalized_dice.rst create mode 100644 src/torchmetrics/classification/generalized_dice.py create mode 100644 src/torchmetrics/functional/classification/generalized_dice.py create mode 100644 tests/classification/test_generalized_dice_score.py diff --git a/docs/source/classification/generalized_dice.rst b/docs/source/classification/generalized_dice.rst new file mode 100644 index 00000000000..f662a5dfd89 --- /dev/null +++ b/docs/source/classification/generalized_dice.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Generalized Dice Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Classification + +.. include:: ../links.rst + +###################### +Generalized Dice Score +###################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.GeneralizedDiceScore + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.generalized_dice_score + :noindex: diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 549367ce4da..d6eb03dfdc3 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -36,6 +36,7 @@ Dice, F1Score, FBetaScore, + GeneralizedDiceScore, HammingDistance, HingeLoss, JaccardIndex, @@ -134,6 +135,7 @@ "ExtendedEditDistance", "F1Score", "FBetaScore", + "GeneralizedDiceScore", "HammingDistance", "HingeLoss", "JaccardIndex", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 70ae4d5179c..6e2dbefc778 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -23,6 +23,7 @@ from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401 from torchmetrics.classification.dice import Dice # noqa: F401 from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401 +from torchmetrics.classification.generalized_dice import GeneralizedDiceScore # noqa: F401 from torchmetrics.classification.hamming import HammingDistance # noqa: F401 from torchmetrics.classification.hinge import HingeLoss # noqa: F401 from torchmetrics.classification.jaccard import JaccardIndex # noqa: F401 diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 472643d3991..05e5ed88ed7 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -191,9 +191,6 @@ def __init__( **kwargs, ) - if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): - raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") - self.average = average self.threshold = threshold self.top_k = top_k diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index bc19c44dbff..bb0d34cb5e0 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -25,12 +25,12 @@ class Dice(StatScores): .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} - Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and - false positives respecitively. + Where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the numbers of + true positives, false positives and false negatives, respectively. It is recommend set `ignore_index` to index of background class. - The reduction method (how the precision scores are aggregated) is controlled by the + The reduction method (how the dice scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/classification:input types`. @@ -102,8 +102,12 @@ class Dice(StatScores): If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. ValueError: If ``average`` is set but ``num_classes`` is not provided. + ValueError: + If ``num_classes`` is set and is not larger than ``0``. ValueError: If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. + ValueError: + If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> import torch diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 11c7861ef8d..14b72e03dcb 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -107,6 +107,12 @@ class FBetaScore(StatScores): Raises: ValueError: If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``. + ValueError: + If ``num_classes`` is set and is not larger than ``0``. + ValueError: + If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. + ValueError: + If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> import torch @@ -236,6 +242,15 @@ class F1Score(FBetaScore): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + Raises: + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``. + ValueError: + If ``num_classes`` is set and is not larger than ``0``. + ValueError: + If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. + ValueError: + If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> import torch diff --git a/src/torchmetrics/classification/generalized_dice.py b/src/torchmetrics/classification/generalized_dice.py new file mode 100644 index 00000000000..42a6c66407c --- /dev/null +++ b/src/torchmetrics/classification/generalized_dice.py @@ -0,0 +1,169 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +from torch import Tensor + +from torchmetrics.classification.stat_scores import StatScores +from torchmetrics.functional.classification.generalized_dice import _generalized_dice_compute + + +class GeneralizedDiceScore(StatScores): + r"""Computes the Generalized Dice Score (GDS) metric: + + .. math:: + \text{GDS}=\sum_{i=1}^{C}\frac{2\cdot\text{TP}_i}{(2\cdot\text{TP}_i+\text{FP}_i+\text{FN}_i)\cdot w_i} + + Where :math:`\text{C}` is the number of classes and :math:`\text{TP}_i`, :math:`\text{FP}_i` and :math:`\text{FN}`_i + represent the numbers of true positives, false positives and false negatives for class :math:`i`, respectively. + :math:`w_i` represents the weight of class :math:`i`. + + The reduction method (how the generalized dice scores are aggregated) is controlled by the + ``average`` parameter. Accepts all inputs listed in :ref:`pages/classification:input types`. + Does not accept multidimensional multi-label data. + + Args: + num_classes: + Number of classes. + + weight_type: Defines the type of weighting to apply. Should be one of the following: + + - ``'square'`` [default]: Weight each class by the squared inverse of its support, + i.e., the inverse of its squared volume - :math:`\frac{1}{(tp + fn)^2}`. + - ``'simple'``: Weight each class by the inverse of its support, i.e., + the inverse of its volume - :math:`\frac{1}{tp + fn}`. + + zero_division: + The value to use for the score if denominator equals zero. If set to None, the score will be 1 if the + numerator is also 0, and 0 otherwise. Defaults to None. + + threshold: + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. + + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'samples'`` [default]: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + - ``'none'`` or ``None``: Calculate the metric for each sample separately, and return + the metric for every sample. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. + + top_k: + Number of the highest probability or logit score predictions considered finding the correct label. + The default value (``None``) will be interpreted as 1. + + multiclass: + Determines whether the input is multiclass (if True) or multilabel (if False). Defaults to True. + + multidim: + Determines whether the input is multidim or not. Defaults to True. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: + If ``weight_type`` is not ``"simple"``, ``"square"`` or ``None``. + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + ValueError: + If ``num_classes`` is not larger than ``0``. + ValueError: + If ``ignore_index`` is not in the range ``[0, num_classes)``. + ValueError: + If ``top_k`` is not an ``integer`` larger than ``0``. + + Example: + >>> import torch + >>> from torchmetrics import GeneralizedDiceScore + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> generalized_dice_score = GeneralizedDiceScore(num_classes=3) + >>> generalized_dice_score(preds, target) + tensor(0.3478) + + """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + weight_type: str = "square", + zero_division: Optional[int] = None, + threshold: float = 0.5, + average: str = "samples", + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: bool = True, + multidim: bool = True, + **kwargs: Any, + ) -> None: + allowed_weight_type = ("square", "simple", None) + if weight_type not in allowed_weight_type: + raise ValueError(f"The `weight_type` has to be one of {allowed_weight_type}, got {weight_type}.") + + allowed_average = ("samples", "none", None) + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + if ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): + raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") + + # Provide "mdmc_reduce" and "reduce" as kwargs + kwargs["mdmc_reduce"] = "samplewise" + kwargs["reduce"] = "macro" if multidim else None + + super().__init__( + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, + **kwargs, + ) + + self.multidim = multidim + self.average = average + self.weight_type = weight_type + self.zero_division = zero_division + + def compute(self) -> Tensor: + """Computes the generalized dice score based on inputs passed in to ``update`` previously. + + Return: + The shape of the returned tensor depends on the ``average`` parameter: + + - If ``average == 'samples'``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(N,)``, where ``N`` stands + for the number of samples + """ + tp, fp, _, fn = self._get_final_stats() + return _generalized_dice_compute( + tp, + fp, + fn, + average=self.average, + ignore_index=None if self.reduce is None else self.ignore_index, + weight_type=self.weight_type, + zero_division=self.zero_division, + ) diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index c0c93e68a88..7b0d84d2c83 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -95,6 +95,12 @@ class Precision(StatScores): Raises: ValueError: If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + ValueError: + If ``num_classes`` is set and is not larger than ``0``. + ValueError: + If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. + ValueError: + If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> import torch @@ -234,6 +240,12 @@ class Recall(StatScores): Raises: ValueError: If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + ValueError: + If ``num_classes`` is set and is not larger than ``0``. + ValueError: + If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. + ValueError: + If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> import torch diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index b1bfeb8badb..54bd5f0162b 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -98,6 +98,12 @@ class Specificity(StatScores): Raises: ValueError: If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + ValueError: + If ``num_classes`` is set and is not larger than ``0``. + ValueError: + If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. + ValueError: + If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> from torchmetrics import Specificity diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index eca2150d63b..06da7976eb9 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -50,6 +50,8 @@ class StatScores(Metric): Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` to be set. - ``'samples'``: Counts the statistics for each sample separately (over all classes). Each statistic is represented by a ``(N, )`` 1d tensor. + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. .. note:: What is considered a sample in the multi-dimensional multi-class case depends on the value of ``mdmc_reduce``. @@ -87,14 +89,17 @@ class StatScores(Metric): Raises: ValueError: - If ``reduce`` is none of ``"micro"``, ``"macro"`` or ``"samples"``. + If ``reduce`` is none of ``"micro"``, ``"macro"``, ``"samples"``, ``"none"`` or None. ValueError: If ``mdmc_reduce`` is none of ``None``, ``"samplewise"``, ``"global"``. ValueError: If ``reduce`` is set to ``"macro"`` and ``num_classes`` is not provided. ValueError: - If ``num_classes`` is set - and ``ignore_index`` is not in the range ``0`` <= ``ignore_index`` < ``num_classes``. + If ``num_classes`` is set and is not larger than ``0``. + ValueError: + If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. + ValueError: + If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> from torchmetrics.classification import StatScores @@ -140,7 +145,7 @@ def __init__( self.ignore_index = ignore_index self.top_k = top_k - if reduce not in ["micro", "macro", "samples"]: + if reduce not in ["micro", "macro", "samples", None]: raise ValueError(f"The `reduce` {reduce} is not valid.") if mdmc_reduce not in [None, "samplewise", "global"]: @@ -149,9 +154,15 @@ def __init__( if reduce == "macro" and (not num_classes or num_classes < 1): raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") + if num_classes and num_classes < 1: + raise ValueError("Number of classes must be larger than 0.") + if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): + raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") + default: Callable = lambda: [] reduce_fn: Optional[str] = "cat" if mdmc_reduce != "samplewise" and reduce != "samples": diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 08f0ac2e46c..bcae4b6e373 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -23,6 +23,7 @@ from torchmetrics.functional.classification.confusion_matrix import confusion_matrix from torchmetrics.functional.classification.dice import dice, dice_score from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score +from torchmetrics.functional.classification.generalized_dice import generalized_dice_score from torchmetrics.functional.classification.hamming import hamming_distance from torchmetrics.functional.classification.hinge import hinge_loss from torchmetrics.functional.classification.jaccard import jaccard_index @@ -111,6 +112,7 @@ "extended_edit_distance", "f1_score", "fbeta_score", + "generalized_dice_score", "hamming_distance", "hinge_loss", "image_gradients", diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 70f777b56e0..7414fee88b0 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -20,6 +20,7 @@ from torchmetrics.functional.classification.confusion_matrix import confusion_matrix # noqa: F401 from torchmetrics.functional.classification.dice import dice, dice_score # noqa: F401 from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score # noqa: F401 +from torchmetrics.functional.classification.generalized_dice import generalized_dice_score # noqa: F401 from torchmetrics.functional.classification.hamming import hamming_distance # noqa: F401 from torchmetrics.functional.classification.hinge import hinge_loss # noqa: F401 from torchmetrics.functional.classification.jaccard import jaccard_index # noqa: F401 diff --git a/src/torchmetrics/functional/classification/dice.py b/src/torchmetrics/functional/classification/dice.py index 441ff8e69bc..11dee726338 100644 --- a/src/torchmetrics/functional/classification/dice.py +++ b/src/torchmetrics/functional/classification/dice.py @@ -63,6 +63,7 @@ def dice_score( Tensor containing dice score Example: + >>> import torch >>> from torchmetrics.functional import dice_score >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], ... [0.05, 0.85, 0.05, 0.05], @@ -123,6 +124,7 @@ def _dice_compute( ``average`` parameter) Example: + >>> import torch >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update >>> from torchmetrics.functional.classification.dice import _dice_compute >>> preds = torch.tensor([2, 0, 2, 1]) @@ -171,8 +173,8 @@ def dice( .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} - Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and - false negatives respecitively. + Where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the numbers of + true positives, false positives and false negatives, respectively. It is recommend set `ignore_index` to index of background class. @@ -259,8 +261,11 @@ def dice( If ``average`` is set but ``num_classes`` is not provided. ValueError: If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. + ValueError: + If ``top_k`` is not an integer greater than ``0``. Example: + >>> import torch >>> from torchmetrics.functional import dice >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) @@ -283,7 +288,7 @@ def dice( raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): - raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") + raise ValueError(f"The `top_k` should be an integer greater than 0, got {top_k}") preds, target = _input_squeeze(preds, target) reduce = "macro" if average in ("weighted", "none", None) else average diff --git a/src/torchmetrics/functional/classification/generalized_dice.py b/src/torchmetrics/functional/classification/generalized_dice.py new file mode 100644 index 00000000000..8be56bb275d --- /dev/null +++ b/src/torchmetrics/functional/classification/generalized_dice.py @@ -0,0 +1,247 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +import torch +from torch import Tensor + +from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update +from torchmetrics.utilities.checks import _input_squeeze + + +def _generalized_dice_compute( + tp: Tensor, + fp: Tensor, + fn: Tensor, + average: str = "samples", + weight_type: str = "square", + ignore_index: Optional[int] = None, + zero_division: Optional[int] = None, +) -> Tensor: + """Computes generalized dice score from the stat scores: true positives, false positives, false negatives. + + Args: + tp: True positives + fp: False positives + fn: False negatives + average: Defines the reduction that is applied + weight_type: Defines the type of weights applied different classes + ignore_index: Optional index of the class to ignore in the score computation + zero_division: The value to use for the score if denominator equals zero. If set to 0, score will be 1 + if the numerator is also 0, and 0 otherwise + + Example: + >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update + >>> from torchmetrics.functional.classification.generalized_dice import _generalized_dice_compute + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='macro', mdmc_reduce='samplewise') + >>> _generalized_dice_compute(tp, fp, fn, average='samples') + tensor(0.3478) + """ + # Compute ground-truth class volume and class weights + target_volume = tp + fn + if weight_type == "simple": + weights = torch.reciprocal(target_volume.float()) + elif weight_type == "square": + weights = torch.reciprocal(target_volume.float() * target_volume.float()) + elif weight_type is None: + weights = torch.ones_like(target_volume.float()) + + # Replace weights and stats for ignore_index by 0 + if ignore_index is not None: + weights[..., ignore_index] = 0 + tp[..., ignore_index] = 0 + fp[..., ignore_index] = 0 + fn[..., ignore_index] + + # Replace infinite weights for non-appearing classes by the max« weight or 0, if all weights are infinite + if weights.dim() > 1: + for sample_weights in weights: + infs = torch.isinf(sample_weights) + sample_weights[infs] = torch.max(sample_weights[~infs]) if len(sample_weights[~infs]) > 0 else 0 + else: + infs = torch.isinf(weights) + weights[infs] = torch.max(weights[~infs]) + + # Compute weighted numerator and denominator + numerator = 2 * (tp * weights).sum(dim=-1) + denominator = ((2 * tp + fp + fn) * weights).sum(dim=-1) + + # Handle zero division + denominator_zeros = denominator == 0 + denominator[denominator_zeros] = 1 + if zero_division is not None: + # If zero_division score is specified, use it as numerator and set denominator to 1 + numerator[denominator_zeros] = zero_division + else: + # If both denominator and total sample prediction volume are 0, score is 1. Otherwise 0. + pred_volume = (tp + fp).sum(dim=-1) + pred_zeros = pred_volume == 0 + numerator[denominator_zeros] = torch.where( + pred_zeros[denominator_zeros], + torch.tensor(1, device=numerator.device).float(), + torch.tensor(0, device=numerator.device).float(), + ) + + return _reduce_stat_scores( + numerator=numerator, + denominator=denominator, + weights=None, + average=average, + mdmc_average=None, + ) + + +def generalized_dice_score( + preds: Tensor, + target: Tensor, + weight_type: str = "square", + zero_division: Optional[int] = None, + average: str = "samples", + threshold: float = 0.5, + top_k: Optional[int] = None, + num_classes: Optional[int] = None, + multiclass: bool = True, + multidim: bool = True, + ignore_index: Optional[int] = None, + **kwargs: Any, +) -> Tensor: + r"""Computes the Generalized Dice Score (GDS) metric: + + .. math:: + \text{GDS}=\sum_{i=1}^{C}\frac{2\cdot\text{TP}_i}{(2\cdot\text{TP}_i+\text{FP}_i+\text{FN}_i)\cdot w_i} + + Where :math:`\text{C}` is the number of classes and :math:`\text{TP}_i`, :math:`\text{FP}_i` and :math:`\text{FN}`_i + represent the numbers of true positives, false positives and false negatives for class :math:`i`, respectively. + :math:`w_i` represents the weight of class :math:`i`. + + The reduction method (how the recall scores are aggregated) is controlled by the + ``average`` parameter. Accepts all inputs listed in :ref:`pages/classification:input types`. + + Args: + preds: Predictions from model (probabilities, logits or labels). + + target: Ground truth values. + + weight_type: Defines the type of weighting to apply. Should be one of the following: + + - ``'square'`` [default]: Weight each class by the squared inverse of its support, + i.e., the inverse of its squared volume - :math:`\frac{1}{(tp + fn)^2}`. + - ``'simple'``: Weight each class by the inverse of its support, i.e., + the inverse of its volume - :math:`\frac{1}{tp + fn}`. + - ``None``: All classes are assigned unitary weight. Equivalent to dice score. + + zero_division: + The value to use for the score if denominator equals zero. If set to None, the score will be 1 if the + numerator is also 0, and 0 otherwise. Defaults to None. + + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'samples'`` [default]: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + - ``'none'`` or ``None``: Calculate the metric for each sample separately, and return + the metric for every sample. + + threshold: + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. + + top_k: + Number of the highest probability or logit score predictions considered finding the correct label. + The default value (``None``) will be interpreted as 1. + + num_classes: + Number of classes. + + multiclass: + Determines whether the input is multiclass (if True) or multilabel (if False). Defaults to True. + + multidim: + Determines whether the input is multidim or not. Defaults to True. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. + + Return: + The shape of the returned tensor depends on the ``average`` parameter + + - If ``average == 'samples'``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(N,)``, where ``N`` stands for the number of samples + + Raises: + ValueError: + If ``weight_type`` is not ``"simple"``, ``"square"`` or ``None``. + ValueError: + If ``average`` is not one of ``"samples"``, ``"none"`` or ``None``. + ValueError: + If ``num_classes`` is provided but is not an integer larger than 0. + ValueError: + If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. + ValueError: + If ``top_k`` is not an integer larger than ``0``. + + Example: + >>> import torch + >>> from torchmetrics.functional import generalized_dice_score + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> generalized_dice_score(preds, target, average='samples') + tensor(0.3478) + + """ + allowed_weight_type = ("square", "simple", None) + if weight_type not in allowed_weight_type: + raise ValueError(f"The `weight_type` has to be one of {allowed_weight_type}, got {weight_type}.") + + allowed_average = ("samples", "none", None) + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + if num_classes and num_classes < 1: + raise ValueError("Number of classes must be larger than 0.") + + if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): + raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") + + preds, target = _input_squeeze(preds, target) + + # Obtain tp, fp and fn per sample per class + reduce = "macro" if multidim else None + tp, fp, _, fn = _stat_scores_update( + preds, + target, + reduce=reduce, + mdmc_reduce="samplewise", + threshold=threshold, + num_classes=num_classes, + top_k=top_k, + multiclass=multiclass, + ignore_index=ignore_index, + ) + + return _generalized_dice_compute( + tp, + fp, + fn, + average=average, + ignore_index=None if reduce is None else ignore_index, + weight_type=weight_type, + zero_division=zero_division, + ) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index b3cb7786e49..89bc5f8bb52 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -77,6 +77,9 @@ def _stat_scores( The shape of the returned tensors depends on the shape of the inputs and the ``reduce`` parameter: + If ``reduce=None`` or ``reduce='none'``. the returned tensors have the same shape + as the input. + If inputs are of the shape ``(N, C)``, then: - If ``reduce='micro'``, the returned tensors are 1 element tensors @@ -89,7 +92,7 @@ def _stat_scores( - If ``reduce='macro'``, the returned tensors are ``(N,C)`` tensors - If ``reduce='samples'``, the returned tensors are ``(N,X)`` tensors """ - dim: Union[int, List[int]] = 1 # for "samples" + dim: Union[int, List[int]] = 1 if reduce == "micro": dim = [0, 1] if preds.ndim == 2 else [1, 2] elif reduce == "macro": @@ -98,11 +101,17 @@ def _stat_scores( true_pred, false_pred = target == preds, target != preds pos_pred, neg_pred = preds == 1, preds == 0 - tp = (true_pred * pos_pred).sum(dim=dim) - fp = (false_pred * pos_pred).sum(dim=dim) + tp = true_pred * pos_pred + fp = false_pred * pos_pred + + tn = true_pred * neg_pred + fn = false_pred * neg_pred - tn = (true_pred * neg_pred).sum(dim=dim) - fn = (false_pred * neg_pred).sum(dim=dim) + if reduce is not None and reduce != "none": + tp = tp.sum(dim=dim) + fp = fp.sum(dim=dim) + tn = tn.sum(dim=dim) + fn = fn.sum(dim=dim) return tp.long(), fp.long(), tn.long(), fn.long() diff --git a/src/torchmetrics/utilities/checks.py b/src/torchmetrics/utilities/checks.py index 54d15b3d455..a4e6af7854f 100644 --- a/src/torchmetrics/utilities/checks.py +++ b/src/torchmetrics/utilities/checks.py @@ -177,7 +177,7 @@ def _check_num_classes_ml(num_classes: int, multiclass: Optional[bool], implied_ if multiclass and num_classes != 2: raise ValueError( - "Your have set `multiclass=True`, but `num_classes` is not equal to 2." + "You have set `multiclass=True`, but `num_classes` is not equal to 2." " If you are trying to transform multi-label data to 2 class multi-dimensional" " multi-class, you should set `num_classes` to either 2 or None." ) diff --git a/tests/classification/test_generalized_dice_score.py b/tests/classification/test_generalized_dice_score.py new file mode 100644 index 00000000000..0899d456bd1 --- /dev/null +++ b/tests/classification/test_generalized_dice_score.py @@ -0,0 +1,307 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Optional + +import pytest +from tests.helpers import seed_all +from tests.helpers.testers import NUM_CLASSES, MetricTester +from torch import Tensor, isinf, max, ones_like, reciprocal, tensor, where + +# from tests.classification.inputs import _input_multilabel_multidim as _input_mlmd +# from tests.classification.inputs import _input_multilabel_multidim_logits as _input_mlmd_logits +# from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob +from tests.classification.inputs import ( # EXTRA_DIM, + _input_binary, + _input_binary_logits, + _input_binary_multidim, + _input_binary_multidim_logits, + _input_binary_multidim_prob, + _input_binary_prob, +) +from tests.classification.inputs import _input_multiclass as _input_mcls +from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits +from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class +from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc +from tests.classification.inputs import _input_multidim_multiclass_logits as _input_mdmc_logits +from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.classification.inputs import _input_multilabel as _input_mlb +from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits +from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torchmetrics import GeneralizedDiceScore +from torchmetrics.functional import generalized_dice_score +from torchmetrics.functional.classification.stat_scores import _del_column +from torchmetrics.utilities.checks import _input_format_classification + +seed_all(42) + + +def _sk_generalized_dice( + preds: Tensor, + target: Tensor, + weight_type: str, + multiclass: bool, + num_classes: int, + ignore_index: Optional[int] = None, + zero_division: Optional[int] = None, +) -> float: + """Compute generalized dice score from 1D prediction and target. + + Args: + preds: prediction tensor + target: target tensor + weight_type: type of weight to use. + multiclass: whether problem is multiclass. + num_classes: number of classes. + ignore_index: integer specifying a target class to ignore. + zero_division: The value to use for the score if denominator equals zero. If set to 0, score will be 1 + if the numerator is also 0, and 0 otherwise + Return: + Float generalized dice score + """ + sk_preds, sk_target, mode = _input_format_classification( + preds, target, multiclass=multiclass, num_classes=num_classes + ) + + if ignore_index is not None: + sk_preds = _del_column(sk_preds, ignore_index) + sk_target = _del_column(sk_target, ignore_index) + + # Compute intersection, target and prediction volumes + intersection = sk_preds * sk_target + target_volume = sk_target + pred_volume = sk_preds + volume = target_volume + pred_volume + + # Reduce over the spatial dimension, if there is one, from (N, C, X) to (N, C) + if sk_preds.ndim == 3: + intersection = intersection.sum(dim=2) + target_volume = target_volume.sum(dim=2) + pred_volume = pred_volume.sum(dim=2) + volume = volume.sum(dim=2) + + # Weight computation per sample per class + if weight_type == "simple": + weights = reciprocal(target_volume.float()) + elif weight_type == "square": + weights = reciprocal(target_volume.float() * target_volume.float()) + elif weight_type is None: + weights = ones_like(target_volume.float()) + + # Replace infinites by maximum weight value for the sample. If all weights are infinite, replace by 0 + if weights.dim() > 1: + for sample_weights in weights: + infs = isinf(sample_weights) + sample_weights[infs] = max(sample_weights[~infs]) if len(sample_weights[~infs]) > 0 else 0 + else: + infs = isinf(weights) + weights[infs] = max(weights[~infs]) + + # Reduce from (N, C) into (N) + numerator = 2 * (weights * intersection).sum(dim=-1) + denominator = (weights * volume).sum(dim=-1) + pred_volume = pred_volume.sum(dim=-1) + + # Compute score and handle zero division + score = numerator / denominator + if zero_division is None: + score = where((denominator == 0) & (pred_volume == 0), tensor(1).float(), score) + score = where((denominator == 0) & (pred_volume != 0), tensor(0).float(), score) + else: + score[denominator == 0] = zero_division + + # Return mean over samples + return score.mean() + + +@pytest.mark.parametrize( + ["pred", "target", "expected"], + [ + ([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.0), + ([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.0), + ([[1, 1], [1, 1]], [[1, 1], [0, 0]], 0.5), + ([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.0), + ], +) +def test_generalized_dice_score(pred, target, expected): + score = generalized_dice_score(tensor(pred), tensor(target)) + assert score == expected + + +@pytest.mark.parametrize( + "preds, target, multiclass, multidim, num_classes", + [ + (_input_binary_multidim.preds, _input_binary_multidim.target, True, True, 2), + (_input_binary_multidim_logits.preds, _input_binary_multidim_logits.target, True, True, 2), + (_input_binary_multidim_prob.preds, _input_binary_multidim_prob.target, True, True, 2), + (_input_binary.preds, _input_binary.target, True, False, 2), + (_input_binary_logits.preds, _input_binary_logits.target, True, False, 2), + (_input_binary_prob.preds, _input_binary_prob.target, True, False, 2), + ], +) +@pytest.mark.parametrize("zero_division", [None, 0, 1]) +@pytest.mark.parametrize("ignore_index", [None, 0]) +@pytest.mark.parametrize("weight_type", ["simple", "square", None]) +class TestGeneralizedDiceBinary(MetricTester): + @pytest.mark.parametrize("ddp", [False]) + @pytest.mark.parametrize("dist_sync_on_step", [False]) + def test_generalized_dice_class( + self, + ddp, + dist_sync_on_step, + preds, + target, + multiclass, + multidim, + weight_type, + ignore_index, + num_classes, + zero_division, + ): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=GeneralizedDiceScore, + sk_metric=partial( + _sk_generalized_dice, + weight_type=weight_type, + ignore_index=ignore_index, + multiclass=multiclass, + num_classes=num_classes, + zero_division=zero_division, + ), + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "ignore_index": ignore_index, + "weight_type": weight_type, + "multiclass": multiclass, + "multidim": multidim, + "num_classes": num_classes, + "zero_division": zero_division, + }, + ) + + def test_generalized_dice_fn( + self, preds, target, multiclass, multidim, weight_type, ignore_index, num_classes, zero_division + ): + self.run_functional_metric_test( + preds, + target, + metric_functional=generalized_dice_score, + sk_metric=partial( + _sk_generalized_dice, + weight_type=weight_type, + ignore_index=ignore_index, + multiclass=multiclass, + num_classes=num_classes, + zero_division=zero_division, + ), + metric_args={ + "ignore_index": ignore_index, + "weight_type": weight_type, + "multiclass": multiclass, + "multidim": multidim, + "num_classes": num_classes, + "zero_division": zero_division, + }, + ) + + +@pytest.mark.parametrize( + "preds, target, multiclass, multidim, num_classes", + [ + (_input_mcls.preds, _input_mcls.target, True, False, NUM_CLASSES), + (_input_mcls_logits.preds, _input_mcls_logits.target, True, False, NUM_CLASSES), + (_input_mcls_prob.preds, _input_mcls_prob.target, True, False, NUM_CLASSES), + (_input_mdmc.preds, _input_mdmc.target, True, True, NUM_CLASSES), + (_input_mdmc_logits.preds, _input_mdmc_logits.target, True, True, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, True, True, NUM_CLASSES), + (_input_miss_class.preds, _input_miss_class.target, True, False, NUM_CLASSES), + (_input_mlb.preds, _input_mlb.target, False, False, NUM_CLASSES), + (_input_mlb_logits.preds, _input_mlb_logits.target, False, False, NUM_CLASSES), + (_input_mlb_prob.preds, _input_mlb_prob.target, False, False, NUM_CLASSES), + # (_input_mlmd.preds, _input_mlmd.target, True, True, NUM_CLASSES), + # (_input_mlmd_logits.preds, _input_mlmd_logits.target, False, True, NUM_CLASSES * EXTRA_DIM), + # (_input_mlmd_prob.preds, _input_mlmd_prob.target, True, True, NUM_CLASSES * EXTRA_DIM), + ], +) +@pytest.mark.parametrize("zero_division", [None, 0, 1]) +@pytest.mark.parametrize("ignore_index", [None, 0]) +@pytest.mark.parametrize("weight_type", ["simple", "square", None]) +class TestGeneralizedDiceMulti(MetricTester): + @pytest.mark.parametrize("ddp", [False]) + @pytest.mark.parametrize("dist_sync_on_step", [False]) + def test_generalized_dice_class( + self, + ddp, + dist_sync_on_step, + preds, + target, + multiclass, + multidim, + weight_type, + ignore_index, + num_classes, + zero_division, + ): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=GeneralizedDiceScore, + sk_metric=partial( + _sk_generalized_dice, + weight_type=weight_type, + ignore_index=ignore_index, + multiclass=multiclass, + num_classes=num_classes, + zero_division=zero_division, + ), + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "ignore_index": ignore_index, + "weight_type": weight_type, + "multiclass": multiclass, + "multidim": multidim, + "num_classes": num_classes, + "zero_division": zero_division, + }, + ) + + def test_generalized_dice_fn( + self, preds, target, multiclass, multidim, weight_type, ignore_index, num_classes, zero_division + ): + self.run_functional_metric_test( + preds, + target, + metric_functional=generalized_dice_score, + sk_metric=partial( + _sk_generalized_dice, + weight_type=weight_type, + ignore_index=ignore_index, + multiclass=multiclass, + num_classes=num_classes, + zero_division=zero_division, + ), + metric_args={ + "ignore_index": ignore_index, + "weight_type": weight_type, + "multiclass": multiclass, + "multidim": multidim, + "num_classes": num_classes, + "zero_division": zero_division, + }, + ) diff --git a/tests/unittests/classification/inputs.py b/tests/unittests/classification/inputs.py index ff88b452638..fd6337ae2ca 100644 --- a/tests/unittests/classification/inputs.py +++ b/tests/unittests/classification/inputs.py @@ -26,15 +26,30 @@ preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) ) +_input_binary_multidim_prob = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), +) + _input_binary = Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)).float(), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ) +_input_binary_multidim = Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)).float(), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), +) + _input_binary_logits = Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) ) +_input_binary_multidim_logits = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), +) + _input_multilabel_prob = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), @@ -50,8 +65,13 @@ target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ) +_input_multilabel_multidim_logits = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), +) + _input_multilabel = Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)).float(), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ) @@ -82,6 +102,7 @@ target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ) +__mdmc_logits = 10 * torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM) __mdmc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM) __mdmc_prob_preds = __mdmc_prob_preds / __mdmc_prob_preds.sum(dim=2, keepdim=True) @@ -89,6 +110,10 @@ preds=__mdmc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) +_input_multidim_multiclass_logits = Input( + preds=__mdmc_logits, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) + _input_multidim_multiclass = Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), diff --git a/tests/unittests/classification/test_dice.py b/tests/unittests/classification/test_dice.py index c1f1d49de7a..cb4b9224f7c 100644 --- a/tests/unittests/classification/test_dice.py +++ b/tests/unittests/classification/test_dice.py @@ -23,14 +23,25 @@ from torchmetrics.functional.classification.stat_scores import _del_column from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType -from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob +from unittests.classification.inputs import ( + _input_binary, + _input_binary_logits, + _input_binary_multidim, + _input_binary_multidim_logits, + _input_binary_multidim_prob, + _input_binary_prob, +) from unittests.classification.inputs import _input_multiclass as _input_mcls from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from unittests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class +from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc +from unittests.classification.inputs import _input_multidim_multiclass_logits as _input_mdmc_logits +from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from unittests.classification.inputs import _input_multilabel as _input_mlb from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from unittests.classification.inputs import _input_multilabel_multidim as _input_mlmd +from unittests.classification.inputs import _input_multilabel_multidim_logits as _input_mlmd_logits from unittests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from unittests.helpers import seed_all @@ -99,6 +110,9 @@ def test_dice(pred, target, expected): (_input_binary.preds, _input_binary.target), (_input_binary_logits.preds, _input_binary_logits.target), (_input_binary_prob.preds, _input_binary_prob.target), + (_input_binary_multidim.preds, _input_binary_multidim.target), + (_input_binary_multidim_logits.preds, _input_binary_multidim_logits.target), + (_input_binary_multidim_prob.preds, _input_binary_multidim_prob.target), ], ) @pytest.mark.parametrize("ignore_index", [None]) @@ -132,12 +146,16 @@ def test_dice_fn(self, preds, target, ignore_index): (_input_mcls.preds, _input_mcls.target), (_input_mcls_logits.preds, _input_mcls_logits.target), (_input_mcls_prob.preds, _input_mcls_prob.target), + (_input_mdmc.preds, _input_mdmc.target), + (_input_mdmc_logits.preds, _input_mdmc_logits.target), + (_input_mdmc_prob.preds, _input_mdmc_prob.target), (_input_miss_class.preds, _input_miss_class.target), (_input_mlb.preds, _input_mlb.target), (_input_mlb_logits.preds, _input_mlb_logits.target), + (_input_mlb_prob.preds, _input_mlb_prob.target), (_input_mlmd.preds, _input_mlmd.target), + (_input_mlmd_logits.preds, _input_mlmd_logits.target), (_input_mlmd_prob.preds, _input_mlmd_prob.target), - (_input_mlb_prob.preds, _input_mlb_prob.target), ], ) @pytest.mark.parametrize("ignore_index", [None, 0]) diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 75e52f66b4a..f19ab810a43 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -20,7 +20,7 @@ from sklearn.metrics import multilabel_confusion_matrix from torch import Tensor, tensor -from torchmetrics import Accuracy, Dice, FBetaScore, Precision, Recall, Specificity, StatScores +from torchmetrics import Accuracy, Dice, FBetaScore, GeneralizedDiceScore, Precision, Recall, Specificity, StatScores from torchmetrics.functional import stat_scores from torchmetrics.utilities.checks import _input_format_classification from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob, _input_multiclass @@ -331,15 +331,17 @@ def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Ten @pytest.mark.parametrize( "metric_args", [ - {"reduce": "micro"}, + {"num_classes": 1, "reduce": "micro"}, {"num_classes": 1, "reduce": "macro"}, - {"reduce": "samples"}, - {"mdmc_reduce": None}, - {"mdmc_reduce": "samplewise"}, - {"mdmc_reduce": "global"}, + {"num_classes": 1, "reduce": "samples"}, + {"num_classes": 1, "mdmc_reduce": None}, + {"num_classes": 1, "mdmc_reduce": "samplewise"}, + {"num_classes": 1, "mdmc_reduce": "global"}, ], ) -@pytest.mark.parametrize("metric_cls", [Accuracy, Dice, FBetaScore, Precision, Recall, Specificity]) +@pytest.mark.parametrize( + "metric_cls", [Accuracy, Dice, FBetaScore, GeneralizedDiceScore, Precision, Recall, Specificity] +) def test_provide_superclass_kwargs(metric_cls: StatScores, metric_args: Dict[str, Any]): """Test instantiating subclasses with superclass arguments as kwargs.""" metric_cls(**metric_args) From d6dfe9935a8294aa77095788c266719c6ef47d97 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 29 Jun 2022 14:00:23 +0200 Subject: [PATCH 02/39] Apply suggestions from code review --- src/torchmetrics/classification/generalized_dice.py | 6 +++--- .../functional/classification/generalized_dice.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/classification/generalized_dice.py b/src/torchmetrics/classification/generalized_dice.py index 42a6c66407c..34aabf0821c 100644 --- a/src/torchmetrics/classification/generalized_dice.py +++ b/src/torchmetrics/classification/generalized_dice.py @@ -46,7 +46,7 @@ class GeneralizedDiceScore(StatScores): zero_division: The value to use for the score if denominator equals zero. If set to None, the score will be 1 if the - numerator is also 0, and 0 otherwise. Defaults to None. + numerator is also 0, and 0 otherwise. threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case @@ -69,10 +69,10 @@ class GeneralizedDiceScore(StatScores): The default value (``None``) will be interpreted as 1. multiclass: - Determines whether the input is multiclass (if True) or multilabel (if False). Defaults to True. + Determines whether the input is multiclass (if True) or multilabel (if False). multidim: - Determines whether the input is multidim or not. Defaults to True. + Determines whether the input is multidim or not. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. diff --git a/src/torchmetrics/functional/classification/generalized_dice.py b/src/torchmetrics/functional/classification/generalized_dice.py index 8be56bb275d..c56fd08bfbd 100644 --- a/src/torchmetrics/functional/classification/generalized_dice.py +++ b/src/torchmetrics/functional/classification/generalized_dice.py @@ -145,7 +145,7 @@ def generalized_dice_score( zero_division: The value to use for the score if denominator equals zero. If set to None, the score will be 1 if the - numerator is also 0, and 0 otherwise. Defaults to None. + numerator is also 0, and 0 otherwise. average: Defines the reduction that is applied. Should be one of the following: @@ -167,10 +167,10 @@ def generalized_dice_score( Number of classes. multiclass: - Determines whether the input is multiclass (if True) or multilabel (if False). Defaults to True. + Determines whether the input is multiclass (if True) or multilabel (if False). multidim: - Determines whether the input is multidim or not. Defaults to True. + Determines whether the input is multidim or not. ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute From a5985c7d4cd4c171d6f83e1b6c89a58324eebdbc Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 29 Jun 2022 14:03:04 +0200 Subject: [PATCH 03/39] move --- .../test_generalized_dice_score.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) rename tests/{ => unittests}/classification/test_generalized_dice_score.py (88%) diff --git a/tests/classification/test_generalized_dice_score.py b/tests/unittests/classification/test_generalized_dice_score.py similarity index 88% rename from tests/classification/test_generalized_dice_score.py rename to tests/unittests/classification/test_generalized_dice_score.py index 0899d456bd1..d410f1760a0 100644 --- a/tests/classification/test_generalized_dice_score.py +++ b/tests/unittests/classification/test_generalized_dice_score.py @@ -15,14 +15,17 @@ from typing import Optional import pytest -from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, MetricTester from torch import Tensor, isinf, max, ones_like, reciprocal, tensor, where -# from tests.classification.inputs import _input_multilabel_multidim as _input_mlmd -# from tests.classification.inputs import _input_multilabel_multidim_logits as _input_mlmd_logits -# from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from tests.classification.inputs import ( # EXTRA_DIM, +from torchmetrics import GeneralizedDiceScore +from torchmetrics.functional import generalized_dice_score +from torchmetrics.functional.classification.stat_scores import _del_column +from torchmetrics.utilities.checks import _input_format_classification + +# from unittests.classification.inputs import _input_multilabel_multidim as _input_mlmd +# from unittests.classification.inputs import _input_multilabel_multidim_logits as _input_mlmd_logits +# from unittests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob +from unittests.classification.inputs import ( # EXTRA_DIM, _input_binary, _input_binary_logits, _input_binary_multidim, @@ -30,20 +33,18 @@ _input_binary_multidim_prob, _input_binary_prob, ) -from tests.classification.inputs import _input_multiclass as _input_mcls -from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class -from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from tests.classification.inputs import _input_multidim_multiclass_logits as _input_mdmc_logits -from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.classification.inputs import _input_multilabel as _input_mlb -from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from torchmetrics import GeneralizedDiceScore -from torchmetrics.functional import generalized_dice_score -from torchmetrics.functional.classification.stat_scores import _del_column -from torchmetrics.utilities.checks import _input_format_classification +from unittests.classification.inputs import _input_multiclass as _input_mcls +from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits +from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from unittests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class +from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc +from unittests.classification.inputs import _input_multidim_multiclass_logits as _input_mdmc_logits +from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from unittests.classification.inputs import _input_multilabel as _input_mlb +from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits +from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from unittests.helpers import seed_all +from unittests.helpers.testers import NUM_CLASSES, MetricTester seed_all(42) From 62662514fc4783b66341556eb252250b18ab4480 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Oct 2022 10:31:33 +0000 Subject: [PATCH 04/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/classification/generalized_dice.py | 1 - src/torchmetrics/functional/classification/generalized_dice.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/torchmetrics/classification/generalized_dice.py b/src/torchmetrics/classification/generalized_dice.py index 34aabf0821c..d9a5ffcc1ac 100644 --- a/src/torchmetrics/classification/generalized_dice.py +++ b/src/torchmetrics/classification/generalized_dice.py @@ -96,7 +96,6 @@ class GeneralizedDiceScore(StatScores): >>> generalized_dice_score = GeneralizedDiceScore(num_classes=3) >>> generalized_dice_score(preds, target) tensor(0.3478) - """ is_differentiable: bool = False higher_is_better: bool = True diff --git a/src/torchmetrics/functional/classification/generalized_dice.py b/src/torchmetrics/functional/classification/generalized_dice.py index c56fd08bfbd..fe18d48f5ed 100644 --- a/src/torchmetrics/functional/classification/generalized_dice.py +++ b/src/torchmetrics/functional/classification/generalized_dice.py @@ -201,7 +201,6 @@ def generalized_dice_score( >>> target = torch.tensor([1, 1, 2, 0]) >>> generalized_dice_score(preds, target, average='samples') tensor(0.3478) - """ allowed_weight_type = ("square", "simple", None) if weight_type not in allowed_weight_type: From fe8be53a3428801edce6b72472933a8e1b07b4c1 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 5 Nov 2022 19:41:05 +0100 Subject: [PATCH 05/39] fix integration testing --- tests/integrations/test_lightning.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index 7c06810c1ce..faf3304bd82 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -20,7 +20,8 @@ from integrations.helpers import no_warning_call from integrations.lightning.boring_model import BoringModel, RandomDataset -from torchmetrics import Accuracy, AveragePrecision, MetricCollection, SumMetric +from torchmetrics import MetricCollection, SumMetric +from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision class DiffMetric(SumMetric): @@ -73,9 +74,9 @@ def __init__(self): self.layer = torch.nn.Linear(32, 1) for stage in ["train", "val", "test"]: - acc = Accuracy() + acc = BinaryAccuracy() acc.reset = mock.Mock(side_effect=acc.reset) - ap = AveragePrecision(num_classes=1, pos_label=1) + ap = BinaryAveragePrecision() ap.reset = mock.Mock(side_effect=ap.reset) self.add_module(f"acc_{stage}", acc) self.add_module(f"ap_{stage}", ap) From c49d6c11c7f60b763f992d4c884e27a5ef725453 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sun, 6 Nov 2022 12:23:02 +0100 Subject: [PATCH 06/39] revert --- src/torchmetrics/classification/accuracy.py | 18 +++---- src/torchmetrics/classification/f_beta.py | 33 +++--------- .../classification/precision_recall.py | 33 +++--------- .../classification/specificity.py | 15 ++---- .../classification/stat_scores.py | 50 ++++++++--------- .../functional/classification/stat_scores.py | 54 ++++++++++--------- tests/integrations/test_lightning.py | 9 ++-- 7 files changed, 77 insertions(+), 135 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 978c25f4395..25cc3e2b45b 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -183,7 +183,7 @@ class MulticlassAccuracy(MulticlassStatScores): Example (preds is float tensor): >>> from torchmetrics.classification import MulticlassAccuracy - >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], @@ -350,8 +350,6 @@ class Accuracy(StatScores): changed to subset accuracy (which requires all labels or sub-samples in the sample to be correctly predicted) by setting ``subset_accuracy=True``. - Accepts all input types listed in :ref:`pages/classification:input types`. - Args: num_classes: Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. @@ -387,11 +385,10 @@ class Accuracy(StatScores): - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`pages/classification:input types`) are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -409,9 +406,7 @@ class Accuracy(StatScores): multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. subset_accuracy: Whether to compute subset accuracy for multi-label and multi-dimensional @@ -541,6 +536,9 @@ def __init__( **kwargs, ) + if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): + raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") + self.average = average self.threshold = threshold self.top_k = top_k @@ -554,9 +552,7 @@ def __init__( self.add_state("total", default=tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. See - :ref:`pages/classification:input types` for more information on input - types. + """Update state with predictions and targets. Args: preds: Predictions from model (logits, probabilities, or labels) diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 6048fb7764d..19336dcc3f7 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -539,8 +539,8 @@ class MulticlassF1Score(MulticlassFBetaScore): Example (preds is float tensor): >>> from torchmetrics.classification import MulticlassF1Score - >>> target = target = torch.tensor([2, 1, 0, 0]) - >>> preds = preds = torch.tensor([ + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], @@ -773,11 +773,10 @@ class FBetaScore(StatScores): - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`pages/classification:input types`) are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -795,21 +794,13 @@ class FBetaScore(StatScores): multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``. - ValueError: - If ``num_classes`` is set and is not larger than ``0``. - ValueError: - If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> import torch @@ -964,11 +955,10 @@ class F1Score(FBetaScore): - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`pages/classification:input types`) are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -985,21 +975,10 @@ class F1Score(FBetaScore): multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - Raises: - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``. - ValueError: - If ``num_classes`` is set and is not larger than ``0``. - ValueError: - If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> import torch diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 8ee5f00a9a2..21d20fd1401 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -177,7 +177,7 @@ class MulticlassPrecision(MulticlassStatScores): Example (preds is float tensor): >>> from torchmetrics.classification import MulticlassPrecision - >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], @@ -623,7 +623,7 @@ class Precision(StatScores): The reduction method (how the precision scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/classification:input types`. + multi-dimensional multi-class case. Args: num_classes: @@ -657,11 +657,11 @@ class Precision(StatScores): - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`pages/classification:input types`) are flattened into a new ``N_X`` sample axis, i.e. + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -678,21 +678,13 @@ class Precision(StatScores): multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. - ValueError: - If ``num_classes`` is set and is not larger than ``0``. - ValueError: - If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> import torch @@ -819,7 +811,7 @@ class Recall(StatScores): The reduction method (how the recall scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/classification:input types`. + multi-dimensional multi-class case. Args: num_classes: @@ -852,11 +844,10 @@ class Recall(StatScores): - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`pages/classification:input types`) are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -874,21 +865,13 @@ class Recall(StatScores): multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. - ValueError: - If ``num_classes`` is set and is not larger than ``0``. - ValueError: - If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> import torch diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index f8bb05168d5..0e618c5fccb 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -314,7 +314,7 @@ class Specificity(StatScores): The reduction method (how the specificity scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/classification:input types`. + multi-dimensional multi-class case. Args: num_classes: @@ -348,11 +348,10 @@ class Specificity(StatScores): - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`pages/classification:input types`) are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -371,21 +370,13 @@ class Specificity(StatScores): multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. - ValueError: - If ``num_classes`` is set and is not larger than ``0``. - ValueError: - If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> from torchmetrics import Specificity diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index cc72457e6be..86f1e3fc714 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -44,7 +44,11 @@ class _AbstractStatScores(Metric): # define common functions - def _create_state(self, size: int, multidim_average: str) -> None: + def _create_state( + self, + size: int, + multidim_average: Literal["global", "samplewise"] = "global", + ) -> None: """Initialize the states for the different statistics.""" default: Union[Callable[[], list], Callable[[], Tensor]] if multidim_average == "samplewise": @@ -53,6 +57,7 @@ def _create_state(self, size: int, multidim_average: str) -> None: else: default = lambda: torch.zeros(size, dtype=torch.long) dist_reduce_fx = "sum" + self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) self.add_state("tn", default(), dist_reduce_fx=dist_reduce_fx) @@ -159,7 +164,7 @@ def __init__( self.ignore_index = ignore_index self.validate_args = validate_args - self._create_state(1, multidim_average) + self._create_state(size=1, multidim_average=multidim_average) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. @@ -243,8 +248,8 @@ class MulticlassStatScores(_AbstractStatScores): Example (preds is float tensor): >>> from torchmetrics.classification import MulticlassStatScores - >>> target = target = torch.tensor([2, 1, 0, 0]) - >>> preds = preds = torch.tensor([ + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], @@ -300,7 +305,9 @@ def __init__( self.ignore_index = ignore_index self.validate_args = validate_args - self._create_state(num_classes, multidim_average) + self._create_state( + size=1 if (average == "micro" and top_k == 1) else num_classes, multidim_average=multidim_average + ) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. @@ -315,7 +322,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore ) preds, target = _multiclass_stat_scores_format(preds, target, self.top_k) tp, fp, tn, fn = _multiclass_stat_scores_update( - preds, target, self.num_classes, self.top_k, self.multidim_average, self.ignore_index + preds, target, self.num_classes, self.top_k, self.average, self.multidim_average, self.ignore_index ) self._update_state(tp, fp, tn, fn) @@ -448,7 +455,7 @@ def __init__( self.ignore_index = ignore_index self.validate_args = validate_args - self._create_state(num_labels, multidim_average) + self._create_state(size=num_labels, multidim_average=multidim_average) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. @@ -503,7 +510,7 @@ class StatScores(Metric): ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the multi-dimensional multi-class case. - Accepts all inputs listed in :ref:`pages/classification:input types`. + Args: threshold: @@ -524,8 +531,6 @@ class StatScores(Metric): Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` to be set. - ``'samples'``: Counts the statistics for each sample separately (over all classes). Each statistic is represented by a ``(N, )`` 1d tensor. - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. .. note:: What is considered a sample in the multi-dimensional multi-class case depends on the value of ``mdmc_reduce``. @@ -541,7 +546,7 @@ class StatScores(Metric): mdmc_reduce: Defines how the multi-dimensional multi-class inputs are handeled. Should be one of the following: - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class (see :ref:`pages/classification:input types` for the definition of input types). + multi-class. - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then the outputs are concatenated together. In each @@ -555,25 +560,20 @@ class StatScores(Metric): multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: - If ``reduce`` is none of ``"micro"``, ``"macro"``, ``"samples"``, ``"none"`` or None. + If ``reduce`` is none of ``"micro"``, ``"macro"`` or ``"samples"``. ValueError: If ``mdmc_reduce`` is none of ``None``, ``"samplewise"``, ``"global"``. ValueError: If ``reduce`` is set to ``"macro"`` and ``num_classes`` is not provided. ValueError: - If ``num_classes`` is set and is not larger than ``0``. - ValueError: - If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. + If ``num_classes`` is set + and ``ignore_index`` is not in the range ``0`` <= ``ignore_index`` < ``num_classes``. Example: >>> from torchmetrics.classification import StatScores @@ -665,7 +665,7 @@ def __init__( self.ignore_index = ignore_index self.top_k = top_k - if reduce not in ["micro", "macro", "samples", None]: + if reduce not in ["micro", "macro", "samples"]: raise ValueError(f"The `reduce` {reduce} is not valid.") if mdmc_reduce not in [None, "samplewise", "global"]: @@ -674,15 +674,9 @@ def __init__( if reduce == "macro" and (not num_classes or num_classes < 1): raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") - if num_classes and num_classes < 1: - raise ValueError("Number of classes must be larger than 0.") - if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): - raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") - default: Callable = lambda: [] reduce_fn: Optional[str] = "cat" if mdmc_reduce != "samplewise" and reduce != "samples": @@ -701,8 +695,6 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. - See :ref:`pages/classification:input types` for more information on input types. - Args: preds: Predictions from model (probabilities, logits or labels) target: Ground truth values diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 891a7df9302..22d8de769e4 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -18,7 +18,7 @@ from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification -from torchmetrics.utilities.data import _bincount, _movedim, select_topk +from torchmetrics.utilities.data import _bincount, select_topk from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod from torchmetrics.utilities.prints import rank_zero_warn @@ -351,6 +351,7 @@ def _multiclass_stat_scores_update( target: Tensor, num_classes: int, top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -373,7 +374,7 @@ def _multiclass_stat_scores_update( target[idx] = num_classes if top_k > 1: - preds_oh = _movedim(select_topk(preds, topk=top_k, dim=1), 1, -1) + preds_oh = torch.movedim(select_topk(preds, topk=top_k, dim=1), 1, -1) else: preds_oh = torch.nn.functional.one_hot( preds, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes @@ -393,7 +394,17 @@ def _multiclass_stat_scores_update( fn = ((target_oh != preds_oh) & (target_oh == 1)).sum(sum_dim) fp = ((target_oh != preds_oh) & (target_oh == 0)).sum(sum_dim) tn = ((target_oh == preds_oh) & (target_oh == 0)).sum(sum_dim) - return tp, fp, tn, fn + elif average == "micro": + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + tp = (preds == target).sum() + fp = (preds != target).sum() + fn = (preds != target).sum() + tn = num_classes * preds.numel() - (fp + fn + tp) else: preds = preds.flatten() target = target.flatten() @@ -408,7 +419,7 @@ def _multiclass_stat_scores_update( fp = confmat.sum(0) - tp fn = confmat.sum(1) - tp tn = confmat.sum() - (fp + fn + tp) - return tp, fp, tn, fn + return tp, fp, tn, fn def _multiclass_stat_scores_compute( @@ -426,8 +437,8 @@ def _multiclass_stat_scores_compute( res = torch.stack([tp, fp, tn, fn, tp + fn], dim=-1) sum_dim = 0 if multidim_average == "global" else 1 if average == "micro": - return res.sum(sum_dim) - elif average == "macro": + return res.sum(sum_dim) if res.ndim > 1 else res + if average == "macro": return res.float().mean(sum_dim) elif average == "weighted": weight = tp + fn @@ -549,7 +560,9 @@ def multiclass_stat_scores( _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) preds, target = _multiclass_stat_scores_format(preds, target, top_k) - tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, num_classes, top_k, multidim_average, ignore_index) + tp, fp, tn, fn = _multiclass_stat_scores_update( + preds, target, num_classes, top_k, average, multidim_average, ignore_index + ) return _multiclass_stat_scores_compute(tp, fp, tn, fn, average, multidim_average) @@ -864,9 +877,6 @@ def _stat_scores( The shape of the returned tensors depends on the shape of the inputs and the ``reduce`` parameter: - If ``reduce=None`` or ``reduce='none'``. the returned tensors have the same shape - as the input. - If inputs are of the shape ``(N, C)``, then: - If ``reduce='micro'``, the returned tensors are 1 element tensors @@ -879,7 +889,7 @@ def _stat_scores( - If ``reduce='macro'``, the returned tensors are ``(N,C)`` tensors - If ``reduce='samples'``, the returned tensors are ``(N,X)`` tensors """ - dim: Union[int, List[int]] = 1 + dim: Union[int, List[int]] = 1 # for "samples" if reduce == "micro": dim = [0, 1] if preds.ndim == 2 else [1, 2] elif reduce == "macro": @@ -888,17 +898,11 @@ def _stat_scores( true_pred, false_pred = target == preds, target != preds pos_pred, neg_pred = preds == 1, preds == 0 - tp = true_pred * pos_pred - fp = false_pred * pos_pred - - tn = true_pred * neg_pred - fn = false_pred * neg_pred + tp = (true_pred * pos_pred).sum(dim=dim) + fp = (false_pred * pos_pred).sum(dim=dim) - if reduce is not None and reduce != "none": - tp = tp.sum(dim=dim) - fp = fp.sum(dim=dim) - tn = tn.sum(dim=dim) - fn = fn.sum(dim=dim) + tn = (true_pred * neg_pred).sum(dim=dim) + fn = (false_pred * neg_pred).sum(dim=dim) return tp.long(), fp.long(), tn.long(), fn.long() @@ -1116,7 +1120,7 @@ def stat_scores( The reduction method (how the statistics are aggregated) is controlled by the ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/classification:input types`. + multi-dimensional multi-class case. Args: preds: Predictions from model (probabilities, logits or labels) @@ -1156,7 +1160,7 @@ def stat_scores( one of the following: - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class (see :ref:`pages/classification:input types` for the definition of input types). + multi-class. - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then the outputs are concatenated together. In each @@ -1170,9 +1174,7 @@ def stat_scores( multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. Return: The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index faf3304bd82..68dc962cc4c 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -20,8 +20,7 @@ from integrations.helpers import no_warning_call from integrations.lightning.boring_model import BoringModel, RandomDataset -from torchmetrics import MetricCollection, SumMetric -from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision +from torchmetrics import Accuracy, AveragePrecision, MetricCollection, SumMetric class DiffMetric(SumMetric): @@ -74,9 +73,9 @@ def __init__(self): self.layer = torch.nn.Linear(32, 1) for stage in ["train", "val", "test"]: - acc = BinaryAccuracy() + acc = Accuracy() acc.reset = mock.Mock(side_effect=acc.reset) - ap = BinaryAveragePrecision() + ap = AveragePrecision(num_classes=1, pos_label=1) ap.reset = mock.Mock(side_effect=ap.reset) self.add_module(f"acc_{stage}", acc) self.add_module(f"ap_{stage}", ap) @@ -189,7 +188,7 @@ def __init__(self): self.metric_epoch = SumMetric() self.sum = torch.tensor(0.0) - def on_epoch_start(self): + def on_train_epoch_start(self): self.sum = torch.tensor(0.0) def training_step(self, batch, batch_idx): From 46c7e1f034f8c2a0c7d9b3ddd55867d1957cd1e4 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sun, 6 Nov 2022 13:28:40 +0100 Subject: [PATCH 07/39] revert some more --- src/torchmetrics/classification/dice.py | 20 ++++++----------- .../functional/classification/dice.py | 20 +++++------------ src/torchmetrics/utilities/checks.py | 2 +- tests/unittests/classification/test_dice.py | 22 ++----------------- 4 files changed, 16 insertions(+), 48 deletions(-) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index d711990619a..9e025dac564 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -26,14 +26,14 @@ class Dice(StatScores): .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} - Where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the numbers of - true positives, false positives and false negatives, respectively. + Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and + false positives respecitively. It is recommend set `ignore_index` to index of background class. - The reduction method (how the dice scores are aggregated) is controlled by the + The reduction method (how the precision scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/classification:input types`. + multi-dimensional multi-class case. Args: num_classes: @@ -69,11 +69,11 @@ class Dice(StatScores): - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`pages/classification:input types`) are flattened into a new ``N_X`` sample axis, i.e. + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -90,9 +90,7 @@ class Dice(StatScores): multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -103,12 +101,8 @@ class Dice(StatScores): If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. ValueError: If ``average`` is set but ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set and is not larger than ``0``. ValueError: If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. Example: >>> import torch diff --git a/src/torchmetrics/functional/classification/dice.py b/src/torchmetrics/functional/classification/dice.py index 1c6d135f2a1..3449c182913 100644 --- a/src/torchmetrics/functional/classification/dice.py +++ b/src/torchmetrics/functional/classification/dice.py @@ -63,7 +63,6 @@ def dice_score( Tensor containing dice score Example: - >>> import torch >>> from torchmetrics.functional import dice_score >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], ... [0.05, 0.85, 0.05, 0.05], @@ -124,7 +123,6 @@ def _dice_compute( ``average`` parameter) Example: - >>> import torch >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update >>> from torchmetrics.functional.classification.dice import _dice_compute >>> preds = torch.tensor([2, 0, 2, 1]) @@ -173,14 +171,14 @@ def dice( .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} - Where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the numbers of - true positives, false positives and false negatives, respectively. + Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and + false negatives respecitively. It is recommend set `ignore_index` to index of background class. The reduction method (how the recall scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/classification:input types`. + multi-dimensional multi-class case. Args: preds: Predictions from model (probabilities, logits or labels) @@ -215,11 +213,10 @@ def dice( - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, + as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`pages/classification:input types`) are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. @@ -242,9 +239,7 @@ def dice( Should be left at default (``None``) for all other types of inputs. multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. + than what they appear to be. Return: The shape of the returned tensor depends on the ``average`` parameter @@ -261,11 +256,8 @@ def dice( If ``average`` is set but ``num_classes`` is not provided. ValueError: If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an integer greater than ``0``. Example: - >>> import torch >>> from torchmetrics.functional import dice >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) @@ -287,7 +279,7 @@ def dice( raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): - raise ValueError(f"The `top_k` should be an integer greater than 0, got {top_k}") + raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") preds, target = _input_squeeze(preds, target) reduce = "macro" if average in ("weighted", "none", None) else average diff --git a/src/torchmetrics/utilities/checks.py b/src/torchmetrics/utilities/checks.py index 4ea8d178f46..dfcb2922147 100644 --- a/src/torchmetrics/utilities/checks.py +++ b/src/torchmetrics/utilities/checks.py @@ -179,7 +179,7 @@ def _check_num_classes_ml(num_classes: int, multiclass: Optional[bool], implied_ if multiclass and num_classes != 2: raise ValueError( - "You have set `multiclass=True`, but `num_classes` is not equal to 2." + "Your have set `multiclass=True`, but `num_classes` is not equal to 2." " If you are trying to transform multi-label data to 2 class multi-dimensional" " multi-class, you should set `num_classes` to either 2 or None." ) diff --git a/tests/unittests/classification/test_dice.py b/tests/unittests/classification/test_dice.py index cb4b9224f7c..c1f1d49de7a 100644 --- a/tests/unittests/classification/test_dice.py +++ b/tests/unittests/classification/test_dice.py @@ -23,25 +23,14 @@ from torchmetrics.functional.classification.stat_scores import _del_column from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType -from unittests.classification.inputs import ( - _input_binary, - _input_binary_logits, - _input_binary_multidim, - _input_binary_multidim_logits, - _input_binary_multidim_prob, - _input_binary_prob, -) +from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob from unittests.classification.inputs import _input_multiclass as _input_mcls from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from unittests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_logits as _input_mdmc_logits -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from unittests.classification.inputs import _input_multilabel as _input_mlb from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from unittests.classification.inputs import _input_multilabel_multidim as _input_mlmd -from unittests.classification.inputs import _input_multilabel_multidim_logits as _input_mlmd_logits from unittests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from unittests.helpers import seed_all @@ -110,9 +99,6 @@ def test_dice(pred, target, expected): (_input_binary.preds, _input_binary.target), (_input_binary_logits.preds, _input_binary_logits.target), (_input_binary_prob.preds, _input_binary_prob.target), - (_input_binary_multidim.preds, _input_binary_multidim.target), - (_input_binary_multidim_logits.preds, _input_binary_multidim_logits.target), - (_input_binary_multidim_prob.preds, _input_binary_multidim_prob.target), ], ) @pytest.mark.parametrize("ignore_index", [None]) @@ -146,16 +132,12 @@ def test_dice_fn(self, preds, target, ignore_index): (_input_mcls.preds, _input_mcls.target), (_input_mcls_logits.preds, _input_mcls_logits.target), (_input_mcls_prob.preds, _input_mcls_prob.target), - (_input_mdmc.preds, _input_mdmc.target), - (_input_mdmc_logits.preds, _input_mdmc_logits.target), - (_input_mdmc_prob.preds, _input_mdmc_prob.target), (_input_miss_class.preds, _input_miss_class.target), (_input_mlb.preds, _input_mlb.target), (_input_mlb_logits.preds, _input_mlb_logits.target), - (_input_mlb_prob.preds, _input_mlb_prob.target), (_input_mlmd.preds, _input_mlmd.target), - (_input_mlmd_logits.preds, _input_mlmd_logits.target), (_input_mlmd_prob.preds, _input_mlmd_prob.target), + (_input_mlb_prob.preds, _input_mlb_prob.target), ], ) @pytest.mark.parametrize("ignore_index", [None, 0]) From c23fc005232e43f80b90376eb61fa6ef94eda045 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sun, 6 Nov 2022 13:31:55 +0100 Subject: [PATCH 08/39] changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df6488490b6..6022de2b8ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `KendallRankCorrCoef` to regression package ([#1271](https://github.com/Lightning-AI/metrics/pull/1271)) +- Added `GeneralizedDiceScore` to classification package ([#1090](https://github.com/Lightning-AI/metrics/pull/1090)) + + ### Changed - Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259)) @@ -29,7 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Deprecated -- +- Deprecated `Dice` and `dice` from classification ([#1090](https://github.com/Lightning-AI/metrics/pull/1090)) ### Removed From 1d3ac53c238bcf277547e1e3600582f05a249c94 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sun, 6 Nov 2022 14:07:54 +0100 Subject: [PATCH 09/39] deprecate dice --- src/torchmetrics/classification/dice.py | 11 +++ src/torchmetrics/classification/f_beta.py | 12 ++- .../functional/classification/dice.py | 91 ++----------------- .../functional/classification/f_beta.py | 12 ++- 4 files changed, 38 insertions(+), 88 deletions(-) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 9e025dac564..4f3bb778e6b 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -19,11 +19,16 @@ from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.dice import _dice_compute from torchmetrics.utilities.enums import AverageMethod +from torchmetrics.utilities.prints import rank_zero_warn class Dice(StatScores): r"""Computes `Dice`_: + .. deprecated:: v0.10 + The `Dice` module was deprecated in v0.10 and will be removed in v0.11. Use `F1Score` module instead which + is equivalent. + .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and @@ -129,6 +134,12 @@ def __init__( multiclass: Optional[bool] = None, **kwargs: Any, ) -> None: + rank_zero_warn( + "The `dice` function was deprecated in v0.10 and will be removed in v0.11. Use `f1score` function instead" + " which is equivalent.", + DeprecationWarning, + ) + allowed_average = ("micro", "macro", "weighted", "samples", "none", None) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 19336dcc3f7..cd43904617f 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -387,11 +387,13 @@ def compute(self) -> Tensor: class BinaryF1Score(BinaryFBetaScore): - r"""Computes F-1 score for binary tasks: + r"""Computes F-1 score (also known as Dice score/similarity) for binary tasks: .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + F-1 score correspond to equally weighted average of the precision and recall scores. + Accepts the following input tensors: - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside @@ -472,11 +474,13 @@ def __init__( class MulticlassF1Score(MulticlassFBetaScore): - r"""Computes F-1 score for multiclass tasks: + r"""Computes F-1 score (also known as Dice score/similarity) for multiclass tasks: .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + F-1 score correspond to equally weighted average of the precision and recall scores. + Accepts the following input tensors: - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point @@ -592,11 +596,13 @@ def __init__( class MultilabelF1Score(MultilabelFBetaScore): - r"""Computes F-1 score for multilabel tasks: + r"""Computes F-1 score (also known as Dice score/similarity) for multilabel tasks: .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + F-1 score correspond to equally weighted average of the precision and recall scores. + Accepts the following input tensors: - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside diff --git a/src/torchmetrics/functional/classification/dice.py b/src/torchmetrics/functional/classification/dice.py index 3449c182913..38e1f5005ad 100644 --- a/src/torchmetrics/functional/classification/dice.py +++ b/src/torchmetrics/functional/classification/dice.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Optional import torch from torch import Tensor -from typing_extensions import Literal from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update from torchmetrics.utilities.checks import _input_squeeze @@ -24,86 +22,6 @@ from torchmetrics.utilities.prints import rank_zero_warn -def dice_score( - preds: Tensor, - target: Tensor, - bg: bool = False, - nan_score: float = 0.0, - no_fg_score: float = 0.0, - reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", -) -> Tensor: - """Compute dice score from prediction scores. - - Supports only "macro" approach, which mean calculate the metric for each class separately, - and average the metrics across classes (with equal weights for each class). - - .. deprecated:: v0.9 - The `dice_score` function was deprecated in v0.9 and will be removed in v0.10. Use `dice` function instead. - - Args: - preds: estimated probabilities - target: ground-truth labels - bg: whether to also compute dice for the background - nan_score: score to return, if a NaN occurs during computation - no_fg_score: (default, ``0.0``) score to return, if no foreground pixel was found in target - - .. deprecated:: v0.9 - All different from default options will be changed to default. - - reduction: (default, ``'elementwise_mean'``) a method to reduce metric score over labels. - - .. deprecated:: v0.9 - All different from default options will be changed to default. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'`` or ``None``: no reduction will be applied - - Return: - Tensor containing dice score - - Example: - >>> from torchmetrics.functional import dice_score - >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - ... [0.05, 0.85, 0.05, 0.05], - ... [0.05, 0.05, 0.85, 0.05], - ... [0.05, 0.05, 0.05, 0.85]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> dice_score(pred, target) - tensor(0.3333) - """ - rank_zero_warn( - "The `dice_score` function was deprecated in v0.9 and will be removed in v0.10. Use `dice` function instead.", - DeprecationWarning, - ) - num_classes = preds.shape[1] - - if no_fg_score != 0.0: - no_fg_score = 0.0 - rank_zero_warn(f"Deprecated parameter. Switched to default `no_fg_score` = {no_fg_score}.") - - if reduction != "elementwise_mean": - reduction = "elementwise_mean" - rank_zero_warn(f"Deprecated parameter. Switched to default `reduction` = {reduction}.") - - zero_division = math.floor(nan_score) - if zero_division != nan_score: - rank_zero_warn(f"Deprecated parameter. `nan_score` converted to integer {zero_division}.") - - ignore_index = None - if not bg: - ignore_index = 0 - - return dice( - preds, - target, - ignore_index=ignore_index, - average="macro", - num_classes=num_classes, - zero_division=zero_division, - ) - - def _dice_compute( tp: Tensor, fp: Tensor, @@ -169,6 +87,10 @@ def dice( ) -> Tensor: r"""Computes `Dice`_: + .. deprecated:: v0.10 + The `dice` function was deprecated in v0.10 and will be removed in v0.11. Use `f1score` function instead which + is equivalent. + .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and @@ -264,6 +186,11 @@ def dice( >>> dice(preds, target, average='micro') tensor(0.2500) """ + rank_zero_warn( + "The `dice` function was deprecated in v0.10 and will be removed in v0.11. Use `f1score` function instead" + " which is equivalent.", + DeprecationWarning, + ) allowed_average = ("micro", "macro", "weighted", "samples", "none", None) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index 1cdb1261a3d..57a80066b55 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -409,11 +409,13 @@ def binary_f1_score( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Computes F-1 score for binary tasks: + r"""Computes F-1 score (also known as Dice score/similarity) for binary tasks: .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + F-1 score correspond to equally weighted average of the precision and recall scores. + Accepts the following input tensors: - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside @@ -491,11 +493,13 @@ def multiclass_f1_score( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Computes F-1 score for multiclass tasks: + r"""Computes F-1 score (also known as Dice score/similarity) for multiclass tasks: .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + F-1 score correspond to equally weighted average of the precision and recall scores. + Accepts the following input tensors: - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point @@ -601,11 +605,13 @@ def multilabel_f1_score( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Computes F-1 score for multilabel tasks: + r"""Computes F-1 score (also known as Dice score/similarity) for multilabel tasks: .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + F-1 score correspond to equally weighted average of the precision and recall scores. + Accepts the following input tensors: - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside From 427a3ff54a63002fcc8727e5352cd311840840ff Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sun, 6 Nov 2022 14:11:18 +0100 Subject: [PATCH 10/39] transfer to new format --- .../classification/generalized_dice.py | 213 ++++-- .../classification/generalized_dice.py | 176 ++++- .../test_generalized_dice_score.py | 669 ++++++++++++------ 3 files changed, 759 insertions(+), 299 deletions(-) diff --git a/src/torchmetrics/classification/generalized_dice.py b/src/torchmetrics/classification/generalized_dice.py index d9a5ffcc1ac..d77efeea312 100644 --- a/src/torchmetrics/classification/generalized_dice.py +++ b/src/torchmetrics/classification/generalized_dice.py @@ -11,15 +11,132 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Literal, Optional from torch import Tensor -from torchmetrics.classification.stat_scores import StatScores -from torchmetrics.functional.classification.generalized_dice import _generalized_dice_compute +from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores +from torchmetrics.functional.classification.generalized_dice import ( + _binary_generalized_dice_score_arg_validation, + _generalized_dice_reduce, + _multiclass_generalized_dice_score_arg_validation, + _multilabel_generalized_dice_score_arg_validation, +) +from torchmetrics.metric import Metric -class GeneralizedDiceScore(StatScores): +class BinaryGeneralizedDiceScore(BinaryStatScores): + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + weight_type: Optional[Literal["square", "simple"]] = "square", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + threshold=threshold, + multidim_average=multidim_average, + ignore_index=ignore_index, + validate_args=False, + **kwargs, + ) + if validate_args: + _binary_generalized_dice_score_arg_validation(weight_type, threshold, multidim_average, ignore_index) + self.validate_args = validate_args + self.weight_type = weight_type + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _generalized_dice_reduce( + tp, fp, tn, fn, self.weight_type, average="binary", multidim_average=self.multidim_average + ) + + +class MulticlassGeneralizedDiceScore(MulticlassStatScores): + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + weight_type: Optional[Literal["square", "simple"]] = "square", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, + top_k=top_k, + average=average, + multidim_average=multidim_average, + ignore_index=ignore_index, + validate_args=False, + **kwargs, + ) + if validate_args: + _multiclass_generalized_dice_score_arg_validation( + weight_type, num_classes, top_k, average, multidim_average, ignore_index + ) + self.validate_args = validate_args + self.weight_type = weight_type + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _generalized_dice_reduce( + tp, fp, tn, fn, self.weight_type, average=self.average, multidim_average=self.multidim_average + ) + + +class MultilabelGeneralizedDiceScore(MultilabelStatScores): + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + weight_type: Optional[Literal["square", "simple"]] = "square", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_labels=num_labels, + threshold=threshold, + average=average, + multidim_average=multidim_average, + ignore_index=ignore_index, + validate_args=False, + **kwargs, + ) + if validate_args: + _multilabel_generalized_dice_score_arg_validation( + weight_type, num_labels, threshold, average, multidim_average, ignore_index + ) + self.validate_args = validate_args + self.weight_type = weight_type + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _generalized_dice_reduce( + tp, fp, tn, fn, self.weight_type, average=self.average, multidim_average=self.multidim_average + ) + + +class GeneralizedDiceScore: r"""Computes the Generalized Dice Score (GDS) metric: .. math:: @@ -97,72 +214,34 @@ class GeneralizedDiceScore(StatScores): >>> generalized_dice_score(preds, target) tensor(0.3478) """ - is_differentiable: bool = False - higher_is_better: bool = True - full_state_update: bool = False - def __init__( - self, - num_classes: int, - weight_type: str = "square", - zero_division: Optional[int] = None, + def __new__( + cls, + num_classes: Optional[int] = None, + beta: float = 1.0, threshold: float = 0.5, - average: str = "samples", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, - multiclass: bool = True, - multidim: bool = True, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, **kwargs: Any, - ) -> None: - allowed_weight_type = ("square", "simple", None) - if weight_type not in allowed_weight_type: - raise ValueError(f"The `weight_type` has to be one of {allowed_weight_type}, got {weight_type}.") - - allowed_average = ("samples", "none", None) - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - if ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): - raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") - - # Provide "mdmc_reduce" and "reduce" as kwargs - kwargs["mdmc_reduce"] = "samplewise" - kwargs["reduce"] = "macro" if multidim else None - - super().__init__( - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - **kwargs, - ) - - self.multidim = multidim - self.average = average - self.weight_type = weight_type - self.zero_division = zero_division - - def compute(self) -> Tensor: - """Computes the generalized dice score based on inputs passed in to ``update`` previously. - - Return: - The shape of the returned tensor depends on the ``average`` parameter: - - - If ``average == 'samples'``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(N,)``, where ``N`` stands - for the number of samples - """ - tp, fp, _, fn = self._get_final_stats() - return _generalized_dice_compute( - tp, - fp, - fn, - average=self.average, - ignore_index=None if self.reduce is None else self.ignore_index, - weight_type=self.weight_type, - zero_division=self.zero_division, + ) -> Metric: + assert multidim_average is not None + kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryGeneralizedDiceScore(beta, threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassGeneralizedDiceScore(beta, num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelGeneralizedDiceScore(beta, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/generalized_dice.py b/src/torchmetrics/functional/classification/generalized_dice.py index fe18d48f5ed..3f7d6c637cc 100644 --- a/src/torchmetrics/functional/classification/generalized_dice.py +++ b/src/torchmetrics/functional/classification/generalized_dice.py @@ -11,13 +11,183 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Optional import torch from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.stat_scores import ( + _binary_stat_scores_arg_validation, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, + _multiclass_stat_scores_update, + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, + _multilabel_stat_scores_update, + _reduce_stat_scores, + _stat_scores_update, +) +from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import AverageMethod as AvgMethod +from torchmetrics.utilities.enums import MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn + + +def _generalized_dice_reduce( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + weight_type: Optional[Literal["square", "simple"]], + average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], + multidim_average: Literal["global", "samplewise"] = "global", +) -> Tensor: + target_volume = tp + fn + if weight_type == "simple": + weights = torch.reciprocal(target_volume.float()) + elif weight_type == "square": + weights = torch.reciprocal(target_volume.float() * target_volume.float()) + elif weight_type is None: + weights = torch.ones_like(target_volume.float()) + + if weights.ndim > 1: + for sample_weights in weights: + infs = torch.isinf(sample_weights) + sample_weights[infs] = torch.max(sample_weights[~infs]) if len(sample_weights[~infs]) > 0 else 0 + else: + infs = torch.isinf(weights) + weights[infs] = torch.max(weights[~infs]) + + if average == "binary": + return _safe_divide(2 * (tp + weights), 2 * tp + fp + fn) + elif average == "micro": + tp = tp.sum(dim=0 if multidim_average == "global" else 1) + fn = fn.sum(dim=0 if multidim_average == "global" else 1) + fp = fp.sum(dim=0 if multidim_average == "global" else 1) + weights = weights.sum(dim=0 if multidim_average == "global" else 1) + return _safe_divide(2 * (tp + weights), 2 * tp + fp + fn) + else: + generalized_dice_score = _safe_divide(2 * (tp + weights), 2 * tp + fp + fn) + if average is None or average == "none": + return generalized_dice_score + if average == "weighted": + weights = tp + fn + else: + weights = torch.ones_like(generalized_dice_score) + return _safe_divide(weights * generalized_dice_score, weights.sum(-1, keepdim=True)).sum(-1) + + +def _binary_generalized_dice_score_arg_validation( + weight_type: Optional[Literal["square", "simple"]], + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + allowed_weight_type = ("square", "simple", None) + if weight_type not in weight_type: + raise ValueError( + f"Argument `weight_type` needs to one of the following: {allowed_weight_type} but got {weight_type}" + ) + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + -from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update -from torchmetrics.utilities.checks import _input_squeeze +def binary_generalized_dice_score( + preds: Tensor, + target: Tensor, + weight_type: Optional[Literal["square", "simple"]], + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _binary_generalized_dice_score_arg_validation(weight_type, threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _generalized_dice_reduce(tp, fp, tn, fn, weight_type, average="binary", multidim_average=multidim_average) + + +def _multiclass_generalized_dice_score_arg_validation( + weight_type: Optional[Literal["square", "simple"]], + num_classes: int, + top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + allowed_weight_type = ("square", "simple", None) + if weight_type not in weight_type: + raise ValueError( + f"Argument `weight_type` needs to one of the following: {allowed_weight_type} but got {weight_type}" + ) + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + + +def multiclass_generalized_dice_score( + preds: Tensor, + target: Tensor, + weight_type: Optional[Literal["square", "simple"]], + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _multiclass_generalized_dice_score_arg_validation( + weight_type, num_classes, top_k, average, multidim_average, ignore_index + ) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update( + preds, target, num_classes, top_k, average, multidim_average, ignore_index + ) + return _generalized_dice_reduce(tp, fp, tn, fn, weight_type, average=average, multidim_average=multidim_average) + + +def _multilabel_generalized_dice_score_arg_validation( + weight_type: Optional[Literal["square", "simple"]], + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + allowed_weight_type = ("square", "simple", None) + if weight_type not in weight_type: + raise ValueError( + f"Argument `weight_type` needs to one of the following: {allowed_weight_type} but got {weight_type}" + ) + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + + +def multilabel_generalized_dice_score( + preds: Tensor, + target: Tensor, + weight_type: Optional[Literal["square", "simple"]], + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _multilabel_generalized_dice_score_arg_validation( + weight_type, num_labels, threshold, average, multidim_average, ignore_index + ) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _generalized_dice_reduce(tp, fp, tn, fn, weight_type, average=average, multidim_average=multidim_average) def _generalized_dice_compute( diff --git a/tests/unittests/classification/test_generalized_dice_score.py b/tests/unittests/classification/test_generalized_dice_score.py index d410f1760a0..b7b9d281609 100644 --- a/tests/unittests/classification/test_generalized_dice_score.py +++ b/tests/unittests/classification/test_generalized_dice_score.py @@ -12,297 +12,508 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Optional +import numpy as np import pytest -from torch import Tensor, isinf, max, ones_like, reciprocal, tensor, where - -from torchmetrics import GeneralizedDiceScore -from torchmetrics.functional import generalized_dice_score -from torchmetrics.functional.classification.stat_scores import _del_column -from torchmetrics.utilities.checks import _input_format_classification - -# from unittests.classification.inputs import _input_multilabel_multidim as _input_mlmd -# from unittests.classification.inputs import _input_multilabel_multidim_logits as _input_mlmd_logits -# from unittests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from unittests.classification.inputs import ( # EXTRA_DIM, - _input_binary, - _input_binary_logits, - _input_binary_multidim, - _input_binary_multidim_logits, - _input_binary_multidim_prob, - _input_binary_prob, +import torch +from scipy.special import expit as sigmoid +from sklearn.metrics import confusion_matrix as sk_confusion_matrix +from sklearn.metrics import f1_score as sk_f1_score +from sklearn.metrics import generalized_dice_score as sk_generalized_dice_score +from torch import Tensor + +from torchmetrics.classification.generalized_dice import ( + BinaryGeneralizedDiceScore, + MulticlassGeneralizedDiceScore, + MultilabelGeneralizedDiceScore, +) +from torchmetrics.functional.classification.generalized_dice import ( + binary_generalized_dice_score, + multiclass_generalized_dice_score, + multilabel_generalized_dice_score, ) -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_logits as _input_mdmc_logits -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_generalized_dice( - preds: Tensor, - target: Tensor, - weight_type: str, - multiclass: bool, - num_classes: int, - ignore_index: Optional[int] = None, - zero_division: Optional[int] = None, -) -> float: - """Compute generalized dice score from 1D prediction and target. - - Args: - preds: prediction tensor - target: target tensor - weight_type: type of weight to use. - multiclass: whether problem is multiclass. - num_classes: number of classes. - ignore_index: integer specifying a target class to ignore. - zero_division: The value to use for the score if denominator equals zero. If set to 0, score will be 1 - if the numerator is also 0, and 0 otherwise - Return: - Float generalized dice score - """ - sk_preds, sk_target, mode = _input_format_classification( - preds, target, multiclass=multiclass, num_classes=num_classes - ) - - if ignore_index is not None: - sk_preds = _del_column(sk_preds, ignore_index) - sk_target = _del_column(sk_target, ignore_index) - - # Compute intersection, target and prediction volumes - intersection = sk_preds * sk_target - target_volume = sk_target - pred_volume = sk_preds - volume = target_volume + pred_volume - - # Reduce over the spatial dimension, if there is one, from (N, C, X) to (N, C) - if sk_preds.ndim == 3: - intersection = intersection.sum(dim=2) - target_volume = target_volume.sum(dim=2) - pred_volume = pred_volume.sum(dim=2) - volume = volume.sum(dim=2) - - # Weight computation per sample per class - if weight_type == "simple": - weights = reciprocal(target_volume.float()) - elif weight_type == "square": - weights = reciprocal(target_volume.float() * target_volume.float()) - elif weight_type is None: - weights = ones_like(target_volume.float()) - - # Replace infinites by maximum weight value for the sample. If all weights are infinite, replace by 0 - if weights.dim() > 1: - for sample_weights in weights: - infs = isinf(sample_weights) - sample_weights[infs] = max(sample_weights[~infs]) if len(sample_weights[~infs]) > 0 else 0 +def _sk_generalized_dice_score_binary(preds, target, sk_fn, ignore_index, multidim_average): + if multidim_average == "global": + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() else: - infs = isinf(weights) - weights[infs] = max(weights[~infs]) - - # Reduce from (N, C) into (N) - numerator = 2 * (weights * intersection).sum(dim=-1) - denominator = (weights * volume).sum(dim=-1) - pred_volume = pred_volume.sum(dim=-1) - - # Compute score and handle zero division - score = numerator / denominator - if zero_division is None: - score = where((denominator == 0) & (pred_volume == 0), tensor(1).float(), score) - score = where((denominator == 0) & (pred_volume != 0), tensor(0).float(), score) + preds = preds.numpy() + target = target.numpy() + + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + + if multidim_average == "global": + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_fn(target, preds) else: - score[denominator == 0] = zero_division + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + res.append(sk_fn(true, pred)) + return np.stack(res) - # Return mean over samples - return score.mean() +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryGeneralizedDiceScore(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("ddp", [False, True]) + def test_binary_generalized_dice_score( + self, ddp, input, module, functional, compare, ignore_index, multidim_average + ): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") -@pytest.mark.parametrize( - ["pred", "target", "expected"], - [ - ([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.0), - ([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.0), - ([[1, 1], [1, 1]], [[1, 1], [0, 0]], 0.5), - ([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.0), - ], -) -def test_generalized_dice_score(pred, target, expected): - score = generalized_dice_score(tensor(pred), tensor(target)) - assert score == expected + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=module, + sk_metric=partial( + _sk_generalized_dice_score_binary, + sk_fn=compare, + ignore_index=ignore_index, + multidim_average=multidim_average, + ), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_binary_generalized_dice_score_functional( + self, input, module, functional, compare, ignore_index, multidim_average + ): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") -@pytest.mark.parametrize( - "preds, target, multiclass, multidim, num_classes", - [ - (_input_binary_multidim.preds, _input_binary_multidim.target, True, True, 2), - (_input_binary_multidim_logits.preds, _input_binary_multidim_logits.target, True, True, 2), - (_input_binary_multidim_prob.preds, _input_binary_multidim_prob.target, True, True, 2), - (_input_binary.preds, _input_binary.target, True, False, 2), - (_input_binary_logits.preds, _input_binary_logits.target, True, False, 2), - (_input_binary_prob.preds, _input_binary_prob.target, True, False, 2), - ], -) -@pytest.mark.parametrize("zero_division", [None, 0, 1]) -@pytest.mark.parametrize("ignore_index", [None, 0]) -@pytest.mark.parametrize("weight_type", ["simple", "square", None]) -class TestGeneralizedDiceBinary(MetricTester): - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [False]) - def test_generalized_dice_class( - self, - ddp, - dist_sync_on_step, - preds, - target, - multiclass, - multidim, - weight_type, - ignore_index, - num_classes, - zero_division, + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=functional, + sk_metric=partial( + _sk_generalized_dice_score_binary, + sk_fn=compare, + ignore_index=ignore_index, + multidim_average=multidim_average, + ), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, + ) + + def test_binary_generalized_dice_score_differentiability(self, input, module, functional, compare): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_generalized_dice_score_half_cpu(self, input, module, functional, compare, dtype): + preds, target = input + + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_generalized_dice_score_half_gpu(self, input, module, functional, compare, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + + +def _sk_generalized_dice_score_multiclass(preds, target, sk_fn, ignore_index, multidim_average, average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + if multidim_average == "global": + preds = preds.numpy().flatten() + target = target.numpy().flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_fn(target, preds, average=average) + else: + preds = preds.numpy() + target = target.numpy() + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + res.append(sk_fn(true, pred, average=average, labels=list(range(NUM_CLASSES)))) + return np.stack(res, 0) + + +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassGeneralizedDiceScore(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_generalized_dice_score( + self, ddp, input, module, functional, compare, ignore_index, multidim_average, average ): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=GeneralizedDiceScore, + metric_class=module, sk_metric=partial( - _sk_generalized_dice, - weight_type=weight_type, + _sk_generalized_dice_score_multiclass, + sk_fn=compare, ignore_index=ignore_index, - multiclass=multiclass, - num_classes=num_classes, - zero_division=zero_division, + multidim_average=multidim_average, + average=average, ), - dist_sync_on_step=dist_sync_on_step, metric_args={ "ignore_index": ignore_index, - "weight_type": weight_type, - "multiclass": multiclass, - "multidim": multidim, - "num_classes": num_classes, - "zero_division": zero_division, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, }, ) - def test_generalized_dice_fn( - self, preds, target, multiclass, multidim, weight_type, ignore_index, num_classes, zero_division + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multiclass_generalized_dice_score_functional( + self, input, module, functional, compare, ignore_index, multidim_average, average ): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + self.run_functional_metric_test( - preds, - target, - metric_functional=generalized_dice_score, + preds=preds, + target=target, + metric_functional=functional, sk_metric=partial( - _sk_generalized_dice, - weight_type=weight_type, + _sk_generalized_dice_score_multiclass, + sk_fn=compare, ignore_index=ignore_index, - multiclass=multiclass, - num_classes=num_classes, - zero_division=zero_division, + multidim_average=multidim_average, + average=average, ), metric_args={ "ignore_index": ignore_index, - "weight_type": weight_type, - "multiclass": multiclass, - "multidim": multidim, - "num_classes": num_classes, - "zero_division": zero_division, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, }, ) + def test_multiclass_generalized_dice_score_differentiability(self, input, module, functional, compare): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_generalized_dice_score_half_cpu(self, input, module, functional, compare, dtype): + preds, target = input + + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_generalized_dice_score_half_gpu(self, input, module, functional, compare, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + +_mc_k_target = torch.tensor([0, 1, 2]) +_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) + + +@pytest.mark.parametrize( + "metric_class, metric_fn", + [ + (partial(MulticlassGeneralizedDiceScore, beta=2.0), partial(multiclass_generalized_dice_score, beta=2.0)), + (MulticlassF1Score, multiclass_f1_score), + ], +) @pytest.mark.parametrize( - "preds, target, multiclass, multidim, num_classes", + "k, preds, target, average, expected_generalized_dice, expected_f1", [ - (_input_mcls.preds, _input_mcls.target, True, False, NUM_CLASSES), - (_input_mcls_logits.preds, _input_mcls_logits.target, True, False, NUM_CLASSES), - (_input_mcls_prob.preds, _input_mcls_prob.target, True, False, NUM_CLASSES), - (_input_mdmc.preds, _input_mdmc.target, True, True, NUM_CLASSES), - (_input_mdmc_logits.preds, _input_mdmc_logits.target, True, True, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, True, True, NUM_CLASSES), - (_input_miss_class.preds, _input_miss_class.target, True, False, NUM_CLASSES), - (_input_mlb.preds, _input_mlb.target, False, False, NUM_CLASSES), - (_input_mlb_logits.preds, _input_mlb_logits.target, False, False, NUM_CLASSES), - (_input_mlb_prob.preds, _input_mlb_prob.target, False, False, NUM_CLASSES), - # (_input_mlmd.preds, _input_mlmd.target, True, True, NUM_CLASSES), - # (_input_mlmd_logits.preds, _input_mlmd_logits.target, False, True, NUM_CLASSES * EXTRA_DIM), - # (_input_mlmd_prob.preds, _input_mlmd_prob.target, True, True, NUM_CLASSES * EXTRA_DIM), + (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6), torch.tensor(2 / 3)), ], ) -@pytest.mark.parametrize("zero_division", [None, 0, 1]) -@pytest.mark.parametrize("ignore_index", [None, 0]) -@pytest.mark.parametrize("weight_type", ["simple", "square", None]) -class TestGeneralizedDiceMulti(MetricTester): - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [False]) - def test_generalized_dice_class( - self, - ddp, - dist_sync_on_step, - preds, - target, - multiclass, - multidim, - weight_type, - ignore_index, - num_classes, - zero_division, +def test_top_k( + metric_class, + metric_fn, + k: int, + preds: Tensor, + target: Tensor, + average: str, + expected_generalized_dice: Tensor, + expected_f1: Tensor, +): + """A simple test to check that top_k works as expected.""" + class_metric = metric_class(top_k=k, average=average, num_classes=3) + class_metric.update(preds, target) + + if class_metric.beta != 1.0: + result = expected_generalized_dice + else: + result = expected_f1 + + assert torch.isclose(class_metric.compute(), result) + assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) + + +def _sk_generalized_dice_score_multilabel_global(preds, target, sk_fn, ignore_index, average): + if average == "micro": + preds = preds.flatten() + target = target.flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_fn(target, preds) + + generalized_dice_score, weights = [], [] + for i in range(preds.shape[1]): + pred, true = preds[:, i].flatten(), target[:, i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + generalized_dice_score.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + weights.append(confmat[1, 1] + confmat[1, 0]) + res = np.stack(generalized_dice_score, axis=0) + + if average == "macro": + return res.mean(0) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +def _sk_generalized_dice_score_multilabel_local(preds, target, sk_fn, ignore_index, average): + generalized_dice_score, weights = [], [] + for i in range(preds.shape[0]): + if average == "micro": + pred, true = preds[i].flatten(), target[i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + generalized_dice_score.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + weights.append(confmat[1, 1] + confmat[1, 0]) + else: + scores, w = [], [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + true, pred = remove_ignore_index(true, pred, ignore_index) + scores.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + w.append(confmat[1, 1] + confmat[1, 0]) + generalized_dice_score.append(np.stack(scores)) + weights.append(np.stack(w)) + if average == "micro": + return np.array(generalized_dice_score) + res = np.stack(generalized_dice_score, 0) + if average == "macro": + return res.mean(-1) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +def _sk_generalized_dice_score_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + if ignore_index is None and multidim_average == "global": + return sk_fn( + target.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), + preds.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), + average=average, + ) + elif multidim_average == "global": + return _sk_generalized_dice_score_multilabel_global(preds, target, sk_fn, ignore_index, average) + return _sk_generalized_dice_score_multilabel_local(preds, target, sk_fn, ignore_index, average) + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelGeneralizedDiceScore(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_generalized_dice_score( + self, ddp, input, module, functional, compare, ignore_index, multidim_average, average ): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=GeneralizedDiceScore, + metric_class=module, sk_metric=partial( - _sk_generalized_dice, - weight_type=weight_type, + _sk_generalized_dice_score_multilabel, + sk_fn=compare, ignore_index=ignore_index, - multiclass=multiclass, - num_classes=num_classes, - zero_division=zero_division, + multidim_average=multidim_average, + average=average, ), - dist_sync_on_step=dist_sync_on_step, metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, "ignore_index": ignore_index, - "weight_type": weight_type, - "multiclass": multiclass, - "multidim": multidim, - "num_classes": num_classes, - "zero_division": zero_division, + "multidim_average": multidim_average, + "average": average, }, ) - def test_generalized_dice_fn( - self, preds, target, multiclass, multidim, weight_type, ignore_index, num_classes, zero_division + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_generalized_dice_score_functional( + self, input, module, functional, compare, ignore_index, multidim_average, average ): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + self.run_functional_metric_test( - preds, - target, - metric_functional=generalized_dice_score, + preds=preds, + target=target, + metric_functional=functional, sk_metric=partial( - _sk_generalized_dice, - weight_type=weight_type, + _sk_generalized_dice_score_multilabel, + sk_fn=compare, ignore_index=ignore_index, - multiclass=multiclass, - num_classes=num_classes, - zero_division=zero_division, + multidim_average=multidim_average, + average=average, ), metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, "ignore_index": ignore_index, - "weight_type": weight_type, - "multiclass": multiclass, - "multidim": multidim, - "num_classes": num_classes, - "zero_division": zero_division, + "multidim_average": multidim_average, + "average": average, }, ) + + def test_multilabel_generalized_dice_score_differentiability(self, input, module, functional, compare): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_generalized_dice_score_half_cpu(self, input, module, functional, compare, dtype): + preds, target = input + + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_generalized_dice_score_half_gpu(self, input, module, functional, compare, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) From 54afa4480191081fbcf494500d74786a54cb6bc7 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 9 Jan 2023 13:13:39 +0100 Subject: [PATCH 11/39] missing import --- src/torchmetrics/functional/classification/generalized_dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/classification/generalized_dice.py b/src/torchmetrics/functional/classification/generalized_dice.py index 3f7d6c637cc..c919853d418 100644 --- a/src/torchmetrics/functional/classification/generalized_dice.py +++ b/src/torchmetrics/functional/classification/generalized_dice.py @@ -15,7 +15,7 @@ import torch from torch import Tensor -from typing_extensions import Literal +from typing_extensions import Literal, Any from torchmetrics.functional.classification.stat_scores import ( _binary_stat_scores_arg_validation, From 7aa422ec80260276d93d86f5bd096334ecd49a49 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Jan 2023 12:14:12 +0000 Subject: [PATCH 12/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/classification/generalized_dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/classification/generalized_dice.py b/src/torchmetrics/functional/classification/generalized_dice.py index c919853d418..63dfcc069b5 100644 --- a/src/torchmetrics/functional/classification/generalized_dice.py +++ b/src/torchmetrics/functional/classification/generalized_dice.py @@ -15,7 +15,7 @@ import torch from torch import Tensor -from typing_extensions import Literal, Any +from typing_extensions import Any, Literal from torchmetrics.functional.classification.stat_scores import ( _binary_stat_scores_arg_validation, From e6f67b022906e9f1ba0482fa4221469de9378803 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 9 Jan 2023 15:34:24 +0100 Subject: [PATCH 13/39] missing import --- src/torchmetrics/classification/dice.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 399cc398f76..08ed139b118 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -20,6 +20,7 @@ from torchmetrics.functional.classification.dice import _dice_compute from torchmetrics.utilities.enums import AverageMethod from torchmetrics.utilities.prints import rank_zero_warn +from torchmetrics.metric import Metric class Dice(Metric): From 0a931fb5d22df0e92995efe140f3a6496533d516 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Jan 2023 14:35:04 +0000 Subject: [PATCH 14/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/classification/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 08ed139b118..4965c5d94b8 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -18,9 +18,9 @@ from typing_extensions import Literal from torchmetrics.functional.classification.dice import _dice_compute +from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod from torchmetrics.utilities.prints import rank_zero_warn -from torchmetrics.metric import Metric class Dice(Metric): From e10b91474a0c9509591c57ca7185cf476eabffe8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Feb 2023 20:44:08 +0000 Subject: [PATCH 15/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/classification/generalized_dice.py | 2 +- .../functional/classification/generalized_dice.py | 5 +---- .../classification/test_generalized_dice_score.py | 11 ++++------- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/torchmetrics/classification/generalized_dice.py b/src/torchmetrics/classification/generalized_dice.py index d77efeea312..6f092d02dab 100644 --- a/src/torchmetrics/classification/generalized_dice.py +++ b/src/torchmetrics/classification/generalized_dice.py @@ -232,7 +232,7 @@ def __new__( **kwargs: Any, ) -> Metric: assert multidim_average is not None - kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) + kwargs.update({"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}) if task == "binary": return BinaryGeneralizedDiceScore(beta, threshold, **kwargs) if task == "multiclass": diff --git a/src/torchmetrics/functional/classification/generalized_dice.py b/src/torchmetrics/functional/classification/generalized_dice.py index 63dfcc069b5..bf1c8d930fc 100644 --- a/src/torchmetrics/functional/classification/generalized_dice.py +++ b/src/torchmetrics/functional/classification/generalized_dice.py @@ -76,10 +76,7 @@ def _generalized_dice_reduce( generalized_dice_score = _safe_divide(2 * (tp + weights), 2 * tp + fp + fn) if average is None or average == "none": return generalized_dice_score - if average == "weighted": - weights = tp + fn - else: - weights = torch.ones_like(generalized_dice_score) + weights = tp + fn if average == "weighted" else torch.ones_like(generalized_dice_score) return _safe_divide(weights * generalized_dice_score, weights.sum(-1, keepdim=True)).sum(-1) diff --git a/tests/unittests/classification/test_generalized_dice_score.py b/tests/unittests/classification/test_generalized_dice_score.py index b7b9d281609..e20a13e9494 100644 --- a/tests/unittests/classification/test_generalized_dice_score.py +++ b/tests/unittests/classification/test_generalized_dice_score.py @@ -48,7 +48,7 @@ def _sk_generalized_dice_score_binary(preds, target, sk_fn, ignore_index, multid target = target.numpy() if np.issubdtype(preds.dtype, np.floating): - if not ((0 < preds) & (preds < 1)).all(): + if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) @@ -294,7 +294,7 @@ def test_multiclass_generalized_dice_score_half_gpu(self, input, module, functio @pytest.mark.parametrize( - "metric_class, metric_fn", + ("metric_class", "metric_fn"), [ (partial(MulticlassGeneralizedDiceScore, beta=2.0), partial(multiclass_generalized_dice_score, beta=2.0)), (MulticlassF1Score, multiclass_f1_score), @@ -321,10 +321,7 @@ def test_top_k( class_metric = metric_class(top_k=k, average=average, num_classes=3) class_metric.update(preds, target) - if class_metric.beta != 1.0: - result = expected_generalized_dice - else: - result = expected_f1 + result = expected_generalized_dice if class_metric.beta != 1.0 else expected_f1 assert torch.isclose(class_metric.compute(), result) assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) @@ -394,7 +391,7 @@ def _sk_generalized_dice_score_multilabel(preds, target, sk_fn, ignore_index, mu preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): - if not ((0 < preds) & (preds < 1)).all(): + if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) preds = preds.reshape(*preds.shape[:2], -1) From 0e3ab5db11b4f45a001f27920f6309c7319d7289 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 20 Feb 2023 21:59:18 +0100 Subject: [PATCH 16/39] Apply suggestions from code review --- src/torchmetrics/classification/dice.py | 6 +++--- .../classification/generalized_dice.py | 6 +++--- .../functional/classification/generalized_dice.py | 15 +++------------ 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 179f61deb48..2fb6a1ddb37 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -26,8 +26,8 @@ class Dice(Metric): r"""Compute `Dice`_. - .. deprecated:: v0.10 - The `Dice` module was deprecated in v0.10 and will be removed in v0.11. Use `F1Score` module instead which + .. deprecated:: v0.12 + The `Dice` module was deprecated in v0.12 and will be removed in v0.13. Use `F1Score` module instead which is equivalent. .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} @@ -150,7 +150,7 @@ def __init__( **kwargs: Any, ) -> None: rank_zero_warn( - "The `dice` function was deprecated in v0.10 and will be removed in v0.11. Use `f1score` function instead" + "The `dice` function was deprecated in v0.12 and will be removed in v0.13. Use `f1score` function instead" " which is equivalent.", DeprecationWarning, ) diff --git a/src/torchmetrics/classification/generalized_dice.py b/src/torchmetrics/classification/generalized_dice.py index 6f092d02dab..bc9ff64dd2b 100644 --- a/src/torchmetrics/classification/generalized_dice.py +++ b/src/torchmetrics/classification/generalized_dice.py @@ -206,10 +206,10 @@ class GeneralizedDiceScore: If ``top_k`` is not an ``integer`` larger than ``0``. Example: - >>> import torch + >>> from torch import tensor >>> from torchmetrics import GeneralizedDiceScore - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) + >>> preds = tensor([2, 0, 2, 1]) + >>> target = tensor([1, 1, 2, 0]) >>> generalized_dice_score = GeneralizedDiceScore(num_classes=3) >>> generalized_dice_score(preds, target) tensor(0.3478) diff --git a/src/torchmetrics/functional/classification/generalized_dice.py b/src/torchmetrics/functional/classification/generalized_dice.py index bf1c8d930fc..696f639237a 100644 --- a/src/torchmetrics/functional/classification/generalized_dice.py +++ b/src/torchmetrics/functional/classification/generalized_dice.py @@ -207,15 +207,6 @@ def _generalized_dice_compute( ignore_index: Optional index of the class to ignore in the score computation zero_division: The value to use for the score if denominator equals zero. If set to 0, score will be 1 if the numerator is also 0, and 0 otherwise - - Example: - >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update - >>> from torchmetrics.functional.classification.generalized_dice import _generalized_dice_compute - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='macro', mdmc_reduce='samplewise') - >>> _generalized_dice_compute(tp, fp, fn, average='samples') - tensor(0.3478) """ # Compute ground-truth class volume and class weights target_volume = tp + fn @@ -362,10 +353,10 @@ def generalized_dice_score( If ``top_k`` is not an integer larger than ``0``. Example: - >>> import torch + >>> from torch import tensor >>> from torchmetrics.functional import generalized_dice_score - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) + >>> preds = tensor([2, 0, 2, 1]) + >>> target = tensor([1, 1, 2, 0]) >>> generalized_dice_score(preds, target, average='samples') tensor(0.3478) """ From 5e7fe1c479bbb450c7ae5738a0cac555d0f0ff86 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Feb 2023 20:59:57 +0000 Subject: [PATCH 17/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/classification/generalized_dice.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/classification/generalized_dice.py b/src/torchmetrics/classification/generalized_dice.py index bc9ff64dd2b..80f8951fa08 100644 --- a/src/torchmetrics/classification/generalized_dice.py +++ b/src/torchmetrics/classification/generalized_dice.py @@ -232,7 +232,9 @@ def __new__( **kwargs: Any, ) -> Metric: assert multidim_average is not None - kwargs.update({"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}) + kwargs.update( + {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} + ) if task == "binary": return BinaryGeneralizedDiceScore(beta, threshold, **kwargs) if task == "multiclass": From 653a9e4f737f1803593ef4d236413c86c488aa2a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Feb 2023 08:39:49 +0000 Subject: [PATCH 18/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/classification/test_generalized_dice_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/classification/test_generalized_dice_score.py b/tests/unittests/classification/test_generalized_dice_score.py index e20a13e9494..eef23b8a32e 100644 --- a/tests/unittests/classification/test_generalized_dice_score.py +++ b/tests/unittests/classification/test_generalized_dice_score.py @@ -301,7 +301,7 @@ def test_multiclass_generalized_dice_score_half_gpu(self, input, module, functio ], ) @pytest.mark.parametrize( - "k, preds, target, average, expected_generalized_dice, expected_f1", + ("k", "preds", "target", "average", "expected_generalized_dice", "expected_f1"), [ (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6), torch.tensor(2 / 3)), From d6cde19ba5789ff982539a24ebd6b98f397803d7 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 28 Feb 2023 10:14:28 +0100 Subject: [PATCH 19/39] Literal --- src/torchmetrics/classification/generalized_dice.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/classification/generalized_dice.py b/src/torchmetrics/classification/generalized_dice.py index 80f8951fa08..b46b50a4466 100644 --- a/src/torchmetrics/classification/generalized_dice.py +++ b/src/torchmetrics/classification/generalized_dice.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Literal, Optional +from typing import Any, Optional from torch import Tensor +from typing_extensions import Literal from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.generalized_dice import ( From 4b99b1d944b8fac5a0c2baea6e2272c77a98c6e1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Mar 2023 16:09:42 +0000 Subject: [PATCH 20/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/classification/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 6717830e239..5223e893dd8 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -19,10 +19,10 @@ from torchmetrics.functional.classification.dice import _dice_compute from torchmetrics.metric import Metric -from torchmetrics.utilities.prints import rank_zero_warn from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE +from torchmetrics.utilities.prints import rank_zero_warn if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["Dice.plot"] From 0435c7cf3b3f4261071cf0d3efe46b3ffa80d1b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Apr 2023 12:19:55 +0000 Subject: [PATCH 21/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index aff0e48ebe5..911e49d0291 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -26,7 +26,6 @@ auroc, average_precision, binary_precision_at_fixed_recall, -generalized_dice_score, calibration_error, cohen_kappa, confusion_matrix, @@ -34,6 +33,7 @@ exact_match, f1_score, fbeta_score, + generalized_dice_score, hamming_distance, hinge_loss, jaccard_index, From 4631b65891b9a7e66ad0c63c25a5ebd568fbb0d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 May 2023 18:55:56 +0000 Subject: [PATCH 22/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/classification/__init__.py | 2 +- tests/unittests/classification/test_generalized_dice_score.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 7bcfdeba8e1..d061ae195d1 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -43,7 +43,7 @@ MultilabelF1Score, MultilabelFBetaScore, ) -from torchmetrics.classification.generalized_dice import GeneralizedDiceScore # noqa: F401 +from torchmetrics.classification.generalized_dice import GeneralizedDiceScore from torchmetrics.classification.group_fairness import BinaryFairness, BinaryGroupStatRates from torchmetrics.classification.hamming import ( BinaryHammingDistance, diff --git a/tests/unittests/classification/test_generalized_dice_score.py b/tests/unittests/classification/test_generalized_dice_score.py index eef23b8a32e..ca103e4cdfa 100644 --- a/tests/unittests/classification/test_generalized_dice_score.py +++ b/tests/unittests/classification/test_generalized_dice_score.py @@ -21,7 +21,6 @@ from sklearn.metrics import f1_score as sk_f1_score from sklearn.metrics import generalized_dice_score as sk_generalized_dice_score from torch import Tensor - from torchmetrics.classification.generalized_dice import ( BinaryGeneralizedDiceScore, MulticlassGeneralizedDiceScore, @@ -32,6 +31,7 @@ multiclass_generalized_dice_score, multilabel_generalized_dice_score, ) + from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index From 628696dbe9c4f43fc26e65f58a1343b6db4b66d4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 30 Jun 2023 17:54:02 +0000 Subject: [PATCH 23/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/classification/test_generalized_dice_score.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unittests/classification/test_generalized_dice_score.py b/tests/unittests/classification/test_generalized_dice_score.py index ca103e4cdfa..8fed9892e87 100644 --- a/tests/unittests/classification/test_generalized_dice_score.py +++ b/tests/unittests/classification/test_generalized_dice_score.py @@ -352,6 +352,7 @@ def _sk_generalized_dice_score_multilabel_global(preds, target, sk_fn, ignore_in return ((weights * res) / weights_norm).sum(-1) elif average is None or average == "none": return res + return None def _sk_generalized_dice_score_multilabel_local(preds, target, sk_fn, ignore_index, average): @@ -385,6 +386,7 @@ def _sk_generalized_dice_score_multilabel_local(preds, target, sk_fn, ignore_ind return ((weights * res) / weights_norm).sum(-1) elif average is None or average == "none": return res + return None def _sk_generalized_dice_score_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): From ef708bf8dca9c269e4d99589f558d703a41ec07e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Aug 2023 11:25:08 +0000 Subject: [PATCH 24/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/classification/generalized_dice.py | 1 + src/torchmetrics/functional/classification/generalized_dice.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/torchmetrics/classification/generalized_dice.py b/src/torchmetrics/classification/generalized_dice.py index b46b50a4466..8eeeab99d3c 100644 --- a/src/torchmetrics/classification/generalized_dice.py +++ b/src/torchmetrics/classification/generalized_dice.py @@ -214,6 +214,7 @@ class GeneralizedDiceScore: >>> generalized_dice_score = GeneralizedDiceScore(num_classes=3) >>> generalized_dice_score(preds, target) tensor(0.3478) + """ def __new__( diff --git a/src/torchmetrics/functional/classification/generalized_dice.py b/src/torchmetrics/functional/classification/generalized_dice.py index 696f639237a..f1a503e0505 100644 --- a/src/torchmetrics/functional/classification/generalized_dice.py +++ b/src/torchmetrics/functional/classification/generalized_dice.py @@ -359,6 +359,7 @@ def generalized_dice_score( >>> target = tensor([1, 1, 2, 0]) >>> generalized_dice_score(preds, target, average='samples') tensor(0.3478) + """ allowed_weight_type = ("square", "simple", None) if weight_type not in allowed_weight_type: From 32e6d1ac4816e90a0dfc8247b8831c7f13d5a413 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:31:29 +0000 Subject: [PATCH 25/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/classification/generalized_dice.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/classification/generalized_dice.py b/src/torchmetrics/classification/generalized_dice.py index 8eeeab99d3c..ddb18b9e648 100644 --- a/src/torchmetrics/classification/generalized_dice.py +++ b/src/torchmetrics/classification/generalized_dice.py @@ -234,9 +234,11 @@ def __new__( **kwargs: Any, ) -> Metric: assert multidim_average is not None - kwargs.update( - {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} - ) + kwargs.update({ + "multidim_average": multidim_average, + "ignore_index": ignore_index, + "validate_args": validate_args, + }) if task == "binary": return BinaryGeneralizedDiceScore(beta, threshold, **kwargs) if task == "multiclass": From 978054d2b7df21314ebb9ee3a9ae67b7363d33e2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 12 Apr 2024 16:34:08 +0200 Subject: [PATCH 26/39] move around files --- docs/source/{classification => segmentation}/generalized_dice.rst | 0 .../{classification => segmentation}/generalized_dice.py | 0 .../{classification => segmentation}/generalized_dice.py | 0 .../test_generalized_dice_score.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename docs/source/{classification => segmentation}/generalized_dice.rst (100%) rename src/torchmetrics/functional/{classification => segmentation}/generalized_dice.py (100%) rename src/torchmetrics/{classification => segmentation}/generalized_dice.py (100%) rename tests/unittests/{classification => segmentation}/test_generalized_dice_score.py (100%) diff --git a/docs/source/classification/generalized_dice.rst b/docs/source/segmentation/generalized_dice.rst similarity index 100% rename from docs/source/classification/generalized_dice.rst rename to docs/source/segmentation/generalized_dice.rst diff --git a/src/torchmetrics/functional/classification/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py similarity index 100% rename from src/torchmetrics/functional/classification/generalized_dice.py rename to src/torchmetrics/functional/segmentation/generalized_dice.py diff --git a/src/torchmetrics/classification/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py similarity index 100% rename from src/torchmetrics/classification/generalized_dice.py rename to src/torchmetrics/segmentation/generalized_dice.py diff --git a/tests/unittests/classification/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py similarity index 100% rename from tests/unittests/classification/test_generalized_dice_score.py rename to tests/unittests/segmentation/test_generalized_dice_score.py From 70fabaf588d8120131920f791609310f797d05ba Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 12 Apr 2024 16:35:31 +0200 Subject: [PATCH 27/39] revert some changes from classification --- CHANGELOG.md | 2 +- src/torchmetrics/classification/dice.py | 14 ++------------ src/torchmetrics/classification/f_beta.py | 12 +++--------- src/torchmetrics/functional/classification/dice.py | 10 ---------- .../functional/classification/f_beta.py | 12 +++--------- src/torchmetrics/segmentation/__init__.py | 13 +++++++++++++ 6 files changed, 22 insertions(+), 41 deletions(-) create mode 100644 src/torchmetrics/segmentation/__init__.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b5ced3054b..ab9b5bcd806 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Added `GeneralizedDiceScore` to classification package ([#1090](https://github.com/Lightning-AI/metrics/pull/1090)) +- Added `GeneralizedDiceScore` to segmentation package ([#1090](https://github.com/Lightning-AI/metrics/pull/1090)) - Added `SensitivityAtSpecificity` metric to classification subpackage ([#2217](https://github.com/Lightning-AI/torchmetrics/pull/2217)) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 23ace338137..39bc2acabcd 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -18,11 +18,11 @@ from typing_extensions import Literal from torchmetrics.functional.classification.dice import _dice_compute +from torchmetrics.functional.classification.stat_scores import _stat_scores_update from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE -from torchmetrics.utilities.prints import rank_zero_warn if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["Dice.plot"] @@ -31,10 +31,6 @@ class Dice(Metric): r"""Compute `Dice`_. - .. deprecated:: v0.12 - The `Dice` module was deprecated in v0.12 and will be removed in v0.13. Use `F1Score` module instead which - is equivalent. - .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and @@ -159,14 +155,8 @@ def __init__( multiclass: Optional[bool] = None, **kwargs: Any, ) -> None: - rank_zero_warn( - "The `dice` function was deprecated in v0.12 and will be removed in v0.13. Use `f1score` function instead" - " which is equivalent.", - DeprecationWarning, - ) - - allowed_average = ("micro", "macro", "weighted", "samples", "none", None) super().__init__(**kwargs) + allowed_average = ("micro", "macro", "samples", "none", None) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 8ca903eebe2..93f26441c2a 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -552,13 +552,11 @@ def plot( class BinaryF1Score(BinaryFBetaScore): - r"""Compute F-1 score (also known as Dice score/similarity) for binary tasks: + r"""Compute F-1 score for binary tasks. .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} - F-1 score correspond to equally weighted average of the precision and recall scores. - The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0` where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false positives and false negatives respectively. If this case is encountered a score of 0 is returned. @@ -690,13 +688,11 @@ def plot( class MulticlassF1Score(MulticlassFBetaScore): - r"""Compute F-1 score (also known as Dice score/similarity) for multiclass tasks: + r"""Compute F-1 score for multiclass tasks. .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} - F-1 score correspond to equally weighted average of the precision and recall scores. - The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0` where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class, the metric for that class @@ -865,13 +861,11 @@ def plot( class MultilabelF1Score(MultilabelFBetaScore): - r"""Compute F-1 score (also known as Dice score/similarity) for multilabel tasks: + r"""Compute F-1 score for multilabel tasks. .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} - F-1 score correspond to equally weighted average of the precision and recall scores. - The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0` where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any label, the metric for that label diff --git a/src/torchmetrics/functional/classification/dice.py b/src/torchmetrics/functional/classification/dice.py index 720150040ac..49d66ea9361 100644 --- a/src/torchmetrics/functional/classification/dice.py +++ b/src/torchmetrics/functional/classification/dice.py @@ -19,7 +19,6 @@ from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update from torchmetrics.utilities.checks import _input_squeeze from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod -from torchmetrics.utilities.prints import rank_zero_warn def _dice_compute( @@ -79,10 +78,6 @@ def dice( ) -> Tensor: r"""Compute `Dice`_. - .. deprecated:: v0.10 - The `dice` function was deprecated in v0.10 and will be removed in v0.11. Use `f1score` function instead which - is equivalent. - .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and @@ -179,11 +174,6 @@ def dice( tensor(0.2500) """ - rank_zero_warn( - "The `dice` function was deprecated in v0.10 and will be removed in v0.11. Use `f1score` function instead" - " which is equivalent.", - DeprecationWarning, - ) allowed_average = ("micro", "macro", "weighted", "samples", "none", None) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index f5615a25cad..0f0e883266c 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -386,13 +386,11 @@ def binary_f1_score( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Compute F-1 score (also known as Dice score/similarity) for binary tasks: + r"""Compute F-1 score for binary tasks. .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} - F-1 score correspond to equally weighted average of the precision and recall scores. - Accepts the following input tensors: - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside @@ -465,13 +463,11 @@ def multiclass_f1_score( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Compute F-1 score (also known as Dice score/similarity) for multiclass tasks: + r"""Compute F-1 score for multiclass tasks. .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} - F-1 score correspond to equally weighted average of the precision and recall scores. - Accepts the following input tensors: - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point @@ -574,13 +570,11 @@ def multilabel_f1_score( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Compute F-1 score (also known as Dice score/similarity) for multilabel tasks: + r"""Compute F-1 score for multilabel tasks. .. math:: F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} - F-1 score correspond to equally weighted average of the precision and recall scores. - Accepts the following input tensors: - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside diff --git a/src/torchmetrics/segmentation/__init__.py b/src/torchmetrics/segmentation/__init__.py new file mode 100644 index 00000000000..e708df58645 --- /dev/null +++ b/src/torchmetrics/segmentation/__init__.py @@ -0,0 +1,13 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file From bd29835644a283b003c4aaed0b8e3119c6981b2a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 12 Apr 2024 16:38:39 +0200 Subject: [PATCH 28/39] move content in init files --- src/torchmetrics/__init__.py | 2 -- src/torchmetrics/classification/__init__.py | 1 - src/torchmetrics/functional/__init__.py | 2 -- src/torchmetrics/functional/classification/__init__.py | 1 - src/torchmetrics/functional/segmentation/__init__.py | 4 ++++ src/torchmetrics/segmentation/__init__.py | 6 +++++- 6 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index a710e67c9a8..b1549dfaf8b 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -51,7 +51,6 @@ ExactMatch, F1Score, FBetaScore, - GeneralizedDiceScore, HammingDistance, HingeLoss, JaccardIndex, @@ -177,7 +176,6 @@ "ExtendedEditDistance", "F1Score", "FBetaScore", - "GeneralizedDiceScore", "FleissKappa", "HammingDistance", "HingeLoss", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 40b315cdc15..988a01c2947 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -43,7 +43,6 @@ MultilabelF1Score, MultilabelFBetaScore, ) -from torchmetrics.classification.generalized_dice import GeneralizedDiceScore from torchmetrics.classification.group_fairness import BinaryFairness, BinaryGroupStatRates from torchmetrics.classification.hamming import ( BinaryHammingDistance, diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 39d303211a2..30a7145aa71 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -32,7 +32,6 @@ exact_match, f1_score, fbeta_score, - generalized_dice_score, hamming_distance, hinge_loss, jaccard_index, @@ -162,7 +161,6 @@ "extended_edit_distance", "f1_score", "fbeta_score", - "generalized_dice_score", "fleiss_kappa", "hamming_distance", "hinge_loss", diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index ffa474e0eb9..faf523844bc 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -52,7 +52,6 @@ multilabel_f1_score, multilabel_fbeta_score, ) -from torchmetrics.functional.classification.generalized_dice import generalized_dice_score from torchmetrics.functional.classification.group_fairness import ( binary_fairness, binary_groups_stat_rates, diff --git a/src/torchmetrics/functional/segmentation/__init__.py b/src/torchmetrics/functional/segmentation/__init__.py index 94f1dec4a9f..eec2e4dfcf3 100644 --- a/src/torchmetrics/functional/segmentation/__init__.py +++ b/src/torchmetrics/functional/segmentation/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score + +__all__ = ["generalized_dice_score"] diff --git a/src/torchmetrics/segmentation/__init__.py b/src/torchmetrics/segmentation/__init__.py index e708df58645..24275594e4c 100644 --- a/src/torchmetrics/segmentation/__init__.py +++ b/src/torchmetrics/segmentation/__init__.py @@ -10,4 +10,8 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. + +from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore + +__all__ = ["GeneralizedDiceScore"] From 05ca6e99e65381e8502c86c3267824beadc24992 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 22 Apr 2024 09:56:11 +0200 Subject: [PATCH 29/39] implementation --- .../segmentation/generalized_dice.py | 434 +++--------------- .../functional/segmentation/utils.py | 7 + 2 files changed, 73 insertions(+), 368 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index f1a503e0505..bbefec742fc 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -11,395 +11,93 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - import torch from torch import Tensor -from typing_extensions import Any, Literal - -from torchmetrics.functional.classification.stat_scores import ( - _binary_stat_scores_arg_validation, - _binary_stat_scores_format, - _binary_stat_scores_tensor_validation, - _binary_stat_scores_update, - _multiclass_stat_scores_arg_validation, - _multiclass_stat_scores_format, - _multiclass_stat_scores_tensor_validation, - _multiclass_stat_scores_update, - _multilabel_stat_scores_arg_validation, - _multilabel_stat_scores_format, - _multilabel_stat_scores_tensor_validation, - _multilabel_stat_scores_update, - _reduce_stat_scores, - _stat_scores_update, -) +from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.compute import _safe_divide -from torchmetrics.utilities.enums import AverageMethod as AvgMethod -from torchmetrics.utilities.enums import MDMCAverageMethod -from torchmetrics.utilities.prints import rank_zero_warn - - -def _generalized_dice_reduce( - tp: Tensor, - fp: Tensor, - tn: Tensor, - fn: Tensor, - weight_type: Optional[Literal["square", "simple"]], - average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], - multidim_average: Literal["global", "samplewise"] = "global", -) -> Tensor: - target_volume = tp + fn - if weight_type == "simple": - weights = torch.reciprocal(target_volume.float()) - elif weight_type == "square": - weights = torch.reciprocal(target_volume.float() * target_volume.float()) - elif weight_type is None: - weights = torch.ones_like(target_volume.float()) - - if weights.ndim > 1: - for sample_weights in weights: - infs = torch.isinf(sample_weights) - sample_weights[infs] = torch.max(sample_weights[~infs]) if len(sample_weights[~infs]) > 0 else 0 - else: - infs = torch.isinf(weights) - weights[infs] = torch.max(weights[~infs]) - - if average == "binary": - return _safe_divide(2 * (tp + weights), 2 * tp + fp + fn) - elif average == "micro": - tp = tp.sum(dim=0 if multidim_average == "global" else 1) - fn = fn.sum(dim=0 if multidim_average == "global" else 1) - fp = fp.sum(dim=0 if multidim_average == "global" else 1) - weights = weights.sum(dim=0 if multidim_average == "global" else 1) - return _safe_divide(2 * (tp + weights), 2 * tp + fp + fn) - else: - generalized_dice_score = _safe_divide(2 * (tp + weights), 2 * tp + fp + fn) - if average is None or average == "none": - return generalized_dice_score - weights = tp + fn if average == "weighted" else torch.ones_like(generalized_dice_score) - return _safe_divide(weights * generalized_dice_score, weights.sum(-1, keepdim=True)).sum(-1) - - -def _binary_generalized_dice_score_arg_validation( - weight_type: Optional[Literal["square", "simple"]], - threshold: float = 0.5, - multidim_average: Literal["global", "samplewise"] = "global", - ignore_index: Optional[int] = None, -) -> None: - allowed_weight_type = ("square", "simple", None) - if weight_type not in weight_type: - raise ValueError( - f"Argument `weight_type` needs to one of the following: {allowed_weight_type} but got {weight_type}" - ) - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) - - -def binary_generalized_dice_score( - preds: Tensor, - target: Tensor, - weight_type: Optional[Literal["square", "simple"]], - threshold: float = 0.5, - multidim_average: Literal["global", "samplewise"] = "global", - ignore_index: Optional[int] = None, - validate_args: bool = True, -) -> Tensor: - if validate_args: - _binary_generalized_dice_score_arg_validation(weight_type, threshold, multidim_average, ignore_index) - _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) - preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) - tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) - return _generalized_dice_reduce(tp, fp, tn, fn, weight_type, average="binary", multidim_average=multidim_average) - +from torchmetrics.functional.segmentation.utils import _ignore_background +from typing_extensions import Literal -def _multiclass_generalized_dice_score_arg_validation( - weight_type: Optional[Literal["square", "simple"]], +def _generalized_dice_validate_args( num_classes: int, - top_k: int = 1, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - multidim_average: Literal["global", "samplewise"] = "global", - ignore_index: Optional[int] = None, + include_background: bool, + per_class: bool, ) -> None: - allowed_weight_type = ("square", "simple", None) - if weight_type not in weight_type: - raise ValueError( - f"Argument `weight_type` needs to one of the following: {allowed_weight_type} but got {weight_type}" - ) - _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) - - -def multiclass_generalized_dice_score( + """Validate the arguments of the metric.""" + if num_classes <= 0: + raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.") + if not isinstance(include_background, bool): + raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.") + if not isinstance(per_class, bool): + raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.") + + +def _generalized_dice_update( preds: Tensor, target: Tensor, - weight_type: Optional[Literal["square", "simple"]], num_classes: int, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - top_k: int = 1, - multidim_average: Literal["global", "samplewise"] = "global", - ignore_index: Optional[int] = None, - validate_args: bool = True, -) -> Tensor: - if validate_args: - _multiclass_generalized_dice_score_arg_validation( - weight_type, num_classes, top_k, average, multidim_average, ignore_index - ) - _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) - preds, target = _multiclass_stat_scores_format(preds, target, top_k) - tp, fp, tn, fn = _multiclass_stat_scores_update( - preds, target, num_classes, top_k, average, multidim_average, ignore_index - ) - return _generalized_dice_reduce(tp, fp, tn, fn, weight_type, average=average, multidim_average=multidim_average) - - -def _multilabel_generalized_dice_score_arg_validation( - weight_type: Optional[Literal["square", "simple"]], - num_labels: int, - threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - multidim_average: Literal["global", "samplewise"] = "global", - ignore_index: Optional[int] = None, -) -> None: - allowed_weight_type = ("square", "simple", None) - if weight_type not in weight_type: - raise ValueError( - f"Argument `weight_type` needs to one of the following: {allowed_weight_type} but got {weight_type}" - ) - _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) - - -def multilabel_generalized_dice_score( - preds: Tensor, - target: Tensor, - weight_type: Optional[Literal["square", "simple"]], - num_labels: int, - threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - multidim_average: Literal["global", "samplewise"] = "global", - ignore_index: Optional[int] = None, - validate_args: bool = True, + include_background: bool, + per_class: bool, + weight_type: Literal["square", "simple", "linear"] = "square", ) -> Tensor: - if validate_args: - _multilabel_generalized_dice_score_arg_validation( - weight_type, num_labels, threshold, average, multidim_average, ignore_index - ) - _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) - preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) - tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) - return _generalized_dice_reduce(tp, fp, tn, fn, weight_type, average=average, multidim_average=multidim_average) + """Update the state with the current prediction and target.""" + _check_same_shape(preds, target) + if preds.ndim < 3: + + + if (preds.bool() != preds).any(): # preds is an index tensor + preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) + if (target.bool() != target).any(): # target is an index tensor + target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) + + if not include_background: + preds, target = _ignore_background(preds, target) + + reduce_axis = list(range(2, target.ndim)) + intersection = torch.sum(preds * target, dim=reduce_axis) + target_sum = torch.sum(target, dim=reduce_axis) + pred_sum = torch.sum(preds, dim=reduce_axis) + cardinality = target_sum + pred_sum - -def _generalized_dice_compute( - tp: Tensor, - fp: Tensor, - fn: Tensor, - average: str = "samples", - weight_type: str = "square", - ignore_index: Optional[int] = None, - zero_division: Optional[int] = None, -) -> Tensor: - """Computes generalized dice score from the stat scores: true positives, false positives, false negatives. - - Args: - tp: True positives - fp: False positives - fn: False negatives - average: Defines the reduction that is applied - weight_type: Defines the type of weights applied different classes - ignore_index: Optional index of the class to ignore in the score computation - zero_division: The value to use for the score if denominator equals zero. If set to 0, score will be 1 - if the numerator is also 0, and 0 otherwise - """ - # Compute ground-truth class volume and class weights - target_volume = tp + fn if weight_type == "simple": - weights = torch.reciprocal(target_volume.float()) + weights = 1.0 / target_sum + elif weight_type == "linear": + weights = torch.ones_like(target_sum) elif weight_type == "square": - weights = torch.reciprocal(target_volume.float() * target_volume.float()) - elif weight_type is None: - weights = torch.ones_like(target_volume.float()) - - # Replace weights and stats for ignore_index by 0 - if ignore_index is not None: - weights[..., ignore_index] = 0 - tp[..., ignore_index] = 0 - fp[..., ignore_index] = 0 - fn[..., ignore_index] - - # Replace infinite weights for non-appearing classes by the max« weight or 0, if all weights are infinite - if weights.dim() > 1: - for sample_weights in weights: - infs = torch.isinf(sample_weights) - sample_weights[infs] = torch.max(sample_weights[~infs]) if len(sample_weights[~infs]) > 0 else 0 - else: - infs = torch.isinf(weights) - weights[infs] = torch.max(weights[~infs]) - - # Compute weighted numerator and denominator - numerator = 2 * (tp * weights).sum(dim=-1) - denominator = ((2 * tp + fp + fn) * weights).sum(dim=-1) - - # Handle zero division - denominator_zeros = denominator == 0 - denominator[denominator_zeros] = 1 - if zero_division is not None: - # If zero_division score is specified, use it as numerator and set denominator to 1 - numerator[denominator_zeros] = zero_division + weights = 1.0 / (target_sum ** 2) else: - # If both denominator and total sample prediction volume are 0, score is 1. Otherwise 0. - pred_volume = (tp + fp).sum(dim=-1) - pred_zeros = pred_volume == 0 - numerator[denominator_zeros] = torch.where( - pred_zeros[denominator_zeros], - torch.tensor(1, device=numerator.device).float(), - torch.tensor(0, device=numerator.device).float(), + raise ValueError( + f"Expected argument `weight_type` to be one of 'simple', 'linear', 'square', but got {weight_type}." ) + + infs = torch.isinf(weights) + weights[infs] = 0 + weights = torch.max(weights, 0) - return _reduce_stat_scores( - numerator=numerator, - denominator=denominator, - weights=None, - average=average, - mdmc_average=None, - ) + numerator = 2.0 * (intersection * weights).sum(dim=1) + denominator = (cardinality * weights).sum(dim=1) + return numerator, denominator + +def _generalized_dice_compute(numerator: Tensor, denominator: Tensor) -> Tensor: + """Compute the generalized dice score.""" + return _safe_divide(numerator, denominator) def generalized_dice_score( preds: Tensor, target: Tensor, - weight_type: str = "square", - zero_division: Optional[int] = None, - average: str = "samples", - threshold: float = 0.5, - top_k: Optional[int] = None, - num_classes: Optional[int] = None, - multiclass: bool = True, - multidim: bool = True, - ignore_index: Optional[int] = None, - **kwargs: Any, + num_classes: int, + include_background: bool = False, + per_class: bool = False, ) -> Tensor: - r"""Computes the Generalized Dice Score (GDS) metric: - - .. math:: - \text{GDS}=\sum_{i=1}^{C}\frac{2\cdot\text{TP}_i}{(2\cdot\text{TP}_i+\text{FP}_i+\text{FN}_i)\cdot w_i} - - Where :math:`\text{C}` is the number of classes and :math:`\text{TP}_i`, :math:`\text{FP}_i` and :math:`\text{FN}`_i - represent the numbers of true positives, false positives and false negatives for class :math:`i`, respectively. - :math:`w_i` represents the weight of class :math:`i`. - - The reduction method (how the recall scores are aggregated) is controlled by the - ``average`` parameter. Accepts all inputs listed in :ref:`pages/classification:input types`. - - Args: - preds: Predictions from model (probabilities, logits or labels). - - target: Ground truth values. - - weight_type: Defines the type of weighting to apply. Should be one of the following: - - - ``'square'`` [default]: Weight each class by the squared inverse of its support, - i.e., the inverse of its squared volume - :math:`\frac{1}{(tp + fn)^2}`. - - ``'simple'``: Weight each class by the inverse of its support, i.e., - the inverse of its volume - :math:`\frac{1}{tp + fn}`. - - ``None``: All classes are assigned unitary weight. Equivalent to dice score. - - zero_division: - The value to use for the score if denominator equals zero. If set to None, the score will be 1 if the - numerator is also 0, and 0 otherwise. - - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'samples'`` [default]: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - ``'none'`` or ``None``: Calculate the metric for each sample separately, and return - the metric for every sample. - - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - - top_k: - Number of the highest probability or logit score predictions considered finding the correct label. - The default value (``None``) will be interpreted as 1. - - num_classes: - Number of classes. - - multiclass: - Determines whether the input is multiclass (if True) or multilabel (if False). - - multidim: - Determines whether the input is multidim or not. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. - - Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average == 'samples'``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(N,)``, where ``N`` stands for the number of samples - - Raises: - ValueError: - If ``weight_type`` is not ``"simple"``, ``"square"`` or ``None``. - ValueError: - If ``average`` is not one of ``"samples"``, ``"none"`` or ``None``. - ValueError: - If ``num_classes`` is provided but is not an integer larger than 0. - ValueError: - If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an integer larger than ``0``. - + """ Example: - >>> from torch import tensor - >>> from torchmetrics.functional import generalized_dice_score - >>> preds = tensor([2, 0, 2, 1]) - >>> target = tensor([1, 1, 2, 0]) - >>> generalized_dice_score(preds, target, average='samples') - tensor(0.3478) - + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.functional.segmentation import generalized_dice_score + >>> preds = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target """ - allowed_weight_type = ("square", "simple", None) - if weight_type not in allowed_weight_type: - raise ValueError(f"The `weight_type` has to be one of {allowed_weight_type}, got {weight_type}.") - - allowed_average = ("samples", "none", None) - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - if num_classes and num_classes < 1: - raise ValueError("Number of classes must be larger than 0.") - - if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): - raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") - - preds, target = _input_squeeze(preds, target) - - # Obtain tp, fp and fn per sample per class - reduce = "macro" if multidim else None - tp, fp, _, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce="samplewise", - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - multiclass=multiclass, - ignore_index=ignore_index, - ) - - return _generalized_dice_compute( - tp, - fp, - fn, - average=average, - ignore_index=None if reduce is None else ignore_index, - weight_type=weight_type, - zero_division=zero_division, - ) + _generalized_dice_validate_args(num_classes, include_background, per_class) + numerator, denominator = _generalized_dice_update(preds, target, num_classes, include_background, per_class) + return _generalized_dice_compute(numerator, denominator) diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py index bbf5c48ded3..e8427a69326 100644 --- a/src/torchmetrics/functional/segmentation/utils.py +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -24,6 +24,13 @@ from torchmetrics.utilities.imports import _SCIPY_AVAILABLE +def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """Ignore the background class in the computation.""" + preds = preds[:, 1:] if preds.shape[1] > 1 else preds + target = target[:, 1:] if target.shape[1] > 1 else target + return preds, target + + def check_if_binarized(x: Tensor) -> None: """Check if the input is binarized. From 75235d874d9c2df739ab5b369b1e5c70518f8b22 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 22 Apr 2024 11:04:23 +0200 Subject: [PATCH 30/39] more implementation --- .../segmentation/generalized_dice.py | 43 +- .../segmentation/generalized_dice.py | 292 +++------- .../test_generalized_dice_score.py | 536 ++---------------- 3 files changed, 155 insertions(+), 716 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index bbefec742fc..9d70c4dd0f8 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -22,6 +22,7 @@ def _generalized_dice_validate_args( num_classes: int, include_background: bool, per_class: bool, + weight_type: Literal["square", "simple", "linear"], ) -> None: """Validate the arguments of the metric.""" if num_classes <= 0: @@ -30,6 +31,10 @@ def _generalized_dice_validate_args( raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.") if not isinstance(per_class, bool): raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.") + if weight_type not in ["square", "simple", "linear"]: + raise ValueError( + f"Expected argument `weight_type` to be one of 'square', 'simple', 'linear', but got {weight_type}." + ) def _generalized_dice_update( @@ -37,14 +42,13 @@ def _generalized_dice_update( target: Tensor, num_classes: int, include_background: bool, - per_class: bool, weight_type: Literal["square", "simple", "linear"] = "square", ) -> Tensor: """Update the state with the current prediction and target.""" _check_same_shape(preds, target) if preds.ndim < 3: - - + raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.") + if (preds.bool() != preds).any(): # preds is an index tensor preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) if (target.bool() != target).any(): # target is an index tensor @@ -70,25 +74,34 @@ def _generalized_dice_update( f"Expected argument `weight_type` to be one of 'simple', 'linear', 'square', but got {weight_type}." ) - infs = torch.isinf(weights) - weights[infs] = 0 - weights = torch.max(weights, 0) + w_shape = weights.shape + weights_flatten = weights.flatten() + infs = torch.isinf(weights_flatten) + weights_flatten[infs] = 0 + w_max = torch.max(weights, 0).values.repeat(w_shape[0],1).T.flatten() + weights_flatten[infs] = w_max[infs] + weights = weights_flatten.reshape(w_shape) - numerator = 2.0 * (intersection * weights).sum(dim=1) - denominator = (cardinality * weights).sum(dim=1) + numerator = 2.0 * intersection * weights + denominator = cardinality * weights return numerator, denominator -def _generalized_dice_compute(numerator: Tensor, denominator: Tensor) -> Tensor: +def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: bool = True) -> Tensor: """Compute the generalized dice score.""" - return _safe_divide(numerator, denominator) + if per_class: + numerator = torch.sum(numerator, 1) + denominator = torch.sum(denominator, 1) + val = _safe_divide(numerator, denominator) + return val def generalized_dice_score( preds: Tensor, target: Tensor, num_classes: int, - include_background: bool = False, + include_background: bool = True, per_class: bool = False, + weight_type: Literal["square", "simple", "linear"] = "square", ) -> Tensor: """ Example: @@ -97,7 +110,9 @@ def generalized_dice_score( >>> from torchmetrics.functional.segmentation import generalized_dice_score >>> preds = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction >>> target = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> generalized_dice_score(preds, target, num_classes=5) + tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000]) """ - _generalized_dice_validate_args(num_classes, include_background, per_class) - numerator, denominator = _generalized_dice_update(preds, target, num_classes, include_background, per_class) - return _generalized_dice_compute(numerator, denominator) + _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type) + numerator, denominator = _generalized_dice_update(preds, target, num_classes, include_background, weight_type) + return _generalized_dice_compute(numerator, denominator, per_class) diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index ddb18b9e648..ebf6e254dc5 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -1,4 +1,4 @@ -# Copyright The PyTorch Lightning team. +# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,243 +11,89 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional - -from torch import Tensor +from typing import Any, Optional, Sequence, Union from typing_extensions import Literal +import torch +from torch import Tensor -from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores -from torchmetrics.functional.classification.generalized_dice import ( - _binary_generalized_dice_score_arg_validation, - _generalized_dice_reduce, - _multiclass_generalized_dice_score_arg_validation, - _multilabel_generalized_dice_score_arg_validation, +from torchmetrics.functional.segmentation.generalized_dice import ( + _generalized_dice_validate_args, _generalized_dice_compute, _generalized_dice_update ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["GeneralizedDiceScore.plot"] -class BinaryGeneralizedDiceScore(BinaryStatScores): - is_differentiable: bool = False - higher_is_better: Optional[bool] = True - full_state_update: bool = False - - def __init__( - self, - threshold: float = 0.5, - multidim_average: Literal["global", "samplewise"] = "global", - weight_type: Optional[Literal["square", "simple"]] = "square", - ignore_index: Optional[int] = None, - validate_args: bool = True, - **kwargs: Any, - ) -> None: - super().__init__( - threshold=threshold, - multidim_average=multidim_average, - ignore_index=ignore_index, - validate_args=False, - **kwargs, - ) - if validate_args: - _binary_generalized_dice_score_arg_validation(weight_type, threshold, multidim_average, ignore_index) - self.validate_args = validate_args - self.weight_type = weight_type - - def compute(self) -> Tensor: - tp, fp, tn, fn = self._final_state() - return _generalized_dice_reduce( - tp, fp, tn, fn, self.weight_type, average="binary", multidim_average=self.multidim_average - ) +class GeneralizedDice(Metric): + """ + """ -class MulticlassGeneralizedDiceScore(MulticlassStatScores): - is_differentiable: bool = False - higher_is_better: Optional[bool] = True + score: Tensor + num_batches: Tensor full_state_update: bool = False - - def __init__( - self, - num_classes: int, - top_k: int = 1, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - multidim_average: Literal["global", "samplewise"] = "global", - weight_type: Optional[Literal["square", "simple"]] = "square", - ignore_index: Optional[int] = None, - validate_args: bool = True, - **kwargs: Any, - ) -> None: - super().__init__( - num_classes=num_classes, - top_k=top_k, - average=average, - multidim_average=multidim_average, - ignore_index=ignore_index, - validate_args=False, - **kwargs, - ) - if validate_args: - _multiclass_generalized_dice_score_arg_validation( - weight_type, num_classes, top_k, average, multidim_average, ignore_index - ) - self.validate_args = validate_args - self.weight_type = weight_type - - def compute(self) -> Tensor: - tp, fp, tn, fn = self._final_state() - return _generalized_dice_reduce( - tp, fp, tn, fn, self.weight_type, average=self.average, multidim_average=self.multidim_average - ) - - -class MultilabelGeneralizedDiceScore(MultilabelStatScores): is_differentiable: bool = False - higher_is_better: Optional[bool] = True - full_state_update: bool = False + higher_is_better: bool = True + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 def __init__( self, - num_labels: int, - threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - multidim_average: Literal["global", "samplewise"] = "global", - weight_type: Optional[Literal["square", "simple"]] = "square", - ignore_index: Optional[int] = None, - validate_args: bool = True, + num_classes: int, + include_background: bool = True, + per_class: bool = False, + weight_type: Literal["square", "simple", "linear"] = "square", **kwargs: Any, - ) -> None: - super().__init__( - num_labels=num_labels, - threshold=threshold, - average=average, - multidim_average=multidim_average, - ignore_index=ignore_index, - validate_args=False, - **kwargs, - ) - if validate_args: - _multilabel_generalized_dice_score_arg_validation( - weight_type, num_labels, threshold, average, multidim_average, ignore_index - ) - self.validate_args = validate_args + ): + super().__init__(**kwargs) + _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type) + self.num_classes = num_classes + self.include_background = include_background + self.per_class = per_class self.weight_type = weight_type - def compute(self) -> Tensor: - tp, fp, tn, fn = self._final_state() - return _generalized_dice_reduce( - tp, fp, tn, fn, self.weight_type, average=self.average, multidim_average=self.multidim_average - ) - - -class GeneralizedDiceScore: - r"""Computes the Generalized Dice Score (GDS) metric: - - .. math:: - \text{GDS}=\sum_{i=1}^{C}\frac{2\cdot\text{TP}_i}{(2\cdot\text{TP}_i+\text{FP}_i+\text{FN}_i)\cdot w_i} - - Where :math:`\text{C}` is the number of classes and :math:`\text{TP}_i`, :math:`\text{FP}_i` and :math:`\text{FN}`_i - represent the numbers of true positives, false positives and false negatives for class :math:`i`, respectively. - :math:`w_i` represents the weight of class :math:`i`. - - The reduction method (how the generalized dice scores are aggregated) is controlled by the - ``average`` parameter. Accepts all inputs listed in :ref:`pages/classification:input types`. - Does not accept multidimensional multi-label data. - - Args: - num_classes: - Number of classes. - - weight_type: Defines the type of weighting to apply. Should be one of the following: - - - ``'square'`` [default]: Weight each class by the squared inverse of its support, - i.e., the inverse of its squared volume - :math:`\frac{1}{(tp + fn)^2}`. - - ``'simple'``: Weight each class by the inverse of its support, i.e., - the inverse of its volume - :math:`\frac{1}{tp + fn}`. - - zero_division: - The value to use for the score if denominator equals zero. If set to None, the score will be 1 if the - numerator is also 0, and 0 otherwise. + self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean") - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'samples'`` [default]: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - ``'none'`` or ``None``: Calculate the metric for each sample separately, and return - the metric for every sample. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. - - top_k: - Number of the highest probability or logit score predictions considered finding the correct label. - The default value (``None``) will be interpreted as 1. - - multiclass: - Determines whether the input is multiclass (if True) or multilabel (if False). - - multidim: - Determines whether the input is multidim or not. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Raises: - ValueError: - If ``weight_type`` is not ``"simple"``, ``"square"`` or ``None``. - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. - ValueError: - If ``num_classes`` is not larger than ``0``. - ValueError: - If ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. - - Example: - >>> from torch import tensor - >>> from torchmetrics import GeneralizedDiceScore - >>> preds = tensor([2, 0, 2, 1]) - >>> target = tensor([1, 1, 2, 0]) - >>> generalized_dice_score = GeneralizedDiceScore(num_classes=3) - >>> generalized_dice_score(preds, target) - tensor(0.3478) - - """ - - def __new__( - cls, - num_classes: Optional[int] = None, - beta: float = 1.0, - threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, - num_labels: Optional[int] = None, - multidim_average: Optional[Literal["global", "samplewise"]] = "global", - validate_args: bool = True, - **kwargs: Any, - ) -> Metric: - assert multidim_average is not None - kwargs.update({ - "multidim_average": multidim_average, - "ignore_index": ignore_index, - "validate_args": validate_args, - }) - if task == "binary": - return BinaryGeneralizedDiceScore(beta, threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassGeneralizedDiceScore(beta, num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelGeneralizedDiceScore(beta, num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + def update(self, preds: Tensor, target: Tensor) -> None: + """ Update the state with new data.""" + numerator, denominator = _generalized_dice_update( + preds, target, self.num_classes, self.include_background, self.weight_type ) + self.score += _generalized_dice_compute(numerator, denominator, self.per_class) + + def compute(self): + return self.score + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + Returns: + Figure and Axes object + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + .. plot:: + :scale: 75 + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality + >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') + >>> metric.update(torch.rand(8000), torch.rand(8000)) + >>> fig_, ax_ = metric.plot() + .. plot:: + :scale: 75 + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality + >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(8000), torch.rand(8000))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 8fed9892e87..8c3d31aac09 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -1,4 +1,4 @@ -# Copyright The PyTorch Lightning team. +# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,508 +11,86 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from functools import partial -import numpy as np import pytest import torch -from scipy.special import expit as sigmoid -from sklearn.metrics import confusion_matrix as sk_confusion_matrix -from sklearn.metrics import f1_score as sk_f1_score -from sklearn.metrics import generalized_dice_score as sk_generalized_dice_score -from torch import Tensor -from torchmetrics.classification.generalized_dice import ( - BinaryGeneralizedDiceScore, - MulticlassGeneralizedDiceScore, - MultilabelGeneralizedDiceScore, -) -from torchmetrics.functional.classification.generalized_dice import ( - binary_generalized_dice_score, - multiclass_generalized_dice_score, - multilabel_generalized_dice_score, -) - -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index - -seed_all(42) - - -def _sk_generalized_dice_score_binary(preds, target, sk_fn, ignore_index, multidim_average): - if multidim_average == "global": - preds = preds.view(-1).numpy() - target = target.view(-1).numpy() - else: - preds = preds.numpy() - target = target.numpy() - - if np.issubdtype(preds.dtype, np.floating): - if not ((preds > 0) & (preds < 1)).all(): - preds = sigmoid(preds) - preds = (preds >= THRESHOLD).astype(np.uint8) - - if multidim_average == "global": - target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds) - else: - res = [] - for pred, true in zip(preds, target): - pred = pred.flatten() - true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) - res.append(sk_fn(true, pred)) - return np.stack(res) - - -@pytest.mark.parametrize("input", _binary_cases) -class TestBinaryGeneralizedDiceScore(MetricTester): - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) - @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) - @pytest.mark.parametrize("ddp", [False, True]) - def test_binary_generalized_dice_score( - self, ddp, input, module, functional, compare, ignore_index, multidim_average - ): - preds, target = input - if ignore_index == -1: - target = inject_ignore_index(target, ignore_index) - if multidim_average == "samplewise" and preds.ndim < 3: - pytest.skip("samplewise and non-multidim arrays are not valid") - if multidim_average == "samplewise" and ddp: - pytest.skip("samplewise and ddp give different order than non ddp") - - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=module, - sk_metric=partial( - _sk_generalized_dice_score_binary, - sk_fn=compare, - ignore_index=ignore_index, - multidim_average=multidim_average, - ), - metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, - ) - - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) - @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) - def test_binary_generalized_dice_score_functional( - self, input, module, functional, compare, ignore_index, multidim_average - ): - preds, target = input - if ignore_index == -1: - target = inject_ignore_index(target, ignore_index) - if multidim_average == "samplewise" and preds.ndim < 3: - pytest.skip("samplewise and non-multidim arrays are not valid") - - self.run_functional_metric_test( - preds=preds, - target=target, - metric_functional=functional, - sk_metric=partial( - _sk_generalized_dice_score_binary, - sk_fn=compare, - ignore_index=ignore_index, - multidim_average=multidim_average, - ), - metric_args={ - "threshold": THRESHOLD, - "ignore_index": ignore_index, - "multidim_average": multidim_average, - }, - ) - - def test_binary_generalized_dice_score_differentiability(self, input, module, functional, compare): - preds, target = input - self.run_differentiability_test( - preds=preds, - target=target, - metric_module=module, - metric_functional=functional, - metric_args={"threshold": THRESHOLD}, - ) - - @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_binary_generalized_dice_score_half_cpu(self, input, module, functional, compare, dtype): - preds, target = input - - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") - self.run_precision_test_cpu( - preds=preds, - target=target, - metric_module=module, - metric_functional=functional, - metric_args={"threshold": THRESHOLD}, - dtype=dtype, - ) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_binary_generalized_dice_score_half_gpu(self, input, module, functional, compare, dtype): - preds, target = input - self.run_precision_test_gpu( - preds=preds, - target=target, - metric_module=module, - metric_functional=functional, - metric_args={"threshold": THRESHOLD}, - dtype=dtype, - ) - +from monai.metrics.generalized_dice import compute_generalized_dice +from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score +from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore -def _sk_generalized_dice_score_multiclass(preds, target, sk_fn, ignore_index, multidim_average, average): - if preds.ndim == target.ndim + 1: - preds = torch.argmax(preds, 1) - if multidim_average == "global": - preds = preds.numpy().flatten() - target = target.numpy().flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds, average=average) - else: - preds = preds.numpy() - target = target.numpy() - res = [] - for pred, true in zip(preds, target): - pred = pred.flatten() - true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) - res.append(sk_fn(true, pred, average=average, labels=list(range(NUM_CLASSES)))) - return np.stack(res, 0) +from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input +from unittests._helpers.testers import MetricTester +_inputs1 = _Input( + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), + target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), +) +_inputs2 = _Input( + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), + target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), +) +_inputs3 = _Input( + preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), + target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), +) -@pytest.mark.parametrize("input", _multiclass_cases) -class TestMulticlassGeneralizedDiceScore(MetricTester): - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) - @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) - @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) - @pytest.mark.parametrize("ddp", [True, False]) - def test_multiclass_generalized_dice_score( - self, ddp, input, module, functional, compare, ignore_index, multidim_average, average - ): - preds, target = input - if ignore_index == -1: - target = inject_ignore_index(target, ignore_index) - if multidim_average == "samplewise" and target.ndim < 3: - pytest.skip("samplewise and non-multidim arrays are not valid") - if multidim_average == "samplewise" and ddp: - pytest.skip("samplewise and ddp give different order than non ddp") - - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=module, - sk_metric=partial( - _sk_generalized_dice_score_multiclass, - sk_fn=compare, - ignore_index=ignore_index, - multidim_average=multidim_average, - average=average, - ), - metric_args={ - "ignore_index": ignore_index, - "multidim_average": multidim_average, - "average": average, - "num_classes": NUM_CLASSES, - }, - ) - - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) - @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) - @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) - def test_multiclass_generalized_dice_score_functional( - self, input, module, functional, compare, ignore_index, multidim_average, average - ): - preds, target = input - if ignore_index == -1: - target = inject_ignore_index(target, ignore_index) - if multidim_average == "samplewise" and target.ndim < 3: - pytest.skip("samplewise and non-multidim arrays are not valid") - - self.run_functional_metric_test( - preds=preds, - target=target, - metric_functional=functional, - sk_metric=partial( - _sk_generalized_dice_score_multiclass, - sk_fn=compare, - ignore_index=ignore_index, - multidim_average=multidim_average, - average=average, - ), - metric_args={ - "ignore_index": ignore_index, - "multidim_average": multidim_average, - "average": average, - "num_classes": NUM_CLASSES, - }, - ) - - def test_multiclass_generalized_dice_score_differentiability(self, input, module, functional, compare): - preds, target = input - self.run_differentiability_test( - preds=preds, - target=target, - metric_module=module, - metric_functional=functional, - metric_args={"num_classes": NUM_CLASSES}, - ) - - @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_multiclass_generalized_dice_score_half_cpu(self, input, module, functional, compare, dtype): - preds, target = input - - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") - self.run_precision_test_cpu( - preds=preds, - target=target, - metric_module=module, - metric_functional=functional, - metric_args={"num_classes": NUM_CLASSES}, - dtype=dtype, - ) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_multiclass_generalized_dice_score_half_gpu(self, input, module, functional, compare, dtype): - preds, target = input - self.run_precision_test_gpu( - preds=preds, - target=target, - metric_module=module, - metric_functional=functional, - metric_args={"num_classes": NUM_CLASSES}, - dtype=dtype, - ) +def _reference_generalized_dice( + preds: torch.Tensor, + target: torch.Tensor, + include_background: bool = True, + per_class: bool = True, + reduce: bool = True, +): + """Calculate reference metric for `MeanIoU`.""" + if (preds.bool() != preds).any(): # preds is an index tensor + preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) + if (target.bool() != target).any(): # target is an index tensor + target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) -_mc_k_target = torch.tensor([0, 1, 2]) -_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) + val = compute_iou(preds, target, include_background=include_background) + if reduce: + return torch.mean(val, 0) if per_class else torch.mean(val) + return val @pytest.mark.parametrize( - ("metric_class", "metric_fn"), + "preds, target", [ - (partial(MulticlassGeneralizedDiceScore, beta=2.0), partial(multiclass_generalized_dice_score, beta=2.0)), - (MulticlassF1Score, multiclass_f1_score), + (_inputs1.preds, _inputs1.target), + (_inputs2.preds, _inputs2.target), + (_inputs3.preds, _inputs3.target), ], ) -@pytest.mark.parametrize( - ("k", "preds", "target", "average", "expected_generalized_dice", "expected_f1"), - [ - (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), - (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6), torch.tensor(2 / 3)), - ], -) -def test_top_k( - metric_class, - metric_fn, - k: int, - preds: Tensor, - target: Tensor, - average: str, - expected_generalized_dice: Tensor, - expected_f1: Tensor, -): - """A simple test to check that top_k works as expected.""" - class_metric = metric_class(top_k=k, average=average, num_classes=3) - class_metric.update(preds, target) - - result = expected_generalized_dice if class_metric.beta != 1.0 else expected_f1 - - assert torch.isclose(class_metric.compute(), result) - assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) - +@pytest.mark.parametrize("include_background", [True, False]) +class TestMeanIoU(MetricTester): + """Test class for `MeanIoU` metric.""" -def _sk_generalized_dice_score_multilabel_global(preds, target, sk_fn, ignore_index, average): - if average == "micro": - preds = preds.flatten() - target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds) - - generalized_dice_score, weights = [], [] - for i in range(preds.shape[1]): - pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) - generalized_dice_score.append(sk_fn(true, pred)) - confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) - weights.append(confmat[1, 1] + confmat[1, 0]) - res = np.stack(generalized_dice_score, axis=0) - - if average == "macro": - return res.mean(0) - elif average == "weighted": - weights = np.stack(weights, 0).astype(float) - weights_norm = weights.sum(-1, keepdims=True) - weights_norm[weights_norm == 0] = 1.0 - return ((weights * res) / weights_norm).sum(-1) - elif average is None or average == "none": - return res - return None - - -def _sk_generalized_dice_score_multilabel_local(preds, target, sk_fn, ignore_index, average): - generalized_dice_score, weights = [], [] - for i in range(preds.shape[0]): - if average == "micro": - pred, true = preds[i].flatten(), target[i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) - generalized_dice_score.append(sk_fn(true, pred)) - confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) - weights.append(confmat[1, 1] + confmat[1, 0]) - else: - scores, w = [], [] - for j in range(preds.shape[1]): - pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) - scores.append(sk_fn(true, pred)) - confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) - w.append(confmat[1, 1] + confmat[1, 0]) - generalized_dice_score.append(np.stack(scores)) - weights.append(np.stack(w)) - if average == "micro": - return np.array(generalized_dice_score) - res = np.stack(generalized_dice_score, 0) - if average == "macro": - return res.mean(-1) - elif average == "weighted": - weights = np.stack(weights, 0).astype(float) - weights_norm = weights.sum(-1, keepdims=True) - weights_norm[weights_norm == 0] = 1.0 - return ((weights * res) / weights_norm).sum(-1) - elif average is None or average == "none": - return res - return None - - -def _sk_generalized_dice_score_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): - preds = preds.numpy() - target = target.numpy() - if np.issubdtype(preds.dtype, np.floating): - if not ((preds > 0) & (preds < 1)).all(): - preds = sigmoid(preds) - preds = (preds >= THRESHOLD).astype(np.uint8) - preds = preds.reshape(*preds.shape[:2], -1) - target = target.reshape(*target.shape[:2], -1) - if ignore_index is None and multidim_average == "global": - return sk_fn( - target.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), - preds.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), - average=average, - ) - elif multidim_average == "global": - return _sk_generalized_dice_score_multilabel_global(preds, target, sk_fn, ignore_index, average) - return _sk_generalized_dice_score_multilabel_local(preds, target, sk_fn, ignore_index, average) - - -@pytest.mark.parametrize("input", _multilabel_cases) -class TestMultilabelGeneralizedDiceScore(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) - @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) - @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) - def test_multilabel_generalized_dice_score( - self, ddp, input, module, functional, compare, ignore_index, multidim_average, average - ): - preds, target = input - if ignore_index == -1: - target = inject_ignore_index(target, ignore_index) - if multidim_average == "samplewise" and preds.ndim < 4: - pytest.skip("samplewise and non-multidim arrays are not valid") - if multidim_average == "samplewise" and ddp: - pytest.skip("samplewise and ddp give different order than non ddp") + atol = 1e-4 + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + @pytest.mark.parametrize("per_class", [True, False]) + def test_mean_iou_class(self, preds, target, include_background, per_class, ddp): + """Test class implementation of metric.""" self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=module, - sk_metric=partial( - _sk_generalized_dice_score_multilabel, - sk_fn=compare, - ignore_index=ignore_index, - multidim_average=multidim_average, - average=average, + metric_class=MeanIoU, + reference_metric=partial( + _reference_mean_iou, include_background=include_background, per_class=per_class, reduce=True ), - metric_args={ - "num_labels": NUM_CLASSES, - "threshold": THRESHOLD, - "ignore_index": ignore_index, - "multidim_average": multidim_average, - "average": average, - }, + metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": per_class}, ) - @pytest.mark.parametrize("ignore_index", [None, 0, -1]) - @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) - @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) - def test_multilabel_generalized_dice_score_functional( - self, input, module, functional, compare, ignore_index, multidim_average, average - ): - preds, target = input - if ignore_index == -1: - target = inject_ignore_index(target, ignore_index) - if multidim_average == "samplewise" and preds.ndim < 4: - pytest.skip("samplewise and non-multidim arrays are not valid") - + def test_mean_iou_functional(self, preds, target, include_background): + """Test functional implementation of metric.""" self.run_functional_metric_test( preds=preds, target=target, - metric_functional=functional, - sk_metric=partial( - _sk_generalized_dice_score_multilabel, - sk_fn=compare, - ignore_index=ignore_index, - multidim_average=multidim_average, - average=average, - ), - metric_args={ - "num_labels": NUM_CLASSES, - "threshold": THRESHOLD, - "ignore_index": ignore_index, - "multidim_average": multidim_average, - "average": average, - }, - ) - - def test_multilabel_generalized_dice_score_differentiability(self, input, module, functional, compare): - preds, target = input - self.run_differentiability_test( - preds=preds, - target=target, - metric_module=module, - metric_functional=functional, - metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, - ) - - @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_multilabel_generalized_dice_score_half_cpu(self, input, module, functional, compare, dtype): - preds, target = input - - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") - self.run_precision_test_cpu( - preds=preds, - target=target, - metric_module=module, - metric_functional=functional, - metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, - dtype=dtype, - ) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_multilabel_generalized_dice_score_half_gpu(self, input, module, functional, compare, dtype): - preds, target = input - self.run_precision_test_gpu( - preds=preds, - target=target, - metric_module=module, - metric_functional=functional, - metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, - dtype=dtype, - ) + metric_functional=mean_iou, + reference_metric=partial(_reference_mean_iou, include_background=include_background, reduce=False), + metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": True}, + ) \ No newline at end of file From 23df55f94e0de94fc22cc2d7cc2571fd82ccd844 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 22 Apr 2024 17:09:09 +0200 Subject: [PATCH 31/39] docstrings + fixing of testing framework --- .../segmentation/generalized_dice.py | 4 +- .../segmentation/generalized_dice.py | 84 +++++++++++++++++-- .../test_generalized_dice_score.py | 17 ++-- 3 files changed, 87 insertions(+), 18 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 9d70c4dd0f8..0a819279c54 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -89,8 +89,8 @@ def _generalized_dice_update( def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: bool = True) -> Tensor: """Compute the generalized dice score.""" if per_class: - numerator = torch.sum(numerator, 1) - denominator = torch.sum(denominator, 1) + numerator = torch.sum(numerator, 0) + denominator = torch.sum(denominator, 0) val = _safe_divide(numerator, denominator) return val diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index ebf6e254dc5..cdaef050329 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -27,8 +27,75 @@ __doctest_skip__ = ["GeneralizedDiceScore.plot"] -class GeneralizedDice(Metric): - """ +class GeneralizedDiceScore(Metric): + """ Compute `Generalized Dice Score`_. + + The metric can be used to evaluate the performance of image segmentation models. The Generalized Dice Score is + defined as: + + .. math:: + GDS = \\frac{2 \\sum_{i=1}^{N} w_i \\sum_{j} t_{ij} p_{ij}}{\\sum_{i=1}^{N} w_i \\sum_{j} t_{ij} + \\sum_{i=1}^{N} w_i \\sum_{j} p_{ij}} + + where :math:`N` is the number of classes, :math:`t_{ij}` is the target tensor, :math:`p_{ij}` is the prediction + tensor, and :math:`w_i` is the weight for class :math:`i`. The weight can be computed in three different ways: + + - `square`: :math:`w_i = 1 / (\\sum_{j} t_{ij})^2` + - `simple`: :math:`w_i = 1 / \\sum_{j} t_{ij}` + - `linear`: :math:`w_i = 1` + + Note that the generalized dice loss can be computed as one minus the generalized dice score. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. That format will be automatically + converted to a one-hot tensor. + - ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. That format will be automatically + converted to a one-hot tensor. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``miou`` (:class:`~torch.Tensor`): The mean Intersection over Union (mIoU) score. If ``per_class`` is set to + ``True``, the output will be a tensor of shape ``(C,)`` with the IoU score for each class. If ``per_class`` is + set to ``False``, the output will be a scalar tensor. + + Args: + num_classes: The number of classes in the segmentation problem. + include_background: Whether to include the background class in the computation + per_class: Whether to compute the IoU for each class separately. If set to ``False``, the metric will + compute the mean IoU over all classes. + weight_type: The type of weight to apply to each class. Can be one of ``"square"``, ``"simple"``, or + ``"linear"``. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: + If ``num_classes`` is not a positive integer + ValueError: + If ``include_background`` is not a boolean + ValueError: + If ``per_class`` is not a boolean + + Example: + >>> import torch + >>> _ = torch.manual_seed(0) + >>> from torchmetrics.segmentation import GeneralizedDiceScore + >>> miou = GeneralizedDiceScore(num_classes=3) + >>> preds = torch.randint(0, 2, (10, 3, 128, 128)) + >>> target = torch.randint(0, 2, (10, 3, 128, 128)) + >>> miou(preds, target) + tensor(0.3318) + >>> miou = GeneralizedDiceScore(num_classes=3, per_class=True) + >>> miou(preds, target) + tensor([0.3322, 0.3303, 0.3329]) + >>> miou = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False) + >>> miou(preds, target) + tensor([0.3303, 0.3329]) + + """ score: Tensor @@ -68,29 +135,34 @@ def compute(self): def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. + Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis + Returns: Figure and Axes object + Raises: ModuleNotFoundError: If `matplotlib` is not installed + .. plot:: :scale: 75 >>> # Example plotting a single value >>> import torch - >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality - >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') + >>> from torchmetrics.segmentation import GeneralizedDiceScore + >>> metric = GeneralizedDiceScore(8000, 'nb') >>> metric.update(torch.rand(8000), torch.rand(8000)) >>> fig_, ax_ = metric.plot() + .. plot:: :scale: 75 >>> # Example plotting multiple values >>> import torch - >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality - >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') + >>> from torchmetrics.segmentation import GeneralizedDiceScore + >>> metric = GeneralizedDiceScore(8000, 'nb') >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(8000), torch.rand(8000))) diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 8c3d31aac09..cf6ac204879 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -42,7 +42,6 @@ def _reference_generalized_dice( target: torch.Tensor, include_background: bool = True, per_class: bool = True, - reduce: bool = True, ): """Calculate reference metric for `MeanIoU`.""" if (preds.bool() != preds).any(): # preds is an index tensor @@ -50,10 +49,8 @@ def _reference_generalized_dice( if (target.bool() != target).any(): # target is an index tensor target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) - val = compute_iou(preds, target, include_background=include_background) - if reduce: - return torch.mean(val, 0) if per_class else torch.mean(val) - return val + val = compute_generalized_dice(preds, target, include_background=include_background) + return val.mean() @pytest.mark.parametrize( @@ -78,9 +75,9 @@ def test_mean_iou_class(self, preds, target, include_background, per_class, ddp) ddp=ddp, preds=preds, target=target, - metric_class=MeanIoU, + metric_class=GeneralizedDiceScore, reference_metric=partial( - _reference_mean_iou, include_background=include_background, per_class=per_class, reduce=True + _reference_generalized_dice, include_background=include_background, per_class=per_class, reduce=True ), metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": per_class}, ) @@ -90,7 +87,7 @@ def test_mean_iou_functional(self, preds, target, include_background): self.run_functional_metric_test( preds=preds, target=target, - metric_functional=mean_iou, - reference_metric=partial(_reference_mean_iou, include_background=include_background, reduce=False), - metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": True}, + metric_functional=generalized_dice_score, + reference_metric=partial(_reference_generalized_dice, include_background=include_background), + metric_args={"num_classes": NUM_CLASSES, "include_background": include_background}, ) \ No newline at end of file From 7726927e6a0f7b94c408e79835e87be17d76526c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 23 Apr 2024 10:10:58 +0200 Subject: [PATCH 32/39] links --- docs/source/links.rst | 1 + src/torchmetrics/functional/segmentation/generalized_dice.py | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/source/links.rst b/docs/source/links.rst index 7034f764d65..fc7e7f3a425 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -170,3 +170,4 @@ .. _FLORES-200: https://arxiv.org/abs/2207.04672 .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html .. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013 +.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237 \ No newline at end of file diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 0a819279c54..85ea69ef3ad 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -112,6 +112,7 @@ def generalized_dice_score( >>> target = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target >>> generalized_dice_score(preds, target, num_classes=5) tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000]) + >>> generalized_dice_score(preds, target, num_classes=5, per_class=True) """ _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type) numerator, denominator = _generalized_dice_update(preds, target, num_classes, include_background, weight_type) From 547f890af1f8c7a84f8c413f170fb9cf67ad43d0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 23 Apr 2024 10:47:43 +0200 Subject: [PATCH 33/39] fix remaining tests --- .../segmentation/generalized_dice.py | 45 +++++++--- .../segmentation/generalized_dice.py | 89 +++++++++++-------- .../test_generalized_dice_score.py | 24 +++-- 3 files changed, 92 insertions(+), 66 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 85ea69ef3ad..2c0ef2ee0f5 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -13,10 +13,12 @@ # limitations under the License. import torch from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.segmentation.utils import _ignore_background from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.compute import _safe_divide -from torchmetrics.functional.segmentation.utils import _ignore_background -from typing_extensions import Literal + def _generalized_dice_validate_args( num_classes: int, @@ -35,7 +37,7 @@ def _generalized_dice_validate_args( raise ValueError( f"Expected argument `weight_type` to be one of 'square', 'simple', 'linear', but got {weight_type}." ) - + def _generalized_dice_update( preds: Tensor, @@ -68,17 +70,17 @@ def _generalized_dice_update( elif weight_type == "linear": weights = torch.ones_like(target_sum) elif weight_type == "square": - weights = 1.0 / (target_sum ** 2) + weights = 1.0 / (target_sum**2) else: raise ValueError( f"Expected argument `weight_type` to be one of 'simple', 'linear', 'square', but got {weight_type}." ) - + w_shape = weights.shape weights_flatten = weights.flatten() infs = torch.isinf(weights_flatten) weights_flatten[infs] = 0 - w_max = torch.max(weights, 0).values.repeat(w_shape[0],1).T.flatten() + w_max = torch.max(weights, 0).values.repeat(w_shape[0], 1).T.flatten() weights_flatten[infs] = w_max[infs] weights = weights_flatten.reshape(w_shape) @@ -86,13 +88,13 @@ def _generalized_dice_update( denominator = cardinality * weights return numerator, denominator + def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: bool = True) -> Tensor: """Compute the generalized dice score.""" - if per_class: - numerator = torch.sum(numerator, 0) - denominator = torch.sum(denominator, 0) - val = _safe_divide(numerator, denominator) - return val + if not per_class: + numerator = torch.sum(numerator, 1) + denominator = torch.sum(denominator, 1) + return _safe_divide(numerator, denominator) def generalized_dice_score( @@ -103,7 +105,19 @@ def generalized_dice_score( per_class: bool = False, weight_type: Literal["square", "simple", "linear"] = "square", ) -> Tensor: - """ + """Compute the Generalized Dice Score for semantic segmentation. + + Args: + preds: Predictions from model + target: Ground truth values + num_classes: Number of classes + include_background: Whether to include the background class in the computation + per_class: Whether to compute the IoU for each class separately, else average over all classes + weight_type: Type of weight factor to apply to the classes. One of ``"square"``, ``"simple"``, or ``"linear"`` + + Returns: + The Generalized Dice Score + Example: >>> import torch >>> _ = torch.manual_seed(42) @@ -111,8 +125,13 @@ def generalized_dice_score( >>> preds = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction >>> target = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target >>> generalized_dice_score(preds, target, num_classes=5) - tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000]) + tensor([0.4830, 0.4935, 0.5044, 0.4880]) >>> generalized_dice_score(preds, target, num_classes=5, per_class=True) + tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500], + [0.4571, 0.4980, 0.5191, 0.4380, 0.5649], + [0.5428, 0.4904, 0.5358, 0.4830, 0.4724], + [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) + """ _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type) numerator, denominator = _generalized_dice_update(preds, target, num_classes, include_background, weight_type) diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index cdaef050329..0cfa26626d4 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Optional, Sequence, Union -from typing_extensions import Literal + import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.segmentation.generalized_dice import ( - _generalized_dice_validate_args, _generalized_dice_compute, _generalized_dice_update + _generalized_dice_compute, + _generalized_dice_update, + _generalized_dice_validate_args, ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE @@ -28,15 +31,16 @@ class GeneralizedDiceScore(Metric): - """ Compute `Generalized Dice Score`_. + r"""Compute `Generalized Dice Score`_. The metric can be used to evaluate the performance of image segmentation models. The Generalized Dice Score is defined as: .. math:: - GDS = \\frac{2 \\sum_{i=1}^{N} w_i \\sum_{j} t_{ij} p_{ij}}{\\sum_{i=1}^{N} w_i \\sum_{j} t_{ij} + \\sum_{i=1}^{N} w_i \\sum_{j} p_{ij}} + GDS = \frac{2 \\sum_{i=1}^{N} w_i \\sum_{j} t_{ij} p_{ij}}{ + \\sum_{i=1}^{N} w_i \\sum_{j} t_{ij} + \\sum_{i=1}^{N} w_i \\sum_{j} p_{ij}} - where :math:`N` is the number of classes, :math:`t_{ij}` is the target tensor, :math:`p_{ij}` is the prediction + where :math:`N` is the number of classes, :math:`t_{ij}` is the target tensor, :math:`p_{ij}` is the prediction tensor, and :math:`w_i` is the weight for class :math:`i`. The weight can be computed in three different ways: - `square`: :math:`w_i = 1 / (\\sum_{j} t_{ij})^2` @@ -46,7 +50,7 @@ class GeneralizedDiceScore(Metric): Note that the generalized dice loss can be computed as one minus the generalized dice score. As input to ``forward`` and ``update`` the metric accepts the following input: - + - ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` can be provided, where the integer values correspond to the class index. That format will be automatically @@ -55,22 +59,21 @@ class GeneralizedDiceScore(Metric): the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` can be provided, where the integer values correspond to the class index. That format will be automatically converted to a one-hot tensor. - + As output to ``forward`` and ``compute`` the metric returns the following output: - - - ``miou`` (:class:`~torch.Tensor`): The mean Intersection over Union (mIoU) score. If ``per_class`` is set to - ``True``, the output will be a tensor of shape ``(C,)`` with the IoU score for each class. If ``per_class`` is + + - ``gds`` (:class:`~torch.Tensor`): The generalized dice score. If ``per_class`` is set to ``True``, the output + will be a tensor of shape ``(C,)`` with the generalized dice score for each class. If ``per_class`` is set to ``False``, the output will be a scalar tensor. Args: num_classes: The number of classes in the segmentation problem. include_background: Whether to include the background class in the computation - per_class: Whether to compute the IoU for each class separately. If set to ``False``, the metric will - compute the mean IoU over all classes. - weight_type: The type of weight to apply to each class. Can be one of ``"square"``, ``"simple"``, or + per_class: Whether to compute the metric for each class separately. + weight_type: The type of weight to apply to each class. Can be one of ``"square"``, ``"simple"``, or ``"linear"``. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - + Raises: ValueError: If ``num_classes`` is not a positive integer @@ -78,28 +81,29 @@ class GeneralizedDiceScore(Metric): If ``include_background`` is not a boolean ValueError: If ``per_class`` is not a boolean - + ValueError: + If ``weight_type`` is not one of ``"square"``, ``"simple"``, or ``"linear"`` + Example: >>> import torch >>> _ = torch.manual_seed(0) >>> from torchmetrics.segmentation import GeneralizedDiceScore - >>> miou = GeneralizedDiceScore(num_classes=3) + >>> gds = GeneralizedDiceScore(num_classes=3) >>> preds = torch.randint(0, 2, (10, 3, 128, 128)) >>> target = torch.randint(0, 2, (10, 3, 128, 128)) - >>> miou(preds, target) - tensor(0.3318) - >>> miou = GeneralizedDiceScore(num_classes=3, per_class=True) - >>> miou(preds, target) - tensor([0.3322, 0.3303, 0.3329]) - >>> miou = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False) - >>> miou(preds, target) - tensor([0.3303, 0.3329]) - + >>> gds(preds, target) + tensor(0.4983) + >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True) + >>> gds(preds, target) + tensor([0.4987, 0.4966, 0.4995]) + >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False) + >>> gds(preds, target) + tensor([0.4966, 0.4995]) """ score: Tensor - num_batches: Tensor + samples: Tensor full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True @@ -113,7 +117,7 @@ def __init__( per_class: bool = False, weight_type: Literal["square", "simple", "linear"] = "square", **kwargs: Any, - ): + ) -> None: super().__init__(**kwargs) _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type) self.num_classes = num_classes @@ -121,21 +125,25 @@ def __init__( self.per_class = per_class self.weight_type = weight_type - self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean") + num_classes = num_classes - 1 if not include_background else num_classes + self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum") + self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: - """ Update the state with new data.""" + """Update the state with new data.""" numerator, denominator = _generalized_dice_update( preds, target, self.num_classes, self.include_background, self.weight_type ) - self.score += _generalized_dice_compute(numerator, denominator, self.per_class) + self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0) + self.samples += preds.shape[0] + + def compute(self) -> Tensor: + """Compute the final generalized dice score.""" + return self.score / self.samples - def compute(self): - return self.score - def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. - + Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. @@ -143,7 +151,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ Returns: Figure and Axes object - + Raises: ModuleNotFoundError: If `matplotlib` is not installed @@ -153,8 +161,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting a single value >>> import torch >>> from torchmetrics.segmentation import GeneralizedDiceScore - >>> metric = GeneralizedDiceScore(8000, 'nb') - >>> metric.update(torch.rand(8000), torch.rand(8000)) + >>> metric = GeneralizedDiceScore(num_classes=3) + >>> metric.update(torch.randint(0, 2, (10, 3, 128, 128)), torch.randint(0, 2, (10, 3, 128, 128))) >>> fig_, ax_ = metric.plot() .. plot:: @@ -162,10 +170,13 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting multiple values >>> import torch >>> from torchmetrics.segmentation import GeneralizedDiceScore - >>> metric = GeneralizedDiceScore(8000, 'nb') + >>> metric = GeneralizedDiceScore(num_classes=3) >>> values = [ ] >>> for _ in range(10): - ... values.append(metric(torch.rand(8000), torch.rand(8000))) + ... values.append( + ... metric(torch.randint(0, 2, (10, 3, 128, 128)), torch.randint(0, 2, (10, 3, 128, 128))) + ... ) >>> fig_, ax_ = metric.plot(values) + """ return self._plot(val, ax) diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index cf6ac204879..ed80e6fd6d7 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -41,16 +41,17 @@ def _reference_generalized_dice( preds: torch.Tensor, target: torch.Tensor, include_background: bool = True, - per_class: bool = True, + reduce: bool = True, ): """Calculate reference metric for `MeanIoU`.""" if (preds.bool() != preds).any(): # preds is an index tensor preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) if (target.bool() != target).any(): # target is an index tensor target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) - val = compute_generalized_dice(preds, target, include_background=include_background) - return val.mean() + if reduce: + val = val.mean() + return val @pytest.mark.parametrize( @@ -65,21 +66,16 @@ def _reference_generalized_dice( class TestMeanIoU(MetricTester): """Test class for `MeanIoU` metric.""" - atol = 1e-4 - @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - @pytest.mark.parametrize("per_class", [True, False]) - def test_mean_iou_class(self, preds, target, include_background, per_class, ddp): + def test_mean_iou_class(self, preds, target, include_background, ddp): """Test class implementation of metric.""" self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=GeneralizedDiceScore, - reference_metric=partial( - _reference_generalized_dice, include_background=include_background, per_class=per_class, reduce=True - ), - metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": per_class}, + reference_metric=partial(_reference_generalized_dice, include_background=include_background, reduce=True), + metric_args={"num_classes": NUM_CLASSES, "include_background": include_background}, ) def test_mean_iou_functional(self, preds, target, include_background): @@ -88,6 +84,6 @@ def test_mean_iou_functional(self, preds, target, include_background): preds=preds, target=target, metric_functional=generalized_dice_score, - reference_metric=partial(_reference_generalized_dice, include_background=include_background), - metric_args={"num_classes": NUM_CLASSES, "include_background": include_background}, - ) \ No newline at end of file + reference_metric=partial(_reference_generalized_dice, include_background=include_background, reduce=False), + metric_args={"num_classes": NUM_CLASSES, "include_background": include_background, "per_class": False}, + ) From 77f151eb7005f3688d6c259c530f975f2fd70a67 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 23 Apr 2024 10:51:22 +0200 Subject: [PATCH 34/39] revert some changes --- tests/unittests/classification/_inputs.py | 31 +++-------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/tests/unittests/classification/_inputs.py b/tests/unittests/classification/_inputs.py index 9c2d6f3e161..9a4981f041c 100644 --- a/tests/unittests/classification/_inputs.py +++ b/tests/unittests/classification/_inputs.py @@ -35,30 +35,15 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) ) -_input_binary_multidim_prob = _Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), -) - _input_binary = _Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)).float(), + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ) -_input_binary_multidim = _Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)).float(), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), -) - _input_binary_logits = _Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) ) -_input_binary_multidim_logits = _Input( - preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), -) - _input_multilabel_prob = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), @@ -74,13 +59,8 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ) -_input_multilabel_multidim_logits = _Input( - preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), -) - -_input_multilabel = Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)).float(), +_input_multilabel = _Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ) @@ -296,7 +276,6 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ) -__mdmc_logits = 10 * torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM) __mdmc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM) __mdmc_prob_preds = __mdmc_prob_preds / __mdmc_prob_preds.sum(dim=2, keepdim=True) @@ -304,10 +283,6 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): preds=__mdmc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) -_input_multidim_multiclass_logits = _Input( - preds=__mdmc_logits, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) -) - _input_multidim_multiclass = _Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), From 5ba5d7d4ca5bd480654871a057208b209b0fe607 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Apr 2024 08:52:07 +0000 Subject: [PATCH 35/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/links.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/links.rst b/docs/source/links.rst index fc7e7f3a425..04b53797c61 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -170,4 +170,4 @@ .. _FLORES-200: https://arxiv.org/abs/2207.04672 .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html .. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013 -.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237 \ No newline at end of file +.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237 From 16eaf2eeb4fc931452c0a215e55a6dfb8c0845fd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 23 Apr 2024 10:53:13 +0200 Subject: [PATCH 36/39] revert changelog --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ab9b5bcd806..ad769e2e706 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Deprecated -- Deprecated `Dice` and `dice` from classification ([#1090](https://github.com/Lightning-AI/metrics/pull/1090)) ### Fixed From 42862834469a03c2f17f33648a4295a4bddfc366 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 23 Apr 2024 11:21:07 +0200 Subject: [PATCH 37/39] fixes to docs --- README.md | 1 + docs/source/index.rst | 8 ++++++++ docs/source/segmentation/generalized_dice.rst | 4 ++-- src/torchmetrics/segmentation/generalized_dice.py | 2 ++ 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2144d89e010..82e370c91f5 100644 --- a/README.md +++ b/README.md @@ -288,6 +288,7 @@ covers the following domains: - Multimodal (Image-Text) - Nominal - Regression +- Segmentation - Text Each domain may require some additional dependencies which can be installed with `pip install torchmetrics[audio]`, diff --git a/docs/source/index.rst b/docs/source/index.rst index a51a1184a78..880a6a2657e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -222,6 +222,14 @@ Or directly from conda retrieval/* +.. toctree:: + :maxdepth: 2 + :name: segmentation + :caption: Segmentation + :glob: + + segmentation/* + .. toctree:: :maxdepth: 2 :name: text diff --git a/docs/source/segmentation/generalized_dice.rst b/docs/source/segmentation/generalized_dice.rst index f662a5dfd89..5c48fc670d1 100644 --- a/docs/source/segmentation/generalized_dice.rst +++ b/docs/source/segmentation/generalized_dice.rst @@ -12,11 +12,11 @@ Generalized Dice Score Module Interface ________________ -.. autoclass:: torchmetrics.GeneralizedDiceScore +.. autoclass:: torchmetrics.segmentation.GeneralizedDiceScore :noindex: Functional Interface ____________________ -.. autofunction:: torchmetrics.functional.generalized_dice_score +.. autofunction:: torchmetrics.functional.segmentation.generalized_dice_score :noindex: diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index 0cfa26626d4..646ba63fbcf 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -158,6 +158,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ .. plot:: :scale: 75 + >>> # Example plotting a single value >>> import torch >>> from torchmetrics.segmentation import GeneralizedDiceScore @@ -167,6 +168,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ .. plot:: :scale: 75 + >>> # Example plotting multiple values >>> import torch >>> from torchmetrics.segmentation import GeneralizedDiceScore From dc72724d39763b385b86603db690eb5a4385b2c0 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 23 Apr 2024 13:30:41 +0200 Subject: [PATCH 38/39] mypy --- src/torchmetrics/functional/segmentation/generalized_dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 2c0ef2ee0f5..65ed68a28a2 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -86,7 +86,7 @@ def _generalized_dice_update( numerator = 2.0 * intersection * weights denominator = cardinality * weights - return numerator, denominator + return numerator, denominator#type:ignore[return-value] def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: bool = True) -> Tensor: From bd375c0a188497784aa8352e86286f8f04e6fea9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Apr 2024 11:31:47 +0000 Subject: [PATCH 39/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/segmentation/generalized_dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 65ed68a28a2..6b740bcea53 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -86,7 +86,7 @@ def _generalized_dice_update( numerator = 2.0 * intersection * weights denominator = cardinality * weights - return numerator, denominator#type:ignore[return-value] + return numerator, denominator # type:ignore[return-value] def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: bool = True) -> Tensor: