Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

segmentation.MeanIoU is wrong #2558

Closed
nrudakov opened this issue May 23, 2024 · 5 comments · Fixed by #2698
Closed

segmentation.MeanIoU is wrong #2558

nrudakov opened this issue May 23, 2024 · 5 comments · Fixed by #2698
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.4.x

Comments

@nrudakov
Copy link

nrudakov commented May 23, 2024

🐛 Bug

MeanIoU scored 56(!) over validation dataset
image

Let's look at the source code:

def update(self, preds: Tensor, target: Tensor) -> None:

def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the state with the new data."""
    intersection, union = _mean_iou_update(preds, target, self.num_classes, self.include_background)
    score = _mean_iou_compute(intersection, union, per_class=self.per_class)
    self.score += score.mean(0) if self.per_class else score.mean()

def compute(self) -> Tensor:
    """Update the state with the new data."""
    return self.score  # / self.num_batches

There are several issues there:

  • self.score is accumulated with each call of update method.
  • compute method just returns the accumulated score
  • description of compute method is copypasted from update method
  • division by self.num_batches is commented
  • num_batches is definded on class level and not used anywhere else

Obviously, that code was neither reviewed nor tested, but somehow was released.

To Reproduce

Call metric.update(y_hat, y) in validation_step
Log the metric in on_validation_epoch_end

Expected behavior

MeanIoU computes correct value in [0, 1] range.

Environment

  • TorchMetrics 1.4.0.post0 from pip
  • Python 3.11.9, PyTorch 2.3.0+cu121
  • pytorch-lightning 2.2.4
  • Windows 11
@nrudakov nrudakov added bug / fix Something isn't working help wanted Extra attention is needed labels May 23, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@Borda Borda added the v1.4.x label May 24, 2024
@DimitrisMantas
Copy link

DimitrisMantas commented May 31, 2024

I think now with zero_division having been added to JaccardIndex, there's no real need for MeanIoU at all. MeanIoU also doesn't follow the classical API.

The original motivation for this was that JaccardIndex used to assign a score of 0 to absent and ignored classes so you couldn't do classwise and macro averaging correctly. Now, you can just set zero_division to NaN and average to None, and get correct class scores. From there, you could do a nanmean to get the correct macro average.

@juliendenize
Copy link

Hi, I also noticed that the MeanIoU was wrong during my experiments.

I developed something that seems to work on my side, which returns the same results as evaluate's mean iou however based on @DimitrisMantas I wonder if it is relevant to submit a PR. I am not familiar enough with the JaccardIndex implementation in torchmetrics.

For reference, here is the undocumented code I developed (which has not been rigorusly tested for now), let me know if submitting a PR is something interesting for you, I'd gladly contribute to this repo.

from typing import Any, Literal

import torch
from torch import Tensor
from torchmetrics import Metric

def _compute_intersection_and_union(
    preds: Tensor,
    target: Tensor,
    num_classes: int,
    include_background: bool = False,
    input_format: Literal["one-hot", "index", "predictions"] = "index",
) -> tuple[Tensor, Tensor]:
    if input_format in ["index", "predictions"]:
        if input_format == "predictions":
            preds = preds.argmax(1)
        preds = torch.nn.functional.one_hot(preds, num_classes=num_classes)
        target = torch.nn.functional.one_hot(target, num_classes=num_classes)

    if not include_background:
        preds[..., 0] = 0
        target[..., 0] = 0

    reduce_axis = list(range(1, preds.ndim - 1))
    intersection = torch.sum(torch.logical_and(preds, target), dim=reduce_axis)
    target_sum = torch.sum(target, dim=reduce_axis)
    pred_sum = torch.sum(preds, dim=reduce_axis)
    union = target_sum + pred_sum - intersection

    return intersection, union


class MeanIoU(Metric):
    def __init__(
        self,
        num_classes: int,
        include_background: bool = True,
        per_class: bool = False,
        input_format: Literal["one-hot", "index", "predictions"] = "index",
        **kwargs: Any,
    ) -> None:
        Metric.__init__(self, **kwargs)

        self.num_classes = num_classes
        self.include_background = include_background
        self.per_class = per_class
        self.input_format = input_format

        self.add_state("intersection", default=torch.zeros(num_classes), dist_reduce_fx="sum")
        self.add_state("union", default=torch.zeros(num_classes), dist_reduce_fx="sum")

    def update(self, preds: Tensor, target: Tensor) -> None:
        intersection, union = _compute_intersection_and_union(
            preds, target, self.num_classes, self.include_background, self.input_format
        )
        self.intersection += intersection.sum(0)
        self.union += union.sum(0)

    def compute(self) -> Tensor:
        iou_valid = torch.gt(self.union, 0)

        iou = torch.where(
            iou_valid,
            torch.divide(self.intersection, self.union),
            torch.nan,
        )

        if self.per_class:
            return iou
        else:
            return torch.mean(iou[iou_valid])

@vkinakh
Copy link
Contributor

vkinakh commented Aug 21, 2024

Hi, I have noticed that MeanIoU is incorrect, when updating it via calling update method, if I update it via forward method it works correctly. I have looked into it, and it is because with default MeanIoU parameters forward method calls _reduce_states method, which updates score using the following formula:

reduced = ((self._update_count - 1) * global_state + local_state).float() / self._update_count,

where global_state is the score accumulated over previous steps and local_state is the score on current batch.

The same behavior is observed for all formats and if per_class = True and per_class = False

Here is the code to reproduce results

import torch
from torchmetrics.segmentation import MeanIoU


def run():
    bs = 16
    num_classes = 3
    h = w = 128

    # one-hot, per_class=False
    miou_update = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=False)
    miou_call = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=False)

    # test 1, all ones
    img_pred = torch.ones(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.ones(bs, num_classes, h, w, dtype=torch.long)

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 2, square in the middle, 100% overlap
    img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)

    img_pred[:, :, 50:80, 50:80] = 1
    img_target[:, :, 50:80, 50:80] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 3, square in the middle, 50% overlap
    img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)

    img_pred[:, :, 50:80, 50:80] = 1
    img_target[:, :, 65:80, 50:80] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 4, square in the middle, 0% overlap
    img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)

    img_pred[:, :, 50:80, 50:80] = 1
    img_target[:, :, 80:100, 80:100] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, { miou_update.score}, { miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, { miou_call.score}, { miou_call.update_count}")

    # one-hot, per_class=True
    miou_update = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=True)
    miou_call = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=True)

    # test 5, all ones
    img_pred = torch.ones(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.ones(bs, num_classes, h, w, dtype=torch.long)

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 6, square in the middle, 100% overlap
    img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)

    img_pred[:, :, 50:80, 50:80] = 1
    img_target[:, :, 50:80, 50:80] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 7, square in the middle, 50% overlap
    img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)

    img_pred[:, :, 50:80, 50:80] = 1
    img_target[:, :, 65:80, 50:80] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 8, square in the middle, 0% overlap
    img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)

    img_pred[:, :, 50:80, 50:80] = 1
    img_target[:, :, 80:100, 80:100] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # index, per_class=False
    miou_update = MeanIoU(num_classes=num_classes, input_format="index", per_class=True)
    miou_call = MeanIoU(num_classes=num_classes, input_format="index", per_class=True)

    # test 9, all ones
    img_pred = torch.ones(bs, h, w, dtype=torch.long)
    img_target = torch.ones(bs, h, w, dtype=torch.long)

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 10, square in the middle, 100% overlap
    img_pred = torch.zeros(bs, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, h, w, dtype=torch.long)

    img_pred[:, 50:80, 50:80] = 1
    img_target[:, 50:80, 50:80] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 11, square in the middle, 50% overlap
    img_pred = torch.zeros(bs, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, h, w, dtype=torch.long)

    img_pred[:, 50:80, 50:80] = 1
    img_target[:, 65:80, 50:80] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 12, square in the middle, 0% overlap
    img_pred = torch.zeros(bs, h, w, dtype=torch.long)

    img_pred[:, 50:80, 50:80] = 1
    img_target[:, 80:100, 80:100] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

vkinakh added a commit to vkinakh/torchmetrics that referenced this issue Aug 21, 2024
- use sum reduce function for score
- add state `num_batches` to keep number of processed batches
- add increment of `num_batches` in every `update` call
- in `compute` return sum of scores divided by number of processed batches
@vkinakh vkinakh mentioned this issue Aug 21, 2024
4 tasks
@rittik9
Copy link
Contributor

rittik9 commented Aug 28, 2024

Hi, I'm new to open-source contributions, and I want to start by working on this issue. @Borda could you please assign it to me?

Borda pushed a commit to vkinakh/torchmetrics that referenced this issue Sep 11, 2024
- use sum reduce function for score
- add state `num_batches` to keep number of processed batches
- add increment of `num_batches` in every `update` call
- in `compute` return sum of scores divided by number of processed batches
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.4.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants