From 2792da5c1981998bb4a81e6aeed09cdfa2f70ed6 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 4 Jun 2024 12:50:06 +0200 Subject: [PATCH] Add better errors when wrong input is received in detection intersection metrics (#2577) * checks * tests * changelog (cherry picked from commit f1a2be74d2ff61eadcb79ccf28b653ed40ec2f0c) --- CHANGELOG.md | 3 +++ src/torchmetrics/functional/detection/ciou.py | 5 +++++ src/torchmetrics/functional/detection/diou.py | 5 +++++ src/torchmetrics/functional/detection/giou.py | 5 +++++ src/torchmetrics/functional/detection/iou.py | 6 ++++++ tests/unittests/detection/test_intersection.py | 14 ++++++++++++++ 6 files changed, 38 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 34d6cf354e0..ea0e9313762 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - +- Added better error messages for intersection detection metrics for wrong user input ([#2577](https://github.com/Lightning-AI/torchmetrics/pull/2577)) + + ### Changed - diff --git a/src/torchmetrics/functional/detection/ciou.py b/src/torchmetrics/functional/detection/ciou.py index 1dbc6f768f2..9669029ba73 100644 --- a/src/torchmetrics/functional/detection/ciou.py +++ b/src/torchmetrics/functional/detection/ciou.py @@ -24,6 +24,11 @@ def _ciou_update( preds: torch.Tensor, target: torch.Tensor, iou_threshold: Optional[float], replacement_val: float = 0 ) -> torch.Tensor: + if preds.ndim != 2 or preds.shape[-1] != 4: + raise ValueError(f"Expected preds to be of shape (N, 4) but got {preds.shape}") + if target.ndim != 2 or target.shape[-1] != 4: + raise ValueError(f"Expected target to be of shape (N, 4) but got {target.shape}") + from torchvision.ops import complete_box_iou iou = complete_box_iou(preds, target) diff --git a/src/torchmetrics/functional/detection/diou.py b/src/torchmetrics/functional/detection/diou.py index 05bcd96ab58..13fb0071fed 100644 --- a/src/torchmetrics/functional/detection/diou.py +++ b/src/torchmetrics/functional/detection/diou.py @@ -24,6 +24,11 @@ def _diou_update( preds: torch.Tensor, target: torch.Tensor, iou_threshold: Optional[float], replacement_val: float = 0 ) -> torch.Tensor: + if preds.ndim != 2 or preds.shape[-1] != 4: + raise ValueError(f"Expected preds to be of shape (N, 4) but got {preds.shape}") + if target.ndim != 2 or target.shape[-1] != 4: + raise ValueError(f"Expected target to be of shape (N, 4) but got {target.shape}") + from torchvision.ops import distance_box_iou iou = distance_box_iou(preds, target) diff --git a/src/torchmetrics/functional/detection/giou.py b/src/torchmetrics/functional/detection/giou.py index b4532eef80b..cc39f813b41 100644 --- a/src/torchmetrics/functional/detection/giou.py +++ b/src/torchmetrics/functional/detection/giou.py @@ -24,6 +24,11 @@ def _giou_update( preds: torch.Tensor, target: torch.Tensor, iou_threshold: Optional[float], replacement_val: float = 0 ) -> torch.Tensor: + if preds.ndim != 2 or preds.shape[-1] != 4: + raise ValueError(f"Expected preds to be of shape (N, 4) but got {preds.shape}") + if target.ndim != 2 or target.shape[-1] != 4: + raise ValueError(f"Expected target to be of shape (N, 4) but got {target.shape}") + from torchvision.ops import generalized_box_iou iou = generalized_box_iou(preds, target) diff --git a/src/torchmetrics/functional/detection/iou.py b/src/torchmetrics/functional/detection/iou.py index 8c1f22b4096..3d3cef26bb2 100644 --- a/src/torchmetrics/functional/detection/iou.py +++ b/src/torchmetrics/functional/detection/iou.py @@ -24,6 +24,12 @@ def _iou_update( preds: torch.Tensor, target: torch.Tensor, iou_threshold: Optional[float], replacement_val: float = 0 ) -> torch.Tensor: + """Compute the IoU matrix between two sets of boxes.""" + if preds.ndim != 2 or preds.shape[-1] != 4: + raise ValueError(f"Expected preds to be of shape (N, 4) but got {preds.shape}") + if target.ndim != 2 or target.shape[-1] != 4: + raise ValueError(f"Expected target to be of shape (N, 4) but got {target.shape}") + from torchvision.ops import box_iou iou = box_iou(preds, target) diff --git a/tests/unittests/detection/test_intersection.py b/tests/unittests/detection/test_intersection.py index 9d99d2d55c5..c42a6763ba9 100644 --- a/tests/unittests/detection/test_intersection.py +++ b/tests/unittests/detection/test_intersection.py @@ -314,6 +314,20 @@ def test_error_on_wrong_input(self, class_metric, functional_metric, reference_m [{"boxes": Tensor(), "labels": []}], ) + def test_functional_error_on_wrong_input_shape(self, class_metric, functional_metric, reference_metric): + """Test functional input validation.""" + with pytest.raises(ValueError, match="Expected preds to be of shape.*"): + functional_metric(torch.randn(25), torch.randn(25, 4)) + + with pytest.raises(ValueError, match="Expected target to be of shape.*"): + functional_metric(torch.randn(25, 4), torch.randn(25)) + + with pytest.raises(ValueError, match="Expected preds to be of shape.*"): + functional_metric(torch.randn(25, 25), torch.randn(25, 4)) + + with pytest.raises(ValueError, match="Expected target to be of shape.*"): + functional_metric(torch.randn(25, 4), torch.randn(25, 25)) + def test_corner_case(): """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1921."""