-
Notifications
You must be signed in to change notification settings - Fork 413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
segmentation.MeanIoU is wrong #2558
Comments
Hi! thanks for your contribution!, great first issue! |
I think now with zero_division having been added to JaccardIndex, there's no real need for MeanIoU at all. MeanIoU also doesn't follow the classical API. The original motivation for this was that JaccardIndex used to assign a score of 0 to absent and ignored classes so you couldn't do classwise and macro averaging correctly. Now, you can just set zero_division to NaN and average to None, and get correct class scores. From there, you could do a nanmean to get the correct macro average. |
Hi, I also noticed that the MeanIoU was wrong during my experiments. I developed something that seems to work on my side, which returns the same results as evaluate's mean iou however based on @DimitrisMantas I wonder if it is relevant to submit a PR. I am not familiar enough with the JaccardIndex implementation in torchmetrics. For reference, here is the undocumented code I developed (which has not been rigorusly tested for now), let me know if submitting a PR is something interesting for you, I'd gladly contribute to this repo. from typing import Any, Literal
import torch
from torch import Tensor
from torchmetrics import Metric
def _compute_intersection_and_union(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = False,
input_format: Literal["one-hot", "index", "predictions"] = "index",
) -> tuple[Tensor, Tensor]:
if input_format in ["index", "predictions"]:
if input_format == "predictions":
preds = preds.argmax(1)
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes)
target = torch.nn.functional.one_hot(target, num_classes=num_classes)
if not include_background:
preds[..., 0] = 0
target[..., 0] = 0
reduce_axis = list(range(1, preds.ndim - 1))
intersection = torch.sum(torch.logical_and(preds, target), dim=reduce_axis)
target_sum = torch.sum(target, dim=reduce_axis)
pred_sum = torch.sum(preds, dim=reduce_axis)
union = target_sum + pred_sum - intersection
return intersection, union
class MeanIoU(Metric):
def __init__(
self,
num_classes: int,
include_background: bool = True,
per_class: bool = False,
input_format: Literal["one-hot", "index", "predictions"] = "index",
**kwargs: Any,
) -> None:
Metric.__init__(self, **kwargs)
self.num_classes = num_classes
self.include_background = include_background
self.per_class = per_class
self.input_format = input_format
self.add_state("intersection", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("union", default=torch.zeros(num_classes), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor) -> None:
intersection, union = _compute_intersection_and_union(
preds, target, self.num_classes, self.include_background, self.input_format
)
self.intersection += intersection.sum(0)
self.union += union.sum(0)
def compute(self) -> Tensor:
iou_valid = torch.gt(self.union, 0)
iou = torch.where(
iou_valid,
torch.divide(self.intersection, self.union),
torch.nan,
)
if self.per_class:
return iou
else:
return torch.mean(iou[iou_valid]) |
Hi, I have noticed that MeanIoU is incorrect, when updating it via calling
where The same behavior is observed for all formats and if Here is the code to reproduce results import torch
from torchmetrics.segmentation import MeanIoU
def run():
bs = 16
num_classes = 3
h = w = 128
# one-hot, per_class=False
miou_update = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=False)
miou_call = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=False)
# test 1, all ones
img_pred = torch.ones(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.ones(bs, num_classes, h, w, dtype=torch.long)
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 2, square in the middle, 100% overlap
img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_pred[:, :, 50:80, 50:80] = 1
img_target[:, :, 50:80, 50:80] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 3, square in the middle, 50% overlap
img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_pred[:, :, 50:80, 50:80] = 1
img_target[:, :, 65:80, 50:80] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 4, square in the middle, 0% overlap
img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_pred[:, :, 50:80, 50:80] = 1
img_target[:, :, 80:100, 80:100] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, { miou_update.score}, { miou_update.update_count}")
print(f"Call: {miou_call.compute()}, { miou_call.score}, { miou_call.update_count}")
# one-hot, per_class=True
miou_update = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=True)
miou_call = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=True)
# test 5, all ones
img_pred = torch.ones(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.ones(bs, num_classes, h, w, dtype=torch.long)
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 6, square in the middle, 100% overlap
img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_pred[:, :, 50:80, 50:80] = 1
img_target[:, :, 50:80, 50:80] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 7, square in the middle, 50% overlap
img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_pred[:, :, 50:80, 50:80] = 1
img_target[:, :, 65:80, 50:80] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 8, square in the middle, 0% overlap
img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_pred[:, :, 50:80, 50:80] = 1
img_target[:, :, 80:100, 80:100] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# index, per_class=False
miou_update = MeanIoU(num_classes=num_classes, input_format="index", per_class=True)
miou_call = MeanIoU(num_classes=num_classes, input_format="index", per_class=True)
# test 9, all ones
img_pred = torch.ones(bs, h, w, dtype=torch.long)
img_target = torch.ones(bs, h, w, dtype=torch.long)
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 10, square in the middle, 100% overlap
img_pred = torch.zeros(bs, h, w, dtype=torch.long)
img_target = torch.zeros(bs, h, w, dtype=torch.long)
img_pred[:, 50:80, 50:80] = 1
img_target[:, 50:80, 50:80] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 11, square in the middle, 50% overlap
img_pred = torch.zeros(bs, h, w, dtype=torch.long)
img_target = torch.zeros(bs, h, w, dtype=torch.long)
img_pred[:, 50:80, 50:80] = 1
img_target[:, 65:80, 50:80] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 12, square in the middle, 0% overlap
img_pred = torch.zeros(bs, h, w, dtype=torch.long)
img_pred[:, 50:80, 50:80] = 1
img_target[:, 80:100, 80:100] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}") |
- use sum reduce function for score - add state `num_batches` to keep number of processed batches - add increment of `num_batches` in every `update` call - in `compute` return sum of scores divided by number of processed batches
Hi, I'm new to open-source contributions, and I want to start by working on this issue. @Borda could you please assign it to me? |
- use sum reduce function for score - add state `num_batches` to keep number of processed batches - add increment of `num_batches` in every `update` call - in `compute` return sum of scores divided by number of processed batches
🐛 Bug
MeanIoU scored 56(!) over validation dataset
Let's look at the source code:
torchmetrics/src/torchmetrics/segmentation/mean_iou.py
Line 109 in 596ed09
There are several issues there:
self.score
is accumulated with each call ofupdate
method.compute
method just returns the accumulated scorecompute
method is copypasted fromupdate
methodself.num_batches
is commentednum_batches
is definded on class level and not used anywhere elseObviously, that code was neither reviewed nor tested, but somehow was released.
To Reproduce
Call
metric.update(y_hat, y)
invalidation_step
Log the metric in
on_validation_epoch_end
Expected behavior
MeanIoU computes correct value in [0, 1] range.
Environment
The text was updated successfully, but these errors were encountered: