diff --git a/CHANGELOG.md b/CHANGELOG.md index 9379fc630b1..45b335f9481 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for nested metric collections ([#1003](https://github.com/PyTorchLightning/metrics/pull/1003)) +- Added `Dice` to classification package ([#1021](https://github.com/PyTorchLightning/metrics/pull/1021)) + + ### Changed - diff --git a/docs/source/classification/dice.rst b/docs/source/classification/dice.rst new file mode 100644 index 00000000000..5b3a6bdf17b --- /dev/null +++ b/docs/source/classification/dice.rst @@ -0,0 +1,33 @@ +.. customcarditem:: + :header: Dice + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Classification + +.. include:: ../links.rst + +#### +Dice +#### + +Module Interface +________________ + +.. autoclass:: torchmetrics.Dice + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.dice + :noindex: + + +########## +Dice Score +########## + +Functional Interface (was deprecated in v0.9) +_____________________________________________ + +.. autofunction:: torchmetrics.functional.dice_score + :noindex: diff --git a/docs/source/classification/dice_score.rst b/docs/source/classification/dice_score.rst deleted file mode 100644 index 37605eb795e..00000000000 --- a/docs/source/classification/dice_score.rst +++ /dev/null @@ -1,14 +0,0 @@ -.. customcarditem:: - :header: Dice Score - :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg - :tags: Classification - -########## -Dice Score -########## - -Functional Interface -____________________ - -.. autofunction:: torchmetrics.functional.dice_score - :noindex: diff --git a/docs/source/links.rst b/docs/source/links.rst index 44c0dffe250..c1f9c8160fa 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -22,6 +22,7 @@ .. _Matthews correlation coefficient: https://en.wikipedia.org/wiki/Matthews_correlation_coefficient .. _Precision: https://en.wikipedia.org/wiki/Precision_and_recall .. _Recall: https://en.wikipedia.org/wiki/Precision_and_recall +.. _Dice: https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient .. _Specificity: https://en.wikipedia.org/wiki/Sensitivity_and_specificity .. _Type I and Type II errors: https://en.wikipedia.org/wiki/Type_I_and_type_II_errors .. _confusion matrix: https://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion diff --git a/tests/classification/test_dice.py b/tests/classification/test_dice.py index 542ec189936..445f926d70e 100644 --- a/tests/classification/test_dice.py +++ b/tests/classification/test_dice.py @@ -11,10 +11,58 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial +from typing import Optional + import pytest -from torch import tensor +from scipy.spatial.distance import dice as _sc_dice +from torch import Tensor, tensor + +from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob +from tests.classification.inputs import _input_multiclass as _input_mcls +from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits +from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class +from tests.classification.inputs import _input_multilabel as _input_mlb +from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits +from tests.classification.inputs import _input_multilabel_multidim as _input_mlmd +from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob +from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.helpers import seed_all +from tests.helpers.testers import MetricTester +from torchmetrics import Dice +from torchmetrics.functional import dice, dice_score +from torchmetrics.functional.classification.stat_scores import _del_column +from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.enums import DataType + +seed_all(42) + + +def _sk_dice( + preds: Tensor, + target: Tensor, + ignore_index: Optional[int] = None, +) -> float: + """Compute dice score from prediction and target. Used scipy implementation of main dice logic. -from torchmetrics.functional import dice_score + Args: + preds: prediction tensor + target: target tensor + ignore_index: + Integer specifying a target class to ignore. Recommend set to index of background class. + Return: + Float dice score + """ + sk_preds, sk_target, mode = _input_format_classification(preds, target) + + if ignore_index is not None and mode != DataType.BINARY: + sk_preds = _del_column(sk_preds, ignore_index) + sk_target = _del_column(sk_target, ignore_index) + + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() + + return 1 - _sc_dice(sk_preds.reshape(-1), sk_target.reshape(-1)) @pytest.mark.parametrize( @@ -29,3 +77,89 @@ def test_dice_score(pred, target, expected): score = dice_score(tensor(pred), tensor(target)) assert score == expected + + +@pytest.mark.parametrize( + ["pred", "target", "expected"], + [ + ([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.0), + ([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.0), + ([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3), + ([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.0), + ], +) +def test_dice(pred, target, expected): + score = dice(tensor(pred), tensor(target), ignore_index=0) + assert score == expected + + +@pytest.mark.parametrize( + "preds, target", + [ + (_input_binary.preds, _input_binary.target), + (_input_binary_logits.preds, _input_binary_logits.target), + (_input_binary_prob.preds, _input_binary_prob.target), + ], +) +@pytest.mark.parametrize("ignore_index", [None]) +class TestDiceBinary(MetricTester): + @pytest.mark.parametrize("ddp", [False]) + @pytest.mark.parametrize("dist_sync_on_step", [False]) + def test_dice_class(self, ddp, dist_sync_on_step, preds, target, ignore_index): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=Dice, + sk_metric=partial(_sk_dice, ignore_index=ignore_index), + dist_sync_on_step=dist_sync_on_step, + metric_args={"ignore_index": ignore_index}, + ) + + def test_dice_fn(self, preds, target, ignore_index): + self.run_functional_metric_test( + preds, + target, + metric_functional=dice, + sk_metric=partial(_sk_dice, ignore_index=ignore_index), + metric_args={"ignore_index": ignore_index}, + ) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_input_mcls.preds, _input_mcls.target), + (_input_mcls_logits.preds, _input_mcls_logits.target), + (_input_mcls_prob.preds, _input_mcls_prob.target), + (_input_miss_class.preds, _input_miss_class.target), + (_input_mlb.preds, _input_mlb.target), + (_input_mlb_logits.preds, _input_mlb_logits.target), + (_input_mlmd.preds, _input_mlmd.target), + (_input_mlmd_prob.preds, _input_mlmd_prob.target), + (_input_mlb_prob.preds, _input_mlb_prob.target), + ], +) +@pytest.mark.parametrize("ignore_index", [None, 0]) +class TestDiceMulti(MetricTester): + @pytest.mark.parametrize("ddp", [False]) + @pytest.mark.parametrize("dist_sync_on_step", [False]) + def test_dice_class(self, ddp, dist_sync_on_step, preds, target, ignore_index): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=Dice, + sk_metric=partial(_sk_dice, ignore_index=ignore_index), + dist_sync_on_step=dist_sync_on_step, + metric_args={"ignore_index": ignore_index}, + ) + + def test_dice_fn(self, preds, target, ignore_index): + self.run_functional_metric_test( + preds, + target, + metric_functional=dice, + sk_metric=partial(_sk_dice, ignore_index=ignore_index), + metric_args={"ignore_index": ignore_index}, + ) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 5857d68b0be..549367ce4da 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -33,6 +33,7 @@ CohenKappa, ConfusionMatrix, CoverageError, + Dice, F1Score, FBetaScore, HammingDistance, @@ -126,6 +127,7 @@ "ConfusionMatrix", "CosineSimilarity", "CoverageError", + "Dice", "TweedieDevianceScore", "ErrorRelativeGlobalDimensionlessSynthesis", "ExplainedVariance", diff --git a/torchmetrics/classification/__init__.py b/torchmetrics/classification/__init__.py index 1ac4bde756a..70ae4d5179c 100644 --- a/torchmetrics/classification/__init__.py +++ b/torchmetrics/classification/__init__.py @@ -21,6 +21,7 @@ from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401 from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401 from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401 +from torchmetrics.classification.dice import Dice # noqa: F401 from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401 from torchmetrics.classification.hamming import HammingDistance # noqa: F401 from torchmetrics.classification.hinge import HingeLoss # noqa: F401 diff --git a/torchmetrics/classification/dice.py b/torchmetrics/classification/dice.py new file mode 100644 index 00000000000..09ff5be205b --- /dev/null +++ b/torchmetrics/classification/dice.py @@ -0,0 +1,162 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +from torch import Tensor + +from torchmetrics.classification.stat_scores import StatScores +from torchmetrics.functional.classification.dice import _dice_compute + + +class Dice(StatScores): + r"""Computes `Dice`_: + + .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} + + Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and + false positives respecitively. + + It is recommend set `ignore_index` to index of background class. + + The reduction method (how the precision scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/classification:input types`. + + Args: + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + threshold: + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. + zero_division: + The value to use for the score if denominator equals zero. + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + .. note:: What is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + (see :ref:`pages/classification:input types`) are flattened into a new ``N_X`` sample axis, i.e. + the inputs are treated as if they were ``(N_X, C)``. + From here on the ``average`` parameter applies as usual. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + + top_k: + Number of the highest probability or logit score predictions considered finding the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. + Should be left at default (``None``) for all other types of inputs. + + multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + ValueError: + If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. + ValueError: + If ``average`` is set but ``num_classes`` is not provided. + ValueError: + If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. + + Example: + >>> import torch + >>> from torchmetrics import Dice + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> dice = Dice(average='micro') + >>> dice(preds, target) + tensor(0.2500) + + """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + zero_division: int = 0, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: str = "micro", + mdmc_average: Optional[str] = "global", + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + **kwargs: Dict[str, Any], + ) -> None: + allowed_average = ("micro", "macro", "weighted", "samples", "none", None) + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + super().__init__( + reduce="macro" if average in ("weighted", "none", None) else average, + mdmc_reduce=mdmc_average, + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, + **kwargs, + ) + + self.average = average + self.zero_division = zero_division + + def compute(self) -> Tensor: + """Computes the dice score based on inputs passed in to ``update`` previously. + + Return: + The shape of the returned tensor depends on the ``average`` parameter: + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes + """ + tp, fp, _, fn = self._get_final_stats() + return _dice_compute(tp, fp, fn, self.average, self.mdmc_reduce, self.zero_division) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index a49f69b6038..08f0ac2e46c 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -21,7 +21,7 @@ from torchmetrics.functional.classification.calibration_error import calibration_error from torchmetrics.functional.classification.cohen_kappa import cohen_kappa from torchmetrics.functional.classification.confusion_matrix import confusion_matrix -from torchmetrics.functional.classification.dice import dice_score +from torchmetrics.functional.classification.dice import dice, dice_score from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score from torchmetrics.functional.classification.hamming import hamming_distance from torchmetrics.functional.classification.hinge import hinge_loss @@ -105,6 +105,7 @@ "coverage_error", "tweedie_deviance_score", "dice_score", + "dice", "error_relative_global_dimensionless_synthesis", "explained_variance", "extended_edit_distance", diff --git a/torchmetrics/functional/classification/__init__.py b/torchmetrics/functional/classification/__init__.py index afda37c8961..70f777b56e0 100644 --- a/torchmetrics/functional/classification/__init__.py +++ b/torchmetrics/functional/classification/__init__.py @@ -18,7 +18,7 @@ from torchmetrics.functional.classification.calibration_error import calibration_error # noqa: F401 from torchmetrics.functional.classification.cohen_kappa import cohen_kappa # noqa: F401 from torchmetrics.functional.classification.confusion_matrix import confusion_matrix # noqa: F401 -from torchmetrics.functional.classification.dice import dice_score # noqa: F401 +from torchmetrics.functional.classification.dice import dice, dice_score # noqa: F401 from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score # noqa: F401 from torchmetrics.functional.classification.hamming import hamming_distance # noqa: F401 from torchmetrics.functional.classification.hinge import hinge_loss # noqa: F401 diff --git a/torchmetrics/functional/classification/dice.py b/torchmetrics/functional/classification/dice.py index ee1e42c8a50..441ff8e69bc 100644 --- a/torchmetrics/functional/classification/dice.py +++ b/torchmetrics/functional/classification/dice.py @@ -11,52 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +import math +from typing import Optional import torch from torch import Tensor from typing_extensions import Literal -from torchmetrics.utilities.data import to_categorical -from torchmetrics.utilities.distributed import reduce - - -def _stat_scores( - preds: Tensor, - target: Tensor, - class_index: int, - argmax_dim: int = 1, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: - """Calculates the number of true positive, false positive, true negative and false negative for a specific - class. - - Args: - preds: prediction tensor - target: target tensor - class_index: class to calculate over - argmax_dim: if pred is a tensor of probabilities, this indicates the - axis the argmax transformation will be applied over - - Return: - True Positive, False Positive, True Negative, False Negative, Support - - Example: - >>> x = torch.tensor([1, 2, 3]) - >>> y = torch.tensor([0, 2, 3]) - >>> tp, fp, tn, fn, sup = _stat_scores(x, y, class_index=1) - >>> tp, fp, tn, fn, sup - (tensor(0), tensor(1), tensor(2), tensor(0), tensor(0)) - """ - if preds.ndim == target.ndim + 1: - preds = to_categorical(preds, argmax_dim=argmax_dim) - - tp = ((preds == class_index) * (target == class_index)).to(torch.long).sum() - fp = ((preds == class_index) * (target != class_index)).to(torch.long).sum() - tn = ((preds != class_index) * (target != class_index)).to(torch.long).sum() - fn = ((preds != class_index) * (target == class_index)).to(torch.long).sum() - sup = (target == class_index).to(torch.long).sum() - - return tp, fp, tn, fn, sup +from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update +from torchmetrics.utilities.checks import _input_squeeze +from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn def dice_score( @@ -69,13 +34,26 @@ def dice_score( ) -> Tensor: """Compute dice score from prediction scores. + Supports only "macro" approach, which mean calculate the metric for each class separately, + and average the metrics across classes (with equal weights for each class). + + .. deprecated:: v0.9 + The `dice_score` function was deprecated in v0.9 and will be removed in v0.10. Use `dice` function instead. + Args: preds: estimated probabilities target: ground-truth labels bg: whether to also compute dice for the background nan_score: score to return, if a NaN occurs during computation - no_fg_score: score to return, if no foreground pixel was found in target - reduction: a method to reduce metric score over labels. + no_fg_score: (default, ``0.0``) score to return, if no foreground pixel was found in target + + .. deprecated:: v0.9 + All different from default options will be changed to default. + + reduction: (default, ``'elementwise_mean'``) a method to reduce metric score over labels. + + .. deprecated:: v0.9 + All different from default options will be changed to default. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum @@ -94,20 +72,232 @@ def dice_score( >>> dice_score(pred, target) tensor(0.3333) """ + rank_zero_warn( + "The `dice_score` function was deprecated in v0.9 and will be removed in v0.10. Use `dice` function instead.", + DeprecationWarning, + ) num_classes = preds.shape[1] - bg_inv = 1 - int(bg) - scores = torch.zeros(num_classes - bg_inv, device=preds.device, dtype=torch.float32) - for i in range(bg_inv, num_classes): - if not (target == i).any(): - # no foreground class - scores[i - bg_inv] += no_fg_score - continue - - # TODO: rewrite to use general `stat_scores` - tp, fp, _, fn, _ = _stat_scores(preds=preds, target=target, class_index=i) - denom = (2 * tp + fp + fn).to(torch.float) - # nan result - score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else nan_score - - scores[i - bg_inv] += score_cls - return reduce(scores, reduction=reduction) + + if no_fg_score != 0.0: + no_fg_score = 0.0 + rank_zero_warn(f"Deprecated parameter. Switched to default `no_fg_score` = {no_fg_score}.") + + if reduction != "elementwise_mean": + reduction = "elementwise_mean" + rank_zero_warn(f"Deprecated parameter. Switched to default `reduction` = {reduction}.") + + zero_division = math.floor(nan_score) + if zero_division != nan_score: + rank_zero_warn(f"Deprecated parameter. `nan_score` converted to integer {zero_division}.") + + ignore_index = None + if not bg: + ignore_index = 0 + + return dice( + preds, + target, + ignore_index=ignore_index, + average="macro", + num_classes=num_classes, + zero_division=zero_division, + ) + + +def _dice_compute( + tp: Tensor, + fp: Tensor, + fn: Tensor, + average: str, + mdmc_average: Optional[str], + zero_division: int = 0, +) -> Tensor: + """Computes dice from the stat scores: true positives, false positives, false negatives. + + Args: + tp: True positives + fp: False positives + fn: False negatives + average: Defines the reduction that is applied + mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter) + + Example: + >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update + >>> from torchmetrics.functional.classification.dice import _dice_compute + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro') + >>> _dice_compute(tp, fp, fn, average='micro', mdmc_average=None) + tensor(0.2500) + """ + numerator = 2 * tp + denominator = 2 * tp + fp + fn + + if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + cond = tp + fp + fn == 0 + numerator = numerator[~cond] + denominator = denominator[~cond] + + if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + # a class is not present if there exists no TPs, no FPs, and no FNs + meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu() + numerator[meaningless_indeces, ...] = -1 + denominator[meaningless_indeces, ...] = -1 + + return _reduce_stat_scores( + numerator=numerator, + denominator=denominator, + weights=None if average != "weighted" else tp + fn, + average=average, + mdmc_average=mdmc_average, + zero_division=zero_division, + ) + + +def dice( + preds: Tensor, + target: Tensor, + zero_division: int = 0, + average: str = "micro", + mdmc_average: Optional[str] = "global", + threshold: float = 0.5, + top_k: Optional[int] = None, + num_classes: Optional[int] = None, + multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, +) -> Tensor: + r"""Computes `Dice`_: + + .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} + + Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and + false negatives respecitively. + + It is recommend set `ignore_index` to index of background class. + + The reduction method (how the recall scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/classification:input types`. + + Args: + preds: Predictions from model (probabilities, logits or labels) + target: Ground truth values + zero_division: The value to use for the score if denominator equals zero + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + .. note:: What is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, + the value for the class will be ``nan``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + (see :ref:`pages/classification:input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + + threshold: + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. + top_k: + Number of the highest probability or logit score predictions considered finding the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. + multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + Return: + The shape of the returned tensor depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number of classes + + Raises: + ValueError: + If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"`` or ``None`` + ValueError: + If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. + ValueError: + If ``average`` is set but ``num_classes`` is not provided. + ValueError: + If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. + + Example: + >>> from torchmetrics.functional import dice + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> dice(preds, target, average='micro') + tensor(0.2500) + + """ + allowed_average = ("micro", "macro", "weighted", "samples", "none", None) + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): + raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") + + allowed_mdmc_average = [None, "samplewise", "global"] + if mdmc_average not in allowed_mdmc_average: + raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") + + if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): + raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") + + preds, target = _input_squeeze(preds, target) + reduce = "macro" if average in ("weighted", "none", None) else average + + tp, fp, _, fn = _stat_scores_update( + preds, + target, + reduce=reduce, + mdmc_reduce=mdmc_average, + threshold=threshold, + num_classes=num_classes, + top_k=top_k, + multiclass=multiclass, + ignore_index=ignore_index, + ) + + return _dice_compute(tp, fp, fn, average, mdmc_average, zero_division)