diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 9aa3c7c7bf4..04f28584b10 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -57,7 +57,6 @@ def _generalized_dice_update( if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) - if input_format == "index": target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) if not include_background: diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 8b749d1d463..278257d04b1 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -51,7 +51,6 @@ def _mean_iou_update( if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) - if input_format == "index": target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) if not include_background: diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 1d0d58babe6..a2bbab7b921 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -47,7 +47,6 @@ def _reference_generalized_dice( """Calculate reference metric for `MeanIoU`.""" if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) - if input_format == "index": target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) val = compute_generalized_dice(preds, target, include_background=include_background) if reduce: diff --git a/tests/unittests/segmentation/test_mean_iou.py b/tests/unittests/segmentation/test_mean_iou.py index da110bb7106..68c2b060a9e 100644 --- a/tests/unittests/segmentation/test_mean_iou.py +++ b/tests/unittests/segmentation/test_mean_iou.py @@ -48,7 +48,6 @@ def _reference_mean_iou( """Calculate reference metric for `MeanIoU`.""" if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) - if input_format == "index": target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) val = compute_iou(preds, target, include_background=include_background)