Skip to content

Commit 365f5da

Browse files
SkafteNickiBorda
authored andcommitted
Bugfix for empty preds or target in iou scores (#2806)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka B <[email protected]> (cherry picked from commit fbc7877)
1 parent 7b2138e commit 365f5da

File tree

7 files changed

+74
-5
lines changed

7 files changed

+74
-5
lines changed

CHANGELOG.md

+9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
**Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.**
88

9+
---
10+
11+
## [UnReleased] - 2024-MM-DD
12+
13+
### Fixed
14+
15+
- Fixed iou scores in detection for either empty predictions/targets leading to wrong scores ([#2805](https://github.com/Lightning-AI/torchmetrics/pull/2805))
16+
17+
918
---
1019

1120
## [1.5.1] - 2024-10-22

src/torchmetrics/detection/iou.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,17 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
182182
"""Update state with predictions and targets."""
183183
_input_validator(preds, target, ignore_score=True)
184184

185-
for p, t in zip(preds, target):
186-
det_boxes = self._get_safe_item_values(p["boxes"])
187-
gt_boxes = self._get_safe_item_values(t["boxes"])
188-
self.groundtruth_labels.append(t["labels"])
185+
for p_i, t_i in zip(preds, target):
186+
det_boxes = self._get_safe_item_values(p_i["boxes"])
187+
gt_boxes = self._get_safe_item_values(t_i["boxes"])
188+
self.groundtruth_labels.append(t_i["labels"])
189189

190190
iou_matrix = self._iou_update_fn(det_boxes, gt_boxes, self.iou_threshold, self._invalid_val) # N x M
191191
if self.respect_labels:
192-
label_eq = p["labels"].unsqueeze(1) == t["labels"].unsqueeze(0) # N x M
192+
if det_boxes.numel() > 0 and gt_boxes.numel() > 0:
193+
label_eq = p_i["labels"].unsqueeze(1) == t_i["labels"].unsqueeze(0) # N x M
194+
else:
195+
label_eq = torch.eye(iou_matrix.shape[0], dtype=bool, device=iou_matrix.device) # type: ignore[call-overload]
193196
iou_matrix[~label_eq] = self._invalid_val
194197
self.iou_matrix.append(iou_matrix)
195198

src/torchmetrics/functional/detection/ciou.py

+5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ def _ciou_update(
3131

3232
from torchvision.ops import complete_box_iou
3333

34+
if preds.numel() == 0: # if no boxes are predicted
35+
return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32)
36+
if target.numel() == 0: # if no boxes are true
37+
return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32)
38+
3439
iou = complete_box_iou(preds, target)
3540
if iou_threshold is not None:
3641
iou[iou < iou_threshold] = replacement_val

src/torchmetrics/functional/detection/diou.py

+5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ def _diou_update(
3131

3232
from torchvision.ops import distance_box_iou
3333

34+
if preds.numel() == 0: # if no boxes are predicted
35+
return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32)
36+
if target.numel() == 0: # if no boxes are true
37+
return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32)
38+
3439
iou = distance_box_iou(preds, target)
3540
if iou_threshold is not None:
3641
iou[iou < iou_threshold] = replacement_val

src/torchmetrics/functional/detection/giou.py

+5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ def _giou_update(
3131

3232
from torchvision.ops import generalized_box_iou
3333

34+
if preds.numel() == 0: # if no boxes are predicted
35+
return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32)
36+
if target.numel() == 0: # if no boxes are true
37+
return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32)
38+
3439
iou = generalized_box_iou(preds, target)
3540
if iou_threshold is not None:
3641
iou[iou < iou_threshold] = replacement_val

src/torchmetrics/functional/detection/iou.py

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ def _iou_update(
3232

3333
from torchvision.ops import box_iou
3434

35+
if preds.numel() == 0: # if no boxes are predicted
36+
return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32)
37+
if target.numel() == 0: # if no boxes are true
38+
return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32)
39+
3540
iou = box_iou(preds, target)
3641
if iou_threshold is not None:
3742
iou[iou < iou_threshold] = replacement_val

tests/unittests/detection/test_intersection.py

+37
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,43 @@ def test_corner_case_only_one_empty_prediction(self, class_metric, functional_me
355355
for val in res.values():
356356
assert val == torch.tensor(0.0)
357357

358+
def test_empty_preds_and_target(self, class_metric, functional_metric, reference_metric):
359+
"""Check that for either empty preds and targets that the metric returns 0 in these cases before averaging."""
360+
x = [
361+
{
362+
"boxes": torch.empty(size=(0, 4), dtype=torch.float32),
363+
"labels": torch.tensor([], dtype=torch.long),
364+
},
365+
{
366+
"boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]),
367+
"labels": torch.LongTensor([1, 2]),
368+
},
369+
]
370+
371+
y = [
372+
{
373+
"boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]),
374+
"labels": torch.LongTensor([1, 2]),
375+
"scores": torch.FloatTensor([0.9, 0.8]),
376+
},
377+
{
378+
"boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]),
379+
"labels": torch.LongTensor([1, 2]),
380+
"scores": torch.FloatTensor([0.9, 0.8]),
381+
},
382+
]
383+
metric = class_metric()
384+
metric.update(x, y)
385+
res = metric.compute()
386+
for val in res.values():
387+
assert val == torch.tensor(0.5)
388+
389+
metric = class_metric()
390+
metric.update(y, x)
391+
res = metric.compute()
392+
for val in res.values():
393+
assert val == torch.tensor(0.5)
394+
358395

359396
def test_corner_case():
360397
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/1921."""

0 commit comments

Comments
 (0)