Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dice score as metric #1021

Merged
merged 42 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
bf2fad2
dice metric as `statescores`
May 10, 2022
0050886
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
b78f6ae
doctest fix
May 10, 2022
ac77a9d
Merge remote-tracking branch 'origin/dice_metric_as_state_scores' int…
May 10, 2022
411dc10
Apply suggestions from code review
justusschock May 10, 2022
cd532cb
Merge branch 'master' into dice_metric_as_state_scores
Borda May 10, 2022
f6c8e78
test dice from scipy
May 10, 2022
25470d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
dbb2306
fix doctest
May 10, 2022
fd7ecd4
Merge remote-tracking branch 'origin/dice_metric_as_state_scores' int…
May 10, 2022
064f91c
fix docs
May 11, 2022
70ce63a
Merge branch 'master' into dice_metric_as_state_scores
Borda May 12, 2022
2011192
Merge branch 'master' into dice_metric_as_state_scores
Borda May 12, 2022
ef7a841
Update torchmetrics/classification/dice.py
MrShevan May 12, 2022
e762fd7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2022
b6bd1f2
Merge branch 'master' into dice_metric_as_state_scores
SkafteNicki May 13, 2022
00cb642
changelog
SkafteNicki May 13, 2022
157b665
update docs
SkafteNicki May 13, 2022
1ec33de
add links to docs
May 13, 2022
9f977aa
fis dice docs
May 13, 2022
7b42b5b
removed unnecessary
May 13, 2022
2fc195e
fix links.rst
May 13, 2022
65d8a87
rename dice.rst
May 13, 2022
a316b17
add ::incude:: in dice.rst
May 13, 2022
4bd4888
Merge branch 'master' into dice_metric_as_state_scores
Borda May 14, 2022
767affb
changename functional Interface
May 14, 2022
9f16c51
fix too short sphinx error
May 14, 2022
aecf684
Merge branch 'master' into dice_metric_as_state_scores
MrShevan May 15, 2022
0fcfc3b
add deprecation warning for `dice_score`
May 15, 2022
7f69ce8
Merge remote-tracking branch 'origin/dice_metric_as_state_scores' int…
May 15, 2022
c40f262
fix doc
May 15, 2022
19a40f9
Update torchmetrics/functional/classification/dice.py
MrShevan May 15, 2022
a126af2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2022
4ef2cb1
Update torchmetrics/functional/classification/dice.py
MrShevan May 15, 2022
cfb6870
Update torchmetrics/functional/classification/dice.py
MrShevan May 15, 2022
fe4c0d4
Update torchmetrics/classification/dice.py
MrShevan May 15, 2022
e3d37f1
Update torchmetrics/classification/dice.py
MrShevan May 15, 2022
a97b712
replace `dice_score` logic to `dice` func
May 15, 2022
618da06
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2022
7785894
fix pep
May 15, 2022
a26a4bb
Merge remote-tracking branch 'origin/dice_metric_as_state_scores' int…
May 15, 2022
508fbe1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 95 additions & 2 deletions tests/classification/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,60 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

import pytest
from torch import tensor
import torch
from torch import Tensor, tensor

from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from torchmetrics import Dice
from torchmetrics.functional import dice, dice_score
from torchmetrics.functional.classification.dice import _stat_scores

seed_all(42)


def _dice_score(
MrShevan marked this conversation as resolved.
Show resolved Hide resolved
preds: Tensor,
target: Tensor,
background: bool = False,
nan_score: float = 0.0,
) -> Tensor:
"""
Compute dice score from prediction scores.
There is no implementation of Dice in sklearn. I used public information about
metric: `https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient`

Args:
preds: prediction tensor
target: target tensor
background: whether to also compute dice for the background
nan_score: the value to use for the score if denominator equals zero

Return:
Tensor containing dice score

from torchmetrics.functional import dice_score
"""
num_classes = preds.shape[1]
bg_inv = 1 - int(background)

tp = tensor(0, device=preds.device)
fp = tensor(0, device=preds.device)
fn = tensor(0, device=preds.device)

for i in range(bg_inv, num_classes):
tp_cls, fp_cls, _, fn_cls, _ = _stat_scores(preds=preds, target=target, class_index=i)

tp += tp_cls
fp += fp_cls
fn += fn_cls

denom = (2 * tp + fp + fn).to(torch.float)
score = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else nan_score
return score


@pytest.mark.parametrize(
Expand All @@ -29,3 +79,46 @@
def test_dice_score(pred, target, expected):
score = dice_score(tensor(pred), tensor(target))
assert score == expected


@pytest.mark.parametrize(
["pred", "target", "expected"],
[
([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.0),
([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.0),
([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3),
([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.0),
],
)
def test_dice(pred, target, expected):
score = dice(tensor(pred), tensor(target))
assert score == expected


@pytest.mark.parametrize(
"preds, target",
[(_input_mcls_prob.preds, _input_mcls_prob.target)],
)
@pytest.mark.parametrize("background", [False, True])
class TestDice(MetricTester):
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [False])
def test_dice_class(self, ddp, dist_sync_on_step, preds, target, background):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=Dice,
sk_metric=partial(_dice_score, background=background),
dist_sync_on_step=dist_sync_on_step,
metric_args={"background": background},
)

def test_dice_fn(self, preds, target, background):
self.run_functional_metric_test(
preds,
target,
metric_functional=dice,
sk_metric=partial(_dice_score, background=background),
metric_args={"background": background},
)
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CohenKappa,
ConfusionMatrix,
CoverageError,
Dice,
F1Score,
FBetaScore,
HammingDistance,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401
from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401
from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
from torchmetrics.classification.dice import Dice # noqa: F401
from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401
from torchmetrics.classification.hamming import HammingDistance # noqa: F401
from torchmetrics.classification.hinge import HingeLoss # noqa: F401
Expand Down
163 changes: 163 additions & 0 deletions torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional

from torch import Tensor

from torchmetrics.functional.classification.dice import _dice_compute

from torchmetrics.classification.stat_scores import StatScores # isort:skip
MrShevan marked this conversation as resolved.
Show resolved Hide resolved


class Dice(StatScores):
r"""Computes `Dice`_:

.. math:: \text{Precision} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}}
justusschock marked this conversation as resolved.
Show resolved Hide resolved

Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and
false positives respecitively.

The reduction method (how the precision scores are aggregated) is controlled by the
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/classification:input types`.

Args:
background:
Whether to also compute dice for the background.
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
threshold:
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
zero_division:
The value to use for the score if denominator equals zero.
average:
Defines the reduction that is applied. Should be one of the following:

- ``'micro'`` [default]: Calculate the metric globally, across all samples and classes.
- ``'macro'``: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
- ``'samples'``: Calculate the metric for each sample, and average the metrics
across samples (with equal weights for each sample).

.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.

mdmc_average:
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter). Should be one of the following:

- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
multi-class.

- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.

- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`pages/classification:input types`) are flattened into a new ``N_X`` sample axis, i.e.
the inputs are treated as if they were ``(N_X, C)``.
From here on the ``average`` parameter applies as usual.

ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.

top_k:
Number of the highest probability or logit score predictions considered finding the correct label,
relevant only for (multi-dimensional) multi-class inputs. The
default value (``None``) will be interpreted as 1 for these inputs.
Should be left at default (``None``) for all other types of inputs.

multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <pages/classification:using the multiclass parameter>`
for a more detailed explanation and examples.

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``.
MrShevan marked this conversation as resolved.
Show resolved Hide resolved

Example:
>>> import torch
>>> from torchmetrics import Dice
>>> preds = torch.tensor([2, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> dice = Dice(average='micro')
>>> dice(preds, target)
tensor(0.3333)

"""
is_differentiable = False
higher_is_better = True
MrShevan marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
background: bool = False,
zero_division: int = 0,
num_classes: Optional[int] = None,
threshold: float = 0.5,
average: str = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
justusschock marked this conversation as resolved.
Show resolved Hide resolved
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if not background and ignore_index is None:
# not compute dice for the background
ignore_index = 0
elif ignore_index is not None:
raise ValueError("When you set `ignore_index`, you have to set background `bg` to True.")

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
justusschock marked this conversation as resolved.
Show resolved Hide resolved
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
multiclass=multiclass,
ignore_index=ignore_index,
**kwargs,
)

self.average = average
self.zero_division = zero_division

def compute(self) -> Tensor:
"""Computes the dice score based on inputs passed in to ``update`` previously.

Return:
The shape of the returned tensor depends on the ``average`` parameter:

- If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned
- If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
of classes
"""
tp, fp, _, fn = self._get_final_stats()
return _dice_compute(tp, fp, fn, self.average, self.mdmc_reduce, self.zero_division)
3 changes: 2 additions & 1 deletion torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchmetrics.functional.classification.calibration_error import calibration_error
from torchmetrics.functional.classification.cohen_kappa import cohen_kappa
from torchmetrics.functional.classification.confusion_matrix import confusion_matrix
from torchmetrics.functional.classification.dice import dice_score
from torchmetrics.functional.classification.dice import dice, dice_score
from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score
from torchmetrics.functional.classification.hamming import hamming_distance
from torchmetrics.functional.classification.hinge import hinge_loss
Expand Down Expand Up @@ -105,6 +105,7 @@
"coverage_error",
"tweedie_deviance_score",
"dice_score",
"dice",
"error_relative_global_dimensionless_synthesis",
"explained_variance",
"extended_edit_distance",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchmetrics.functional.classification.calibration_error import calibration_error # noqa: F401
from torchmetrics.functional.classification.cohen_kappa import cohen_kappa # noqa: F401
from torchmetrics.functional.classification.confusion_matrix import confusion_matrix # noqa: F401
from torchmetrics.functional.classification.dice import dice_score # noqa: F401
from torchmetrics.functional.classification.dice import dice, dice_score # noqa: F401
from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score # noqa: F401
from torchmetrics.functional.classification.hamming import hamming_distance # noqa: F401
from torchmetrics.functional.classification.hinge import hinge_loss # noqa: F401
Expand Down
Loading