Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dice score as metric #1021

Merged
merged 42 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
bf2fad2
dice metric as `statescores`
May 10, 2022
0050886
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
b78f6ae
doctest fix
May 10, 2022
ac77a9d
Merge remote-tracking branch 'origin/dice_metric_as_state_scores' int…
May 10, 2022
411dc10
Apply suggestions from code review
justusschock May 10, 2022
cd532cb
Merge branch 'master' into dice_metric_as_state_scores
Borda May 10, 2022
f6c8e78
test dice from scipy
May 10, 2022
25470d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
dbb2306
fix doctest
May 10, 2022
fd7ecd4
Merge remote-tracking branch 'origin/dice_metric_as_state_scores' int…
May 10, 2022
064f91c
fix docs
May 11, 2022
70ce63a
Merge branch 'master' into dice_metric_as_state_scores
Borda May 12, 2022
2011192
Merge branch 'master' into dice_metric_as_state_scores
Borda May 12, 2022
ef7a841
Update torchmetrics/classification/dice.py
MrShevan May 12, 2022
e762fd7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2022
b6bd1f2
Merge branch 'master' into dice_metric_as_state_scores
SkafteNicki May 13, 2022
00cb642
changelog
SkafteNicki May 13, 2022
157b665
update docs
SkafteNicki May 13, 2022
1ec33de
add links to docs
May 13, 2022
9f977aa
fis dice docs
May 13, 2022
7b42b5b
removed unnecessary
May 13, 2022
2fc195e
fix links.rst
May 13, 2022
65d8a87
rename dice.rst
May 13, 2022
a316b17
add ::incude:: in dice.rst
May 13, 2022
4bd4888
Merge branch 'master' into dice_metric_as_state_scores
Borda May 14, 2022
767affb
changename functional Interface
May 14, 2022
9f16c51
fix too short sphinx error
May 14, 2022
aecf684
Merge branch 'master' into dice_metric_as_state_scores
MrShevan May 15, 2022
0fcfc3b
add deprecation warning for `dice_score`
May 15, 2022
7f69ce8
Merge remote-tracking branch 'origin/dice_metric_as_state_scores' int…
May 15, 2022
c40f262
fix doc
May 15, 2022
19a40f9
Update torchmetrics/functional/classification/dice.py
MrShevan May 15, 2022
a126af2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2022
4ef2cb1
Update torchmetrics/functional/classification/dice.py
MrShevan May 15, 2022
cfb6870
Update torchmetrics/functional/classification/dice.py
MrShevan May 15, 2022
fe4c0d4
Update torchmetrics/classification/dice.py
MrShevan May 15, 2022
e3d37f1
Update torchmetrics/classification/dice.py
MrShevan May 15, 2022
a97b712
replace `dice_score` logic to `dice` func
May 15, 2022
618da06
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2022
7785894
fix pep
May 15, 2022
a26a4bb
Merge remote-tracking branch 'origin/dice_metric_as_state_scores' int…
May 15, 2022
508fbe1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for nested metric collections ([#1003](https://github.com/PyTorchLightning/metrics/pull/1003))


- Added `Dice` to classification package ([#1021](https://github.com/PyTorchLightning/metrics/pull/1021))


### Changed

-
Expand Down
33 changes: 33 additions & 0 deletions docs/source/classification/dice.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
.. customcarditem::
:header: Dice
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Classification

.. include:: ../links.rst

####
Dice
####

Module Interface
________________

.. autoclass:: torchmetrics.Dice
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.dice
:noindex:


##########
Dice Score
##########

Functional Interface (was deprecated in v0.9)
_____________________________________________

.. autofunction:: torchmetrics.functional.dice_score
:noindex:
14 changes: 0 additions & 14 deletions docs/source/classification/dice_score.rst

This file was deleted.

1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
.. _Matthews correlation coefficient: https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
.. _Precision: https://en.wikipedia.org/wiki/Precision_and_recall
.. _Recall: https://en.wikipedia.org/wiki/Precision_and_recall
.. _Dice: https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient
.. _Specificity: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
.. _Type I and Type II errors: https://en.wikipedia.org/wiki/Type_I_and_type_II_errors
.. _confusion matrix: https://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion
Expand Down
138 changes: 136 additions & 2 deletions tests/classification/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,58 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional

import pytest
from torch import tensor
from scipy.spatial.distance import dice as _sc_dice
from torch import Tensor, tensor

from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_multidim as _input_mlmd
from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from torchmetrics import Dice
from torchmetrics.functional import dice, dice_score
from torchmetrics.functional.classification.stat_scores import _del_column
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType

seed_all(42)


def _sk_dice(
preds: Tensor,
target: Tensor,
ignore_index: Optional[int] = None,
) -> float:
"""Compute dice score from prediction and target. Used scipy implementation of main dice logic.

from torchmetrics.functional import dice_score
Args:
preds: prediction tensor
target: target tensor
ignore_index:
Integer specifying a target class to ignore. Recommend set to index of background class.
Return:
Float dice score
"""
sk_preds, sk_target, mode = _input_format_classification(preds, target)

if ignore_index is not None and mode != DataType.BINARY:
sk_preds = _del_column(sk_preds, ignore_index)
sk_target = _del_column(sk_target, ignore_index)

sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

return 1 - _sc_dice(sk_preds.reshape(-1), sk_target.reshape(-1))


@pytest.mark.parametrize(
Expand All @@ -29,3 +77,89 @@
def test_dice_score(pred, target, expected):
score = dice_score(tensor(pred), tensor(target))
assert score == expected


@pytest.mark.parametrize(
["pred", "target", "expected"],
[
([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.0),
([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.0),
([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3),
([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.0),
],
)
def test_dice(pred, target, expected):
score = dice(tensor(pred), tensor(target), ignore_index=0)
assert score == expected


@pytest.mark.parametrize(
"preds, target",
[
(_input_binary.preds, _input_binary.target),
(_input_binary_logits.preds, _input_binary_logits.target),
(_input_binary_prob.preds, _input_binary_prob.target),
],
)
@pytest.mark.parametrize("ignore_index", [None])
class TestDiceBinary(MetricTester):
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [False])
def test_dice_class(self, ddp, dist_sync_on_step, preds, target, ignore_index):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=Dice,
sk_metric=partial(_sk_dice, ignore_index=ignore_index),
dist_sync_on_step=dist_sync_on_step,
metric_args={"ignore_index": ignore_index},
)

def test_dice_fn(self, preds, target, ignore_index):
self.run_functional_metric_test(
preds,
target,
metric_functional=dice,
sk_metric=partial(_sk_dice, ignore_index=ignore_index),
metric_args={"ignore_index": ignore_index},
)


@pytest.mark.parametrize(
"preds, target",
[
(_input_mcls.preds, _input_mcls.target),
(_input_mcls_logits.preds, _input_mcls_logits.target),
(_input_mcls_prob.preds, _input_mcls_prob.target),
(_input_miss_class.preds, _input_miss_class.target),
(_input_mlb.preds, _input_mlb.target),
(_input_mlb_logits.preds, _input_mlb_logits.target),
(_input_mlmd.preds, _input_mlmd.target),
(_input_mlmd_prob.preds, _input_mlmd_prob.target),
(_input_mlb_prob.preds, _input_mlb_prob.target),
],
)
@pytest.mark.parametrize("ignore_index", [None, 0])
class TestDiceMulti(MetricTester):
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [False])
def test_dice_class(self, ddp, dist_sync_on_step, preds, target, ignore_index):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=Dice,
sk_metric=partial(_sk_dice, ignore_index=ignore_index),
dist_sync_on_step=dist_sync_on_step,
metric_args={"ignore_index": ignore_index},
)

def test_dice_fn(self, preds, target, ignore_index):
self.run_functional_metric_test(
preds,
target,
metric_functional=dice,
sk_metric=partial(_sk_dice, ignore_index=ignore_index),
metric_args={"ignore_index": ignore_index},
)
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CohenKappa,
ConfusionMatrix,
CoverageError,
Dice,
F1Score,
FBetaScore,
HammingDistance,
Expand Down Expand Up @@ -126,6 +127,7 @@
"ConfusionMatrix",
"CosineSimilarity",
"CoverageError",
"Dice",
"TweedieDevianceScore",
"ErrorRelativeGlobalDimensionlessSynthesis",
"ExplainedVariance",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401
from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401
from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
from torchmetrics.classification.dice import Dice # noqa: F401
from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401
from torchmetrics.classification.hamming import HammingDistance # noqa: F401
from torchmetrics.classification.hinge import HingeLoss # noqa: F401
Expand Down
155 changes: 155 additions & 0 deletions torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional

from torch import Tensor

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.dice import _dice_compute


class Dice(StatScores):
r"""Computes `Dice`_:

.. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}}

Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and
false positives respecitively.

It is recommend set `ignore_index` to index of background class.

The reduction method (how the precision scores are aggregated) is controlled by the
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/classification:input types`.

Args:
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
threshold:
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
zero_division:
The value to use for the score if denominator equals zero.
average:
Defines the reduction that is applied. Should be one of the following:

- ``'micro'`` [default]: Calculate the metric globally, across all samples and classes.
- ``'macro'``: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
- ``'samples'``: Calculate the metric for each sample, and average the metrics
across samples (with equal weights for each sample).

.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.

mdmc_average:
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter). Should be one of the following:

- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
multi-class.

- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.

- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`pages/classification:input types`) are flattened into a new ``N_X`` sample axis, i.e.
the inputs are treated as if they were ``(N_X, C)``.
From here on the ``average`` parameter applies as usual.

ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.

top_k:
Number of the highest probability or logit score predictions considered finding the correct label,
relevant only for (multi-dimensional) multi-class inputs. The
default value (``None``) will be interpreted as 1 for these inputs.
Should be left at default (``None``) for all other types of inputs.

multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <pages/classification:using the multiclass parameter>`
for a more detailed explanation and examples.

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``.
MrShevan marked this conversation as resolved.
Show resolved Hide resolved

Example:
>>> import torch
>>> from torchmetrics import Dice
>>> preds = torch.tensor([2, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> dice = Dice(average='micro')
>>> dice(preds, target)
tensor(0.2500)

"""
is_differentiable = False
higher_is_better = True
MrShevan marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
zero_division: int = 0,
num_classes: Optional[int] = None,
threshold: float = 0.5,
average: str = "micro",
mdmc_average: Optional[str] = "global",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
allowed_average = ("micro", "macro", "weighted", "samples", "none", None)
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

super().__init__(
reduce="macro" if average in ("weighted", "none", None) else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
multiclass=multiclass,
ignore_index=ignore_index,
**kwargs,
)

self.average = average
self.zero_division = zero_division

def compute(self) -> Tensor:
"""Computes the dice score based on inputs passed in to ``update`` previously.

Return:
The shape of the returned tensor depends on the ``average`` parameter:

- If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned
- If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
of classes
"""
tp, fp, _, fn = self._get_final_stats()
return _dice_compute(tp, fp, fn, self.average, self.mdmc_reduce, self.zero_division)
Loading