From 9be617258fe764c83c2a1340943105fc216895e9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 14 Sep 2024 09:48:02 +0200 Subject: [PATCH] more code --- .../functional/segmentation/dice.py | 17 ++++++++++++++--- .../functional/segmentation/generalized_dice.py | 15 ++++++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index e4c48ca3482..6b6a79dd44e 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -74,8 +74,8 @@ def _dice_score_compute( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", ) -> Tensor: if average == "micro": - numerator = torch.sum(numerator, dim=0) - denominator = torch.sum(denominator, dim=0) + numerator = torch.sum(numerator, dim=1) + denominator = torch.sum(denominator, dim=1) dice = _safe_divide(numerator, denominator, zero_division=1.0) if average == "macro": dice = torch.mean(dice) @@ -95,7 +95,8 @@ def dice_score( ) -> Tensor: """Compute the Dice score for semantic segmentation. - preds: Predictions from model + Args: + preds: Predictions from model target: Ground truth values num_classes: Number of classes include_background: Whether to include the background class in the computation @@ -105,6 +106,16 @@ def dice_score( Returns: The Dice score. + Example (with one-hot encoded tensors): + >>> from torch import randint + >>> from torchmetrics.functional.segmentation import dice_score + >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> dice_score(preds, target, num_classes=5) + tensor([0.4872, 0.5000, 0.5019, 0.4891, 0.4926]) + + Example (with index tensors): + """ _dice_score_validate_args(num_classes, include_background, average, input_format) numerator, denominator = _dice_score_update(preds, target, num_classes, include_background, input_format) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index e0de9f1821e..69a417bfdd8 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -124,7 +124,7 @@ def generalized_dice_score( Returns: The Generalized Dice Score - Example: + Example (with one-hot encoded tensors): >>> from torch import randint >>> from torchmetrics.functional.segmentation import generalized_dice_score >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction @@ -136,6 +136,19 @@ def generalized_dice_score( [0.4571, 0.4980, 0.5191, 0.4380, 0.5649], [0.5428, 0.4904, 0.5358, 0.4830, 0.4724], [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) + + Example (with index tensors): + >>> from torch import randint + >>> from torchmetrics.functional.segmentation import generalized_dice_score + >>> preds = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> generalized_dice_score(preds, target, num_classes=5) + tensor([0.4830, 0.4935, 0.5044, 0.4880]) + >>> generalized_dice_score(preds, target, num_classes=5, per_class=True) + tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500], + [0.4571, 0.4980, 0.5191, 0.4380, 0.5649], + [0.5428, 0.4904, 0.5358, 0.4830, 0.4724], + [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) """ _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format)