Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/map_tm_to_coco
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Jun 4, 2024
2 parents 170da01 + 703423a commit 7e8f61d
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `input_format` argument to segmentation metrics ([#2572](https://github.com/Lightning-AI/torchmetrics/pull/2572))


- Added better error messages for intersection detection metrics for wrong user input ([#2577](https://github.com/Lightning-AI/torchmetrics/pull/2577))


### Changed

-
Expand Down
2 changes: 1 addition & 1 deletion requirements/_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ fire <=0.6.0

cloudpickle >1.3, <=3.0.0
scikit-learn >=1.1.1, <1.3.0; python_version < "3.9"
scikit-learn >=1.4.0, <1.5.0; python_version >= "3.9"
scikit-learn >=1.4.0, <1.6.0; python_version >= "3.9"
cachier ==3.0.0
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/detection/ciou.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
def _ciou_update(
preds: torch.Tensor, target: torch.Tensor, iou_threshold: Optional[float], replacement_val: float = 0
) -> torch.Tensor:
if preds.ndim != 2 or preds.shape[-1] != 4:
raise ValueError(f"Expected preds to be of shape (N, 4) but got {preds.shape}")
if target.ndim != 2 or target.shape[-1] != 4:
raise ValueError(f"Expected target to be of shape (N, 4) but got {target.shape}")

from torchvision.ops import complete_box_iou

iou = complete_box_iou(preds, target)
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/detection/diou.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
def _diou_update(
preds: torch.Tensor, target: torch.Tensor, iou_threshold: Optional[float], replacement_val: float = 0
) -> torch.Tensor:
if preds.ndim != 2 or preds.shape[-1] != 4:
raise ValueError(f"Expected preds to be of shape (N, 4) but got {preds.shape}")
if target.ndim != 2 or target.shape[-1] != 4:
raise ValueError(f"Expected target to be of shape (N, 4) but got {target.shape}")

from torchvision.ops import distance_box_iou

iou = distance_box_iou(preds, target)
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/detection/giou.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
def _giou_update(
preds: torch.Tensor, target: torch.Tensor, iou_threshold: Optional[float], replacement_val: float = 0
) -> torch.Tensor:
if preds.ndim != 2 or preds.shape[-1] != 4:
raise ValueError(f"Expected preds to be of shape (N, 4) but got {preds.shape}")
if target.ndim != 2 or target.shape[-1] != 4:
raise ValueError(f"Expected target to be of shape (N, 4) but got {target.shape}")

from torchvision.ops import generalized_box_iou

iou = generalized_box_iou(preds, target)
Expand Down
6 changes: 6 additions & 0 deletions src/torchmetrics/functional/detection/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
def _iou_update(
preds: torch.Tensor, target: torch.Tensor, iou_threshold: Optional[float], replacement_val: float = 0
) -> torch.Tensor:
"""Compute the IoU matrix between two sets of boxes."""
if preds.ndim != 2 or preds.shape[-1] != 4:
raise ValueError(f"Expected preds to be of shape (N, 4) but got {preds.shape}")
if target.ndim != 2 or target.shape[-1] != 4:
raise ValueError(f"Expected target to be of shape (N, 4) but got {target.shape}")

from torchvision.ops import box_iou

iou = box_iou(preds, target)
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def __init__(
# see (https://github.com/pytorch/pytorch/blob/3e6bb5233f9ca2c5aa55d9cda22a7ee85439aa6e/
# torch/nn/modules/module.py#L227)
torch._C._log_api_usage_once(f"torchmetrics.metric.{self.__class__.__name__}")

# magic patch for `RuntimeError: DataLoader worker (pid(s) 104) exited unexpectedly`
self._TORCH_GREATER_EQUAL_2_1 = bool(_TORCH_GREATER_EQUAL_2_1)
self._device = torch.device("cpu")
self._dtype = torch.get_default_dtype()

Expand Down Expand Up @@ -441,7 +442,7 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group:

# cornor case in distributed settings where a rank have not received any data, create empty to concatenate
if (
_TORCH_GREATER_EQUAL_2_1
self._TORCH_GREATER_EQUAL_2_1
and reduction_fn == dim_zero_cat
and isinstance(input_dict[attr], list)
and len(input_dict[attr]) == 0
Expand Down
14 changes: 14 additions & 0 deletions tests/unittests/detection/test_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,20 @@ def test_error_on_wrong_input(self, class_metric, functional_metric, reference_m
[{"boxes": Tensor(), "labels": []}],
)

def test_functional_error_on_wrong_input_shape(self, class_metric, functional_metric, reference_metric):
"""Test functional input validation."""
with pytest.raises(ValueError, match="Expected preds to be of shape.*"):
functional_metric(torch.randn(25), torch.randn(25, 4))

with pytest.raises(ValueError, match="Expected target to be of shape.*"):
functional_metric(torch.randn(25, 4), torch.randn(25))

with pytest.raises(ValueError, match="Expected preds to be of shape.*"):
functional_metric(torch.randn(25, 25), torch.randn(25, 4))

with pytest.raises(ValueError, match="Expected target to be of shape.*"):
functional_metric(torch.randn(25, 4), torch.randn(25, 25))


def test_corner_case():
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/1921."""
Expand Down

0 comments on commit 7e8f61d

Please sign in to comment.