Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dice score when zero overlap between preds and target #2860

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
50982ab
fix implementation
SkafteNicki Dec 3, 2024
059f6df
add tests
SkafteNicki Dec 3, 2024
4b07905
changelog
SkafteNicki Dec 3, 2024
3bf4f09
introduce zero_division argument
SkafteNicki Dec 4, 2024
37d2530
add tests
SkafteNicki Dec 4, 2024
539ca53
changelog
SkafteNicki Dec 4, 2024
5002b6c
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
Borda Dec 16, 2024
57a8370
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Dec 17, 2024
ef614d1
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Dec 17, 2024
e744160
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Dec 17, 2024
de0d4e5
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Dec 17, 2024
ffc9f0b
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Dec 19, 2024
f05b156
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
Borda Dec 21, 2024
d59fad8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2024
484f8ad
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Dec 21, 2024
1af846d
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
Borda Dec 21, 2024
35c2abc
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Dec 21, 2024
6ce1542
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Dec 24, 2024
9958d82
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Dec 24, 2024
5f7efad
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Dec 24, 2024
fe6dbe4
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Dec 25, 2024
6aef881
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Dec 31, 2024
26dee23
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Jan 2, 2025
ac3e51e
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Jan 6, 2025
b22f6d7
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Jan 6, 2025
dfabd85
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Jan 6, 2025
5486901
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Jan 7, 2025
ab5070d
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
Borda Jan 8, 2025
7984f44
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2025
38bebe8
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Jan 13, 2025
05fa2a6
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Jan 13, 2025
4b42206
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Jan 21, 2025
221d15b
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
mergify[bot] Jan 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `zero_division` argument to `DiceScore` in segmentation package ([#2860](https://github.com/PyTorchLightning/metrics/pull/2860))


### Changed
Expand All @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed `DiceScore` when there is zero overlap between predictions and targets ([#2860](https://github.com/PyTorchLightning/metrics/pull/2860))


- Fixed plotting of multilabel confusion matrix ([#2858](https://github.com/PyTorchLightning/metrics/pull/2858))


Expand Down
16 changes: 13 additions & 3 deletions src/torchmetrics/functional/segmentation/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Union

import torch
from torch import Tensor
Expand All @@ -27,6 +27,7 @@ def _dice_score_validate_args(
include_background: bool,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
input_format: Literal["one-hot", "index"] = "one-hot",
zero_divide: Union[float, Literal["warn", "nan"]] = 1.0,
) -> None:
"""Validate the arguments of the metric."""
if not isinstance(num_classes, int) or num_classes <= 0:
Expand All @@ -38,6 +39,10 @@ def _dice_score_validate_args(
raise ValueError(f"Expected argument `average` to be one of {allowed_average} or None, but got {average}.")
if input_format not in ["one-hot", "index"]:
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")
if zero_divide not in [1.0, 0.0, "warn", "nan"]:
raise ValueError(
f"Expected argument `zero_divide` to be one of 1.0, 0.0, 'warn', 'nan', but got {zero_divide}."
)


def _dice_score_update(
Expand Down Expand Up @@ -76,16 +81,21 @@ def _dice_score_compute(
denominator: Tensor,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
support: Optional[Tensor] = None,
zero_division: Union[float, Literal["warn", "nan"]] = 1.0,
) -> Tensor:
"""Compute the Dice score from the numerator and denominator."""
# If both numerator and denominator are 0, the dice score is 0
if torch.all(numerator == 0) and torch.all(denominator == 0):
return torch.tensor(0.0, device=numerator.device, dtype=torch.float)

if average == "micro":
numerator = torch.sum(numerator, dim=-1)
denominator = torch.sum(denominator, dim=-1)
dice = _safe_divide(numerator, denominator, zero_division=1.0)
dice = _safe_divide(numerator, denominator, zero_division=zero_division)
if average == "macro":
dice = torch.mean(dice, dim=-1)
elif average == "weighted" and support is not None:
weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division=1.0)
weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division=zero_division)
dice = torch.sum(dice * weights, dim=-1)
return dice

Expand Down
9 changes: 7 additions & 2 deletions src/torchmetrics/segmentation/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class DiceScore(Metric):
or ``None``. This determines how to average the dice score across different classes.
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors
zero_division: The value to return when there is a division by zero. Options are 1.0, 0.0, "warn" or "nan".
Setting it to "warn" behaves like 0.0 but will also create a warning.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand Down Expand Up @@ -110,14 +112,16 @@ def __init__(
include_background: bool = True,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
input_format: Literal["one-hot", "index"] = "one-hot",
zero_division: Union[float, Literal["warn", "nan"]] = 0.0,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
_dice_score_validate_args(num_classes, include_background, average, input_format)
_dice_score_validate_args(num_classes, include_background, average, input_format, zero_division)
self.num_classes = num_classes
self.include_background = include_background
self.average = average
self.input_format = input_format
self.zero_division = zero_division

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("numerator", [], dist_reduce_fx="cat")
Expand All @@ -140,7 +144,8 @@ def compute(self) -> Tensor:
dim_zero_cat(self.denominator),
self.average,
support=dim_zero_cat(self.support) if self.average == "weighted" else None,
).mean(dim=0)
zero_division=self.zero_division,
).nanmean(dim=0)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down
19 changes: 15 additions & 4 deletions src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.utilities import rank_zero_warn


def _safe_matmul(x: Tensor, y: Tensor) -> Tensor:
"""Safe calculation of matrix multiplication.
Expand Down Expand Up @@ -44,7 +46,11 @@ def _safe_xlogy(x: Tensor, y: Tensor) -> Tensor:
return res


def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tensor:
def _safe_divide(
num: Tensor,
denom: Tensor,
zero_division: Union[float, Literal["warn", "nan"]] = 0.0,
) -> Tensor:
"""Safe division, by preventing division by zero.

Function will cast to float if input is not already to secure backwards compatibility.
Expand All @@ -64,8 +70,13 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens
"""
num = num if num.is_floating_point() else num.float()
denom = denom if denom.is_floating_point() else denom.float()
zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to(num.device, non_blocking=True)
return torch.where(denom != 0, num / denom, zero_division_tensor)
if isinstance(zero_division, float) or zero_division == "warn":
if zero_division == "warn" and torch.any(denom == 0):
rank_zero_warn("Detected zero division in _safe_divide. Setting 0/0 to 0.0")
zero_division = 0.0 if zero_division == "warn" else zero_division
zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to(num.device, non_blocking=True)
return torch.where(denom != 0, num / denom, zero_division_tensor)
return torch.true_divide(num, denom)


def _adjust_weights_safe_divide(
Expand Down
43 changes: 43 additions & 0 deletions tests/unittests/segmentation/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import torch
from sklearn.metrics import f1_score
from torch import tensor

from torchmetrics import MetricCollection
from torchmetrics.functional.segmentation.dice import dice_score
Expand Down Expand Up @@ -109,6 +110,48 @@ def test_dice_score_functional(self, preds, target, input_format, include_backgr
)


@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None])
def test_corner_case_no_overlap(average):
"""Check that if no overlap and intersection between target and preds, the dice score is 0.

See issue: https://github.com/Lightning-AI/torchmetrics/issues/2851

"""
target = torch.full((4, 4, 128, 128), 0, dtype=torch.int8)
preds = torch.full((4, 4, 128, 128), 0, dtype=torch.int8)
target[0, 0] = 1
preds[0, 0] = 1
dice = DiceScore(num_classes=3, average=average, include_background=False)
assert dice(preds, target) == 0.0


@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None])
@pytest.mark.parametrize("zero_division", [1.0, 0.0, "warn", "nan"])
def test_zero_division(zero_division, average):
"""Test different zero_division values."""
target = torch.full((1, 3, 128, 128), 0, dtype=torch.int8)
preds = torch.full((1, 3, 128, 128), 0, dtype=torch.int8)
target[0, 0] = 1
dice = DiceScore(num_classes=3, average=average, zero_division=zero_division)
score = dice(preds, target)

res_dict = {
"micro": {1.0: tensor(0.0), 0.0: tensor(0.0), "warn": tensor(0.0), "nan": tensor(0.0)},
"macro": {1.0: tensor(0.6667), 0.0: tensor(0.0), "warn": tensor(0.0), "nan": tensor(float("nan"))},
"weighted": {1.0: tensor(0.0), 0.0: tensor(0.0), "warn": tensor(0.0), "nan": tensor(float("nan"))},
None: {
1.0: tensor([0.0, 1.0, 1.0]),
0.0: tensor([0.0, 0.0, 0.0]),
"warn": tensor([0.0, 0.0, 0.0]),
"nan": tensor([0.0, float("nan"), float("nan")]),
},
}

assert torch.allclose(score, res_dict[average][zero_division], atol=1e-4, equal_nan=True), (
f"Expected {res_dict[average][zero_division]} but got {score}"
)


@pytest.mark.parametrize("compute_groups", [True, False])
def test_dice_score_metric_collection(compute_groups: bool, num_batches: int = 4):
"""Test that the metric works within a metric collection with and without compute groups."""
Expand Down
Loading