Skip to content

Dice score as metric #1021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
bf2fad2
dice metric as `statescores`
May 10, 2022
0050886
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
b78f6ae
doctest fix
May 10, 2022
ac77a9d
Merge remote-tracking branch 'origin/dice_metric_as_state_scores' int…
May 10, 2022
411dc10
Apply suggestions from code review
justusschock May 10, 2022
cd532cb
Merge branch 'master' into dice_metric_as_state_scores
Borda May 10, 2022
f6c8e78
test dice from scipy
May 10, 2022
25470d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
dbb2306
fix doctest
May 10, 2022
fd7ecd4
Merge remote-tracking branch 'origin/dice_metric_as_state_scores' int…
May 10, 2022
064f91c
fix docs
May 11, 2022
70ce63a
Merge branch 'master' into dice_metric_as_state_scores
Borda May 12, 2022
2011192
Merge branch 'master' into dice_metric_as_state_scores
Borda May 12, 2022
ef7a841
Update torchmetrics/classification/dice.py
MrShevan May 12, 2022
e762fd7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2022
b6bd1f2
Merge branch 'master' into dice_metric_as_state_scores
SkafteNicki May 13, 2022
00cb642
changelog
SkafteNicki May 13, 2022
157b665
update docs
SkafteNicki May 13, 2022
1ec33de
add links to docs
May 13, 2022
9f977aa
fis dice docs
May 13, 2022
7b42b5b
removed unnecessary
May 13, 2022
2fc195e
fix links.rst
May 13, 2022
65d8a87
rename dice.rst
May 13, 2022
a316b17
add ::incude:: in dice.rst
May 13, 2022
4bd4888
Merge branch 'master' into dice_metric_as_state_scores
Borda May 14, 2022
767affb
changename functional Interface
May 14, 2022
9f16c51
fix too short sphinx error
May 14, 2022
aecf684
Merge branch 'master' into dice_metric_as_state_scores
MrShevan May 15, 2022
0fcfc3b
add deprecation warning for `dice_score`
May 15, 2022
7f69ce8
Merge remote-tracking branch 'origin/dice_metric_as_state_scores' int…
May 15, 2022
c40f262
fix doc
May 15, 2022
19a40f9
Update torchmetrics/functional/classification/dice.py
MrShevan May 15, 2022
a126af2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2022
4ef2cb1
Update torchmetrics/functional/classification/dice.py
MrShevan May 15, 2022
cfb6870
Update torchmetrics/functional/classification/dice.py
MrShevan May 15, 2022
fe4c0d4
Update torchmetrics/classification/dice.py
MrShevan May 15, 2022
e3d37f1
Update torchmetrics/classification/dice.py
MrShevan May 15, 2022
a97b712
replace `dice_score` logic to `dice` func
May 15, 2022
618da06
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2022
7785894
fix pep
May 15, 2022
a26a4bb
Merge remote-tracking branch 'origin/dice_metric_as_state_scores' int…
May 15, 2022
508fbe1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for nested metric collections ([#1003](https://github.com/PyTorchLightning/metrics/pull/1003))


- Added `Dice` to classification package ([#1021](https://github.com/PyTorchLightning/metrics/pull/1021))


### Changed

-
Expand Down
33 changes: 33 additions & 0 deletions docs/source/classification/dice.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
.. customcarditem::
:header: Dice
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Classification

.. include:: ../links.rst

####
Dice
####

Module Interface
________________

.. autoclass:: torchmetrics.Dice
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.dice
:noindex:


##########
Dice Score
##########

Functional Interface (was deprecated in v0.9)
_____________________________________________

.. autofunction:: torchmetrics.functional.dice_score
:noindex:
14 changes: 0 additions & 14 deletions docs/source/classification/dice_score.rst

This file was deleted.

1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
.. _Matthews correlation coefficient: https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
.. _Precision: https://en.wikipedia.org/wiki/Precision_and_recall
.. _Recall: https://en.wikipedia.org/wiki/Precision_and_recall
.. _Dice: https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient
.. _Specificity: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
.. _Type I and Type II errors: https://en.wikipedia.org/wiki/Type_I_and_type_II_errors
.. _confusion matrix: https://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion
Expand Down
138 changes: 136 additions & 2 deletions tests/classification/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,58 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional

import pytest
from torch import tensor
from scipy.spatial.distance import dice as _sc_dice
from torch import Tensor, tensor

from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_multidim as _input_mlmd
from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from torchmetrics import Dice
from torchmetrics.functional import dice, dice_score
from torchmetrics.functional.classification.stat_scores import _del_column
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType

seed_all(42)


def _sk_dice(
preds: Tensor,
target: Tensor,
ignore_index: Optional[int] = None,
) -> float:
"""Compute dice score from prediction and target. Used scipy implementation of main dice logic.

from torchmetrics.functional import dice_score
Args:
preds: prediction tensor
target: target tensor
ignore_index:
Integer specifying a target class to ignore. Recommend set to index of background class.
Return:
Float dice score
"""
sk_preds, sk_target, mode = _input_format_classification(preds, target)

if ignore_index is not None and mode != DataType.BINARY:
sk_preds = _del_column(sk_preds, ignore_index)
sk_target = _del_column(sk_target, ignore_index)

sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()

return 1 - _sc_dice(sk_preds.reshape(-1), sk_target.reshape(-1))


@pytest.mark.parametrize(
Expand All @@ -29,3 +77,89 @@
def test_dice_score(pred, target, expected):
score = dice_score(tensor(pred), tensor(target))
assert score == expected


@pytest.mark.parametrize(
["pred", "target", "expected"],
[
([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.0),
([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.0),
([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3),
([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.0),
],
)
def test_dice(pred, target, expected):
score = dice(tensor(pred), tensor(target), ignore_index=0)
assert score == expected


@pytest.mark.parametrize(
"preds, target",
[
(_input_binary.preds, _input_binary.target),
(_input_binary_logits.preds, _input_binary_logits.target),
(_input_binary_prob.preds, _input_binary_prob.target),
],
)
@pytest.mark.parametrize("ignore_index", [None])
class TestDiceBinary(MetricTester):
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [False])
def test_dice_class(self, ddp, dist_sync_on_step, preds, target, ignore_index):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=Dice,
sk_metric=partial(_sk_dice, ignore_index=ignore_index),
dist_sync_on_step=dist_sync_on_step,
metric_args={"ignore_index": ignore_index},
)

def test_dice_fn(self, preds, target, ignore_index):
self.run_functional_metric_test(
preds,
target,
metric_functional=dice,
sk_metric=partial(_sk_dice, ignore_index=ignore_index),
metric_args={"ignore_index": ignore_index},
)


@pytest.mark.parametrize(
"preds, target",
[
(_input_mcls.preds, _input_mcls.target),
(_input_mcls_logits.preds, _input_mcls_logits.target),
(_input_mcls_prob.preds, _input_mcls_prob.target),
(_input_miss_class.preds, _input_miss_class.target),
(_input_mlb.preds, _input_mlb.target),
(_input_mlb_logits.preds, _input_mlb_logits.target),
(_input_mlmd.preds, _input_mlmd.target),
(_input_mlmd_prob.preds, _input_mlmd_prob.target),
(_input_mlb_prob.preds, _input_mlb_prob.target),
],
)
@pytest.mark.parametrize("ignore_index", [None, 0])
class TestDiceMulti(MetricTester):
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [False])
def test_dice_class(self, ddp, dist_sync_on_step, preds, target, ignore_index):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=Dice,
sk_metric=partial(_sk_dice, ignore_index=ignore_index),
dist_sync_on_step=dist_sync_on_step,
metric_args={"ignore_index": ignore_index},
)

def test_dice_fn(self, preds, target, ignore_index):
self.run_functional_metric_test(
preds,
target,
metric_functional=dice,
sk_metric=partial(_sk_dice, ignore_index=ignore_index),
metric_args={"ignore_index": ignore_index},
)
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CohenKappa,
ConfusionMatrix,
CoverageError,
Dice,
F1Score,
FBetaScore,
HammingDistance,
Expand Down Expand Up @@ -126,6 +127,7 @@
"ConfusionMatrix",
"CosineSimilarity",
"CoverageError",
"Dice",
"TweedieDevianceScore",
"ErrorRelativeGlobalDimensionlessSynthesis",
"ExplainedVariance",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401
from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401
from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
from torchmetrics.classification.dice import Dice # noqa: F401
from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401
from torchmetrics.classification.hamming import HammingDistance # noqa: F401
from torchmetrics.classification.hinge import HingeLoss # noqa: F401
Expand Down
Loading