Skip to content

Commit

Permalink
Fix segmentation.MeanIoU (#2698)
Browse files Browse the repository at this point in the history
- use sum reduce function for score
- add state `num_batches` to keep number of processed batches
- add increment of `num_batches` in every `update` call
- in `compute` return sum of scores divided by number of processed batches

---------

Co-authored-by: Nicki Skafte <[email protected]>
(cherry picked from commit cb1ab37)
  • Loading branch information
vkinakh authored and Borda committed Sep 13, 2024
1 parent 327f01c commit 400aa91
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed wrong aggregation in `segmentation.MeanIoU` ([#2698](https://github.com/Lightning-AI/torchmetrics/pull/2698))


- Fixed handling zero division error in binary IoU (Jaccard index) calculation ([#2726](https://github.com/Lightning-AI/torchmetrics/pull/2726))


Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Aggregate and evaluate batch input directly.
Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch
statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding
statistics to the overall accumulating metric state. Input arguments are the exact same as corresponding
``update`` method. The returned output is the exact same as the output of ``compute``.
Args:
Expand Down Expand Up @@ -361,7 +361,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any:
def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
"""Forward computation using single call to `update`.
This can be done when the global metric state is a sinple reduction of batch states. This can be unsafe for
This can be done when the global metric state is a simple reduction of batch states. This can be unsafe for
certain metric cases but is also the fastest way to both accumulate globally and compute locally.
"""
Expand Down Expand Up @@ -802,7 +802,7 @@ def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module:
"""Overwrite `_apply` function such that we can also move metric states to the correct device.
This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods
are called. Dtype conversion is garded and will only happen through the special `set_dtype` method.
are called. Dtype conversion is guarded and will only happen through the special `set_dtype` method.
Args:
fn: the function to apply
Expand Down Expand Up @@ -1166,15 +1166,15 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt
"""

def update(self, *args: Any, **kwargs: Any) -> None:
"""Redirect the call to the input which the conposition was formed from."""
"""Redirect the call to the input which the composition was formed from."""
if isinstance(self.metric_a, Metric):
self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs))

if isinstance(self.metric_b, Metric):
self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs))

def compute(self) -> Any:
"""Redirect the call to the input which the conposition was formed from."""
"""Redirect the call to the input which the composition was formed from."""
# also some parsing for kwargs?
val_a = self.metric_a.compute() if isinstance(self.metric_a, Metric) else self.metric_a
val_b = self.metric_b.compute() if isinstance(self.metric_b, Metric) else self.metric_b
Expand Down Expand Up @@ -1216,7 +1216,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
return self._forward_cache

def reset(self) -> None:
"""Redirect the call to the input which the conposition was formed from."""
"""Redirect the call to the input which the composition was formed from."""
if isinstance(self.metric_a, Metric):
self.metric_a.reset()

Expand Down
8 changes: 5 additions & 3 deletions src/torchmetrics/segmentation/mean_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,19 @@ def __init__(
self.per_class = per_class

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean")
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum")
self.add_state("num_batches", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with the new data."""
intersection, union = _mean_iou_update(preds, target, self.num_classes, self.include_background)
score = _mean_iou_compute(intersection, union, per_class=self.per_class)
self.score += score.mean(0) if self.per_class else score.mean()
self.num_batches += 1

def compute(self) -> Tensor:
"""Update the state with the new data."""
return self.score # / self.num_batches
"""Compute the final Mean Intersection over Union (mIoU)."""
return self.score / self.num_batches

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down
25 changes: 23 additions & 2 deletions tests/unittests/_helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,25 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O
"""Recursively assert that two results are within a certain tolerance."""
# single output compare
if isinstance(tm_result, Tensor):
assert np.allclose(tm_result.detach().cpu().numpy(), ref_result, atol=atol, equal_nan=True)
assert np.allclose(
tm_result.detach().cpu().numpy() if isinstance(tm_result, Tensor) else tm_result,
ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result,
atol=atol,
equal_nan=True,
)
# multi output compare
elif isinstance(tm_result, Sequence):
for pl_res, ref_res in zip(tm_result, ref_result):
_assert_allclose(pl_res, ref_res, atol=atol)
elif isinstance(tm_result, Dict):
if key is None:
raise KeyError("Provide Key for Dict based metric results.")
assert np.allclose(tm_result[key].detach().cpu().numpy(), ref_result, atol=atol, equal_nan=True)
assert np.allclose(
tm_result[key].detach().cpu().numpy() if isinstance(tm_result[key], Tensor) else tm_result[key],
ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result,
atol=atol,
equal_nan=True,
)
else:
raise ValueError("Unknown format for comparison")

Expand Down Expand Up @@ -147,13 +157,24 @@ def _class_test(
# verify metrics work after being loaded from pickled state
pickled_metric = pickle.dumps(metric)
metric = pickle.loads(pickled_metric)
metric_clone = deepcopy(metric)

for i in range(rank, num_batches, world_size):
batch_kwargs_update = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}

# compute batch stats and aggregate for global stats
batch_result = metric(preds[i], target[i], **batch_kwargs_update)

if rank == 0 and world_size == 1 and i == 0: # check only in non-ddp mode and first batch
# dummy check to make sure that forward/update works as expected
metric_clone.update(preds[i], target[i], **batch_kwargs_update)
update_result = metric_clone.compute()
if isinstance(batch_result, dict):
for key in batch_result:
_assert_allclose(batch_result, update_result[key], key=key)
else:
_assert_allclose(batch_result, update_result)

if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0:
if isinstance(preds, Tensor):
ddp_preds = torch.cat([preds[i + r] for r in range(world_size)]).cpu()
Expand Down

0 comments on commit 400aa91

Please sign in to comment.