Skip to content

Commit

Permalink
more code
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Sep 14, 2024
1 parent c57c11f commit 9be6172
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
17 changes: 14 additions & 3 deletions src/torchmetrics/functional/segmentation/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def _dice_score_compute(
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
) -> Tensor:
if average == "micro":
numerator = torch.sum(numerator, dim=0)
denominator = torch.sum(denominator, dim=0)
numerator = torch.sum(numerator, dim=1)
denominator = torch.sum(denominator, dim=1)
dice = _safe_divide(numerator, denominator, zero_division=1.0)
if average == "macro":
dice = torch.mean(dice)
Expand All @@ -95,7 +95,8 @@ def dice_score(
) -> Tensor:
"""Compute the Dice score for semantic segmentation.
preds: Predictions from model
Args:
preds: Predictions from model
target: Ground truth values
num_classes: Number of classes
include_background: Whether to include the background class in the computation
Expand All @@ -105,6 +106,16 @@ def dice_score(
Returns:
The Dice score.
Example (with one-hot encoded tensors):
>>> from torch import randint
>>> from torchmetrics.functional.segmentation import dice_score
>>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
>>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target
>>> dice_score(preds, target, num_classes=5)
tensor([0.4872, 0.5000, 0.5019, 0.4891, 0.4926])
Example (with index tensors):
"""
_dice_score_validate_args(num_classes, include_background, average, input_format)
numerator, denominator = _dice_score_update(preds, target, num_classes, include_background, input_format)
Expand Down
15 changes: 14 additions & 1 deletion src/torchmetrics/functional/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def generalized_dice_score(
Returns:
The Generalized Dice Score
Example:
Example (with one-hot encoded tensors):
>>> from torch import randint
>>> from torchmetrics.functional.segmentation import generalized_dice_score
>>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
Expand All @@ -136,6 +136,19 @@ def generalized_dice_score(
[0.4571, 0.4980, 0.5191, 0.4380, 0.5649],
[0.5428, 0.4904, 0.5358, 0.4830, 0.4724],
[0.4715, 0.4925, 0.4797, 0.5267, 0.4788]])
Example (with index tensors):
>>> from torch import randint
>>> from torchmetrics.functional.segmentation import generalized_dice_score
>>> preds = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
>>> target = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 target
>>> generalized_dice_score(preds, target, num_classes=5)
tensor([0.4830, 0.4935, 0.5044, 0.4880])
>>> generalized_dice_score(preds, target, num_classes=5, per_class=True)
tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500],
[0.4571, 0.4980, 0.5191, 0.4380, 0.5649],
[0.5428, 0.4904, 0.5358, 0.4830, 0.4724],
[0.4715, 0.4925, 0.4797, 0.5267, 0.4788]])
"""
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format)
Expand Down

0 comments on commit 9be6172

Please sign in to comment.