diff --git a/docs/source/links.rst b/docs/source/links.rst index 88009ba5ff1..2d0000115d2 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -154,5 +154,7 @@ .. _Normalized Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html .. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools .. _Rand Score: https://link.springer.com/article/10.1007/BF01908075 +.. _faster-coco-eval: https://github.com/MiXaiLL76/faster_coco_eval +.. _fork of pycocotools: https://github.com/ppwwyyxx/cocoapi .. _Adjusted Rand Score: https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index .. _Dunn Index: https://en.wikipedia.org/wiki/Dunn_index diff --git a/requirements/detection_test.txt b/requirements/detection_test.txt index e69de29bb2d..6515620c715 100644 --- a/requirements/detection_test.txt +++ b/requirements/detection_test.txt @@ -0,0 +1,4 @@ +# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package +# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment + +faster-coco-eval >=1.3.3 diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 6cbe4f2b16a..9e60b2aa867 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -14,7 +14,8 @@ import contextlib import io import json -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from types import ModuleType +from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -27,6 +28,7 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.imports import ( + _FASTER_COCO_EVAL_AVAILABLE, _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8, @@ -48,14 +50,7 @@ "MeanAveragePrecision.coco_to_tm", ] - -if _PYCOCOTOOLS_AVAILABLE: - import pycocotools.mask as mask_utils - from pycocotools.coco import COCO - from pycocotools.cocoeval import COCOeval -else: - COCO, COCOeval = None, None - mask_utils = None +if not _PYCOCOTOOLS_AVAILABLE: __doctest_skip__ = [ "MeanAveragePrecision.plot", "MeanAveragePrecision", @@ -64,6 +59,32 @@ ] +def _load_backend_tools(backend: Literal["pycocotools", "faster_coco_eval"]) -> Tuple[object, object, ModuleType]: + """Load the backend tools for the given backend.""" + if backend == "pycocotools": + if not _PYCOCOTOOLS_AVAILABLE: + raise ModuleNotFoundError( + "Backend `pycocotools` in metric `MeanAveragePrecision` metric requires that `pycocotools` is" + " installed. Please install with `pip install pycocotools` or `pip install torchmetrics[detection]`" + ) + import pycocotools.mask as mask_utils + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + + return COCO, COCOeval, mask_utils + + if not _FASTER_COCO_EVAL_AVAILABLE: + raise ModuleNotFoundError( + "Backend `faster_coco_eval` in metric `MeanAveragePrecision` metric requires that `faster-coco-eval` is" + " installed. Please install with `pip install faster-coco-eval`." + ) + from faster_coco_eval import COCO + from faster_coco_eval import COCOeval_faster as COCOeval + from faster_coco_eval.core import mask as mask_utils + + return COCO, COCOeval, mask_utils + + class MeanAveragePrecision(Metric): r"""Compute the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)`_ for object detection predictions. @@ -142,9 +163,16 @@ class MeanAveragePrecision(Metric): Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well. .. note:: - This metric utilizes the official `pycocotools` implementation as its backend. This means that the metric - requires you to have `pycocotools` installed. In addition we require `torchvision` version 0.8.0 or newer. - Please install with ``pip install torchmetrics[detection]``. + This metric supports, at the moment, two different backends for the evaluation. The default backend is + ``"pycocotools"``, which either require the official `pycocotools`_ implementation or this + `fork of pycocotools`_ to be installed. We recommend using the fork as it is better maintained and easily + available to install via pip: `pip install pycocotools`. It is also this fork that will be installed if you + install ``torchmetrics[detection]``. The second backend is the `faster-coco-eval`_ implementation, which can be + installed with ``pip install faster-coco-eval``. This implementation is a maintained open-source implementation + that is faster and corrects certain corner cases that the official implementation has. Our own testing has shown + that the results are identical to the official implementation. Regardless of the backend we also require you to + have `torchvision` version 0.8.0 or newer installed. Please install with ``pip install torchvision>=0.8`` or + ``pip install torchmetrics[detection]``. Args: box_format: @@ -188,7 +216,9 @@ class MeanAveragePrecision(Metric): of max detections per image. average: - Method for averaging scores over labels. Choose between "``macro``"" and "``micro``". Default is "macro" + Method for averaging scores over labels. Choose between "``"macro"`` and ``"micro"``. + backend: + Backend to use for the evaluation. Choose between ``"pycocotools"`` and ``"faster_coco_eval"``. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -323,6 +353,19 @@ class MeanAveragePrecision(Metric): warn_on_many_detections: bool = True + __jit_unused_properties__: ClassVar[List[str]] = [ + "is_differentiable", + "higher_is_better", + "plot_lower_bound", + "plot_upper_bound", + "plot_legend_name", + "metric_state", + # below is added for specifically for this metric + "coco", + "cocoeval", + "mask_utils", + ] + def __init__( self, box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", @@ -333,6 +376,7 @@ def __init__( class_metrics: bool = False, extended_summary: bool = False, average: Literal["macro", "micro"] = "macro", + backend: Literal["pycocotools", "faster_coco_eval"] = "pycocotools", **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -387,6 +431,12 @@ def __init__( raise ValueError(f"Expected argument `average` to be one of ('macro', 'micro') but got {average}") self.average = average + if backend not in ("pycocotools", "faster_coco_eval"): + raise ValueError( + f"Expected argument `backend` to be one of ('pycocotools', 'faster_coco_eval') but got {backend}" + ) + self.backend = backend + self.add_state("detection_box", default=[], dist_reduce_fx=None) self.add_state("detection_mask", default=[], dist_reduce_fx=None) self.add_state("detection_scores", default=[], dist_reduce_fx=None) @@ -397,6 +447,24 @@ def __init__( self.add_state("groundtruth_crowds", default=[], dist_reduce_fx=None) self.add_state("groundtruth_area", default=[], dist_reduce_fx=None) + @property + def coco(self) -> object: + """Returns the coco module for the given backend, done in this way to make metric picklable.""" + coco, _, _ = _load_backend_tools(self.backend) + return coco + + @property + def cocoeval(self) -> object: + """Returns the coco eval module for the given backend, done in this way to make metric picklable.""" + _, cocoeval, _ = _load_backend_tools(self.backend) + return cocoeval + + @property + def mask_utils(self) -> object: + """Returns the mask utils object for the given backend, done in this way to make metric picklable.""" + _, _, mask_utils = _load_backend_tools(self.backend) + return mask_utils + def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: """Update metric state. @@ -454,7 +522,7 @@ def compute(self) -> dict: for anno in coco_preds.dataset["annotations"]: anno["area"] = anno[f"area_{i_type}"] - coco_eval = COCOeval(coco_target, coco_preds, iouType=i_type) + coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type) coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) coco_eval.params.maxDets = self.max_detection_thresholds @@ -482,7 +550,7 @@ def compute(self) -> dict: # since micro averaging have all the data in one class, we need to reinitialize the coco_eval # object in macro mode to get the per class stats coco_preds, coco_target = self._get_coco_datasets(average="macro") - coco_eval = COCOeval(coco_target, coco_preds, iouType=i_type) + coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type) coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) coco_eval.params.maxDets = self.max_detection_thresholds @@ -516,7 +584,7 @@ def compute(self) -> dict: return result_dict - def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[COCO, COCO]: + def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[object, object]: """Returns the coco datasets for the target and the predictions.""" if average == "micro": # for micro averaging we set everything to be the same class @@ -526,7 +594,7 @@ def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[COCO, groundtruth_labels = self.groundtruth_labels detection_labels = self.detection_labels - coco_target, coco_preds = COCO(), COCO() + coco_target, coco_preds = self.coco(), self.coco() coco_target.dataset = self._get_coco_format( labels=groundtruth_labels, @@ -571,6 +639,7 @@ def coco_to_tm( coco_preds: str, coco_target: str, iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox", + backend: Literal["pycocotools", "faster_coco_eval"] = "pycocotools", ) -> Tuple[List[Dict[str, Tensor]], List[Dict[str, Tensor]]]: """Utility function for converting .json coco format files to the input format of this metric. @@ -581,6 +650,7 @@ def coco_to_tm( coco_preds: Path to the json file containing the predictions in coco format coco_target: Path to the json file containing the targets in coco format iou_type: Type of input, either `bbox` for bounding boxes or `segm` for segmentation masks + backend: Backend to use for the conversion. Either `pycocotools` or `faster_coco_eval`. Returns: A tuple containing the predictions and targets in the input format of this metric. Each element of the @@ -599,9 +669,10 @@ def coco_to_tm( """ iou_type = _validate_iou_type_arg(iou_type) + coco, _, _ = _load_backend_tools(backend) with contextlib.redirect_stdout(io.StringIO()): - gt = COCO(coco_target) + gt = coco(coco_target) dt = gt.loadRes(coco_preds) gt_dataset = gt.dataset["annotations"] @@ -748,7 +819,7 @@ def _get_safe_item_values( if "segm" in self.iou_type: masks = [] for i in item["masks"].cpu().numpy(): - rle = mask_utils.encode(np.asfortranarray(i)) + rle = self.mask_utils.encode(np.asfortranarray(i)) masks.append((tuple(rle["size"]), rle["counts"])) output[1] = tuple(masks) if (output[0] is not None and len(output[0]) > self.max_detection_thresholds[-1]) or ( @@ -819,10 +890,12 @@ def _get_coco_format( if area is not None and area[image_id][k].cpu().tolist() > 0: area_stat = area[image_id][k].cpu().tolist() else: - area_stat = mask_utils.area(image_mask) if "segm" in self.iou_type else image_box[2] * image_box[3] + area_stat = ( + self.mask_utils.area(image_mask) if "segm" in self.iou_type else image_box[2] * image_box[3] + ) if len(self.iou_type) > 1: area_stat_box = image_box[2] * image_box[3] - area_stat_mask = mask_utils.area(image_mask) + area_stat_mask = self.mask_utils.area(image_mask) annotation = { "id": annotation_id, diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index ee1dc505477..fa6e8bfea69 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -57,5 +57,6 @@ _MULTIPROCESSING_AVAILABLE: bool = package_available("multiprocessing") _XLA_AVAILABLE: bool = package_available("torch_xla") _PIQ_GREATER_EQUAL_0_8: Optional[bool] = compare_version("piq", operator.ge, "0.8.0") +_FASTER_COCO_EVAL_AVAILABLE: bool = package_available("faster_coco_eval") _LATEX_AVAILABLE: bool = shutil.which("latex") is not None diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 438e5dea952..a3b2b762b45 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -27,7 +27,11 @@ from pycocotools.cocoeval import COCOeval from torch import IntTensor, Tensor from torchmetrics.detection.mean_ap import MeanAveragePrecision -from torchmetrics.utilities.imports import _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.imports import ( + _FASTER_COCO_EVAL_AVAILABLE, + _PYCOCOTOOLS_AVAILABLE, + _TORCHVISION_GREATER_EQUAL_0_8, +) from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL from unittests.helpers.testers import MetricTester @@ -35,6 +39,11 @@ _pytest_condition = not (_PYCOCOTOOLS_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) +def _skip_if_faster_coco_eval_missing(backend): + if backend == "faster_coco_eval" and not _FASTER_COCO_EVAL_AVAILABLE: + pytest.skip("test requires that faster_coco_eval is installed") + + def _generate_coco_inputs(iou_type): """Generates inputs for the MAP metric. @@ -118,13 +127,16 @@ def _compare_against_coco_fn(preds, target, iou_type, iou_thresholds=None, rec_t @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 and pycocotools is installed") @pytest.mark.parametrize("iou_type", ["bbox", "segm"]) @pytest.mark.parametrize("ddp", [False, True]) +@pytest.mark.parametrize("backend", ["pycocotools", "faster_coco_eval"]) class TestMAPUsingCOCOReference(MetricTester): """Test map metric on the reference coco data.""" @pytest.mark.parametrize("iou_thresholds", [None, [0.25, 0.5, 0.75]]) @pytest.mark.parametrize("rec_thresholds", [None, [0.25, 0.5, 0.75]]) - def test_map(self, iou_type, iou_thresholds, rec_thresholds, ddp): + def test_map(self, iou_type, iou_thresholds, rec_thresholds, ddp, backend): """Test modular implementation for correctness.""" + _skip_if_faster_coco_eval_missing(backend) + preds, target = _coco_bbox_input if iou_type == "bbox" else _coco_segm_input self.run_class_metric_test( ddp=ddp, @@ -144,17 +156,19 @@ def test_map(self, iou_type, iou_thresholds, rec_thresholds, ddp): "rec_thresholds": rec_thresholds, "class_metrics": False, "box_format": "xywh", + "backend": backend, }, check_batch=False, atol=1e-2, ) - def test_map_classwise(self, iou_type, ddp): + def test_map_classwise(self, iou_type, ddp, backend): """Test modular implementation for correctness with classwise=True. Needs bigger atol to be stable. """ + _skip_if_faster_coco_eval_missing(backend) preds, target = _coco_bbox_input if iou_type == "bbox" else _coco_segm_input self.run_class_metric_test( ddp=ddp, @@ -162,14 +176,17 @@ def test_map_classwise(self, iou_type, ddp): target=target, metric_class=MeanAveragePrecision, reference_metric=partial(_compare_against_coco_fn, iou_type=iou_type, class_metrics=True), - metric_args={"box_format": "xywh", "iou_type": iou_type, "class_metrics": True}, + metric_args={"box_format": "xywh", "iou_type": iou_type, "class_metrics": True, "backend": backend}, check_batch=False, atol=1e-1, ) -def test_compare_both_same_time(tmpdir): +@pytest.mark.parametrize("backend", ["pycocotools", "faster_coco_eval"]) +def test_compare_both_same_time(tmpdir, backend): """Test that the class support evaluating both bbox and segm at the same time.""" + _skip_if_faster_coco_eval_missing(backend) + with open(_DETECTION_BBOX) as f: boxes = json.load(f) with open(_DETECTION_SEGM) as f: @@ -183,7 +200,7 @@ def test_compare_both_same_time(tmpdir): batched_preds = [batched_preds[10 * i : 10 * (i + 1)] for i in range(10)] batched_target = [batched_target[10 * i : 10 * (i + 1)] for i in range(10)] - metric = MeanAveragePrecision(iou_type=["bbox", "segm"], box_format="xywh") + metric = MeanAveragePrecision(iou_type=["bbox", "segm"], box_format="xywh", backend=backend) for bp, bt in zip(batched_preds, batched_target): metric.update(bp, bt) res = metric.compute() @@ -376,452 +393,468 @@ def test_compare_both_same_time(tmpdir): ) -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -def test_error_on_wrong_init(): - """Test class raises the expected errors.""" - MeanAveragePrecision() # no error - - with pytest.raises(ValueError, match="Expected argument `class_metrics` to be a boolean"): - MeanAveragePrecision(class_metrics=0) - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -def test_empty_preds(): - """Test empty predictions.""" - metric = MeanAveragePrecision() - - metric.update( - [{"boxes": Tensor([]), "scores": Tensor([]), "labels": IntTensor([])}], - [{"boxes": Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), "labels": IntTensor([4])}], - ) - metric.compute() - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -def test_empty_ground_truths(): - """Test empty ground truths.""" - metric = MeanAveragePrecision() - - metric.update( - [ - { - "boxes": Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), - "scores": Tensor([0.5]), - "labels": IntTensor([4]), - } - ], - [{"boxes": Tensor([]), "labels": IntTensor([])}], - ) - metric.compute() - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -def test_empty_ground_truths_xywh(): - """Test empty ground truths in xywh format.""" - metric = MeanAveragePrecision(box_format="xywh") - - metric.update( - [ - { - "boxes": Tensor([[214.1500, 41.2900, 348.2600, 243.7800]]), - "scores": Tensor([0.5]), - "labels": IntTensor([4]), - } - ], - [{"boxes": Tensor([]), "labels": IntTensor([])}], - ) - metric.compute() - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -def test_empty_preds_xywh(): - """Test empty predictions in xywh format.""" - metric = MeanAveragePrecision(box_format="xywh") - - metric.update( - [{"boxes": Tensor([]), "scores": Tensor([]), "labels": IntTensor([])}], - [{"boxes": Tensor([[214.1500, 41.2900, 348.2600, 243.7800]]), "labels": IntTensor([4])}], - ) - metric.compute() +def _generate_random_segm_input(device, batch_size=2, num_preds_size=10, num_gt_size=10, random_size=True): + """Generate random inputs for mAP when iou_type=segm.""" + preds = [] + targets = [] + for _ in range(batch_size): + result = {} + num_preds = torch.randint(0, num_preds_size, (1,)).item() if random_size else num_preds_size + result["scores"] = torch.rand((num_preds,), device=device) + result["labels"] = torch.randint(0, 10, (num_preds,), device=device) + result["masks"] = torch.randint(0, 2, (num_preds, 10, 10), device=device).bool() + preds.append(result) + gt = {} + num_gt = torch.randint(0, num_gt_size, (1,)).item() if random_size else num_gt_size + gt["labels"] = torch.randint(0, 10, (num_gt,), device=device) + gt["masks"] = torch.randint(0, 2, (num_gt, 10, 10), device=device).bool() + targets.append(gt) + return preds, targets @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -def test_empty_ground_truths_cxcywh(): - """Test empty ground truths in cxcywh format.""" - metric = MeanAveragePrecision(box_format="cxcywh") - - metric.update( - [ - { - "boxes": Tensor([[388.2800, 163.1800, 348.2600, 243.7800]]), - "scores": Tensor([0.5]), - "labels": IntTensor([4]), - } - ], - [{"boxes": Tensor([]), "labels": IntTensor([])}], - ) - metric.compute() - +@pytest.mark.parametrize( + "backend", + [ + pytest.param("pycocotools"), + pytest.param( + "faster_coco_eval", + marks=pytest.mark.skipif( + not _FASTER_COCO_EVAL_AVAILABLE, reason="test requires that faster_coco_eval is installed" + ), + ), + ], +) +class TestMapProperties: + """Test class collection different tests for different properties parametrized by backend argument.""" -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -def test_empty_preds_cxcywh(): - """Test empty predictions in cxcywh format.""" - metric = MeanAveragePrecision(box_format="cxcywh") + def test_error_on_wrong_init(self, backend): + """Test class raises the expected errors.""" + MeanAveragePrecision(backend=backend) # no error - metric.update( - [{"boxes": Tensor([]), "scores": Tensor([]), "labels": IntTensor([])}], - [{"boxes": Tensor([[388.2800, 163.1800, 348.2600, 243.7800]]), "labels": IntTensor([4])}], - ) - metric.compute() + with pytest.raises(ValueError, match="Expected argument `class_metrics` to be a boolean"): + MeanAveragePrecision(class_metrics=0, backend=backend) + def test_empty_preds(self, backend): + """Test empty predictions.""" + metric = MeanAveragePrecision(backend=backend) -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires CUDA availability") -@pytest.mark.parametrize("inputs", [_inputs, _inputs2, _inputs3]) -def test_map_gpu(inputs): - """Test predictions on single gpu.""" - metric = MeanAveragePrecision() - metric = metric.to("cuda") - for preds, targets in zip(deepcopy(inputs.preds), deepcopy(inputs.target)): metric.update( - apply_to_collection(preds, Tensor, lambda x: x.to("cuda")), - apply_to_collection(targets, Tensor, lambda x: x.to("cuda")), + [{"boxes": Tensor([]), "scores": Tensor([]), "labels": IntTensor([])}], + [{"boxes": Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), "labels": IntTensor([4])}], ) - metric.compute() + metric.compute() + def test_empty_ground_truths(self, backend): + """Test empty ground truths.""" + metric = MeanAveragePrecision(backend=backend) -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires CUDA availability") -def test_map_with_custom_thresholds(): - """Test that map works with custom iou thresholds.""" - metric = MeanAveragePrecision(iou_thresholds=[0.1, 0.2]) - metric = metric.to("cuda") - for preds, targets in zip(deepcopy(_inputs.preds), deepcopy(_inputs.target)): metric.update( - apply_to_collection(preds, Tensor, lambda x: x.to("cuda")), - apply_to_collection(targets, Tensor, lambda x: x.to("cuda")), + [ + { + "boxes": Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), + "scores": Tensor([0.5]), + "labels": IntTensor([4]), + } + ], + [{"boxes": Tensor([]), "labels": IntTensor([])}], ) - res = metric.compute() - assert res["map_50"].item() == -1 - assert res["map_75"].item() == -1 - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") -def test_empty_metric(): - """Test empty metric.""" - metric = MeanAveragePrecision() - metric.compute() - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") -def test_missing_pred(): - """One good detection, one false negative. - - Map should be lower than 1. Actually it is 0.5, but the exact value depends on where we are sampling (i.e. recall's - values) - - """ - gts = [ - {"boxes": Tensor([[10, 20, 15, 25]]), "labels": IntTensor([0])}, - {"boxes": Tensor([[10, 20, 15, 25]]), "labels": IntTensor([0])}, - ] - preds = [ - {"boxes": Tensor([[10, 20, 15, 25]]), "scores": Tensor([0.9]), "labels": IntTensor([0])}, - # Empty prediction - {"boxes": Tensor([]), "scores": Tensor([]), "labels": IntTensor([])}, - ] - metric = MeanAveragePrecision() - metric.update(preds, gts) - result = metric.compute() - assert result["map"] < 1, "MAP cannot be 1, as there is a missing prediction." - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") -def test_missing_gt(): - """The symmetric case of test_missing_pred. - - One good detection, one false positive. Map should be lower than 1. Actually it is 0.5, but the exact value depends - on where we are sampling (i.e. recall's values) - - """ - gts = [ - {"boxes": Tensor([[10, 20, 15, 25]]), "labels": IntTensor([0])}, - {"boxes": Tensor([]), "labels": IntTensor([])}, - ] - preds = [ - {"boxes": Tensor([[10, 20, 15, 25]]), "scores": Tensor([0.9]), "labels": IntTensor([0])}, - {"boxes": Tensor([[10, 20, 15, 25]]), "scores": Tensor([0.95]), "labels": IntTensor([0])}, - ] - - metric = MeanAveragePrecision() - metric.update(preds, gts) - result = metric.compute() - assert result["map"] < 1, "MAP cannot be 1, as there is an image with no ground truth, but some predictions." - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -def test_segm_iou_empty_gt_mask(): - """Test empty ground truths.""" - metric = MeanAveragePrecision(iou_type="segm") - - metric.update( - [{"masks": torch.randint(0, 1, (1, 10, 10)).bool(), "scores": Tensor([0.5]), "labels": IntTensor([4])}], - [{"masks": Tensor([]), "labels": IntTensor([])}], - ) - - metric.compute() + metric.compute() + def test_empty_ground_truths_xywh(self, backend): + """Test empty ground truths in xywh format.""" + metric = MeanAveragePrecision(box_format="xywh", backend=backend) -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -def test_segm_iou_empty_pred_mask(): - """Test empty predictions.""" - metric = MeanAveragePrecision(iou_type="segm") - - metric.update( - [{"masks": torch.BoolTensor([]), "scores": Tensor([]), "labels": IntTensor([])}], - [{"masks": torch.randint(0, 1, (1, 10, 10)).bool(), "labels": IntTensor([4])}], - ) - - metric.compute() - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -def test_error_on_wrong_input(): - """Test class input validation.""" - metric = MeanAveragePrecision() - - metric.update([], []) # no error - - with pytest.raises(ValueError, match="Expected argument `preds` to be of type Sequence"): - metric.update(Tensor(), []) # type: ignore - - with pytest.raises(ValueError, match="Expected argument `target` to be of type Sequence"): - metric.update([], Tensor()) # type: ignore - - with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same length"): - metric.update([{}], [{}, {}]) - - with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `boxes` key"): metric.update( - [{"scores": Tensor(), "labels": IntTensor}], - [{"boxes": Tensor(), "labels": IntTensor()}], + [ + { + "boxes": Tensor([[214.1500, 41.2900, 348.2600, 243.7800]]), + "scores": Tensor([0.5]), + "labels": IntTensor([4]), + } + ], + [{"boxes": Tensor([]), "labels": IntTensor([])}], ) + metric.compute() - with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `scores` key"): - metric.update( - [{"boxes": Tensor(), "labels": IntTensor}], - [{"boxes": Tensor(), "labels": IntTensor()}], - ) + def test_empty_preds_xywh(self, backend): + """Test empty predictions in xywh format.""" + metric = MeanAveragePrecision(box_format="xywh", backend=backend) - with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `labels` key"): metric.update( - [{"boxes": Tensor(), "scores": IntTensor}], - [{"boxes": Tensor(), "labels": IntTensor()}], + [{"boxes": Tensor([]), "scores": Tensor([]), "labels": IntTensor([])}], + [{"boxes": Tensor([[214.1500, 41.2900, 348.2600, 243.7800]]), "labels": IntTensor([4])}], ) + metric.compute() - with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `boxes` key"): - metric.update( - [{"boxes": Tensor(), "scores": IntTensor, "labels": IntTensor}], - [{"labels": IntTensor()}], - ) + def test_empty_ground_truths_cxcywh(self, backend): + """Test empty ground truths in cxcywh format.""" + metric = MeanAveragePrecision(box_format="cxcywh", backend=backend) - with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `labels` key"): metric.update( - [{"boxes": Tensor(), "scores": IntTensor, "labels": IntTensor}], - [{"boxes": IntTensor()}], + [ + { + "boxes": Tensor([[388.2800, 163.1800, 348.2600, 243.7800]]), + "scores": Tensor([0.5]), + "labels": IntTensor([4]), + } + ], + [{"boxes": Tensor([]), "labels": IntTensor([])}], ) + metric.compute() - with pytest.raises(ValueError, match="Expected all boxes in `preds` to be of type Tensor"): - metric.update( - [{"boxes": [], "scores": Tensor(), "labels": IntTensor()}], - [{"boxes": Tensor(), "labels": IntTensor()}], - ) + def test_empty_preds_cxcywh(self, backend): + """Test empty predictions in cxcywh format.""" + metric = MeanAveragePrecision(box_format="cxcywh", backend=backend) - with pytest.raises(ValueError, match="Expected all scores in `preds` to be of type Tensor"): metric.update( - [{"boxes": Tensor(), "scores": [], "labels": IntTensor()}], - [{"boxes": Tensor(), "labels": IntTensor()}], + [{"boxes": Tensor([]), "scores": Tensor([]), "labels": IntTensor([])}], + [{"boxes": Tensor([[388.2800, 163.1800, 348.2600, 243.7800]]), "labels": IntTensor([4])}], ) + metric.compute() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires CUDA availability") + @pytest.mark.parametrize("inputs", [_inputs, _inputs2, _inputs3]) + def test_map_gpu(self, backend, inputs): + """Test predictions on single gpu.""" + metric = MeanAveragePrecision(backend=backend) + metric = metric.to("cuda") + for preds, targets in zip(deepcopy(inputs.preds), deepcopy(inputs.target)): + metric.update( + apply_to_collection(preds, Tensor, lambda x: x.to("cuda")), + apply_to_collection(targets, Tensor, lambda x: x.to("cuda")), + ) + metric.compute() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires CUDA availability") + def test_map_with_custom_thresholds(self, backend): + """Test that map works with custom iou thresholds.""" + metric = MeanAveragePrecision(iou_thresholds=[0.1, 0.2], backend=backend) + metric = metric.to("cuda") + for preds, targets in zip(deepcopy(_inputs.preds), deepcopy(_inputs.target)): + metric.update( + apply_to_collection(preds, Tensor, lambda x: x.to("cuda")), + apply_to_collection(targets, Tensor, lambda x: x.to("cuda")), + ) + res = metric.compute() + assert res["map_50"].item() == -1 + assert res["map_75"].item() == -1 + + def test_empty_metric(self, backend): + """Test empty metric.""" + metric = MeanAveragePrecision(backend=backend) + metric.compute() + + def test_missing_pred(self, backend): + """One good detection, one false negative. + + Map should be lower than 1. Actually it is 0.5, but the exact value depends on where we are sampling (i.e. + recall's values) - with pytest.raises(ValueError, match="Expected all labels in `preds` to be of type Tensor"): - metric.update( - [{"boxes": Tensor(), "scores": Tensor(), "labels": []}], - [{"boxes": Tensor(), "labels": IntTensor()}], - ) + """ + gts = [ + {"boxes": Tensor([[10, 20, 15, 25]]), "labels": IntTensor([0])}, + {"boxes": Tensor([[10, 20, 15, 25]]), "labels": IntTensor([0])}, + ] + preds = [ + {"boxes": Tensor([[10, 20, 15, 25]]), "scores": Tensor([0.9]), "labels": IntTensor([0])}, + # Empty prediction + {"boxes": Tensor([]), "scores": Tensor([]), "labels": IntTensor([])}, + ] + metric = MeanAveragePrecision(backend=backend) + metric.update(preds, gts) + result = metric.compute() + assert result["map"] < 1, "MAP cannot be 1, as there is a missing prediction." + + def test_missing_gt(self, backend): + """The symmetric case of test_missing_pred. + + One good detection, one false positive. Map should be lower than 1. Actually it is 0.5, but the exact value + depends on where we are sampling (i.e. recall's values) - with pytest.raises(ValueError, match="Expected all boxes in `target` to be of type Tensor"): + """ + gts = [ + {"boxes": Tensor([[10, 20, 15, 25]]), "labels": IntTensor([0])}, + {"boxes": Tensor([]), "labels": IntTensor([])}, + ] + preds = [ + {"boxes": Tensor([[10, 20, 15, 25]]), "scores": Tensor([0.9]), "labels": IntTensor([0])}, + {"boxes": Tensor([[10, 20, 15, 25]]), "scores": Tensor([0.95]), "labels": IntTensor([0])}, + ] + + metric = MeanAveragePrecision(backend=backend) + metric.update(preds, gts) + result = metric.compute() + assert result["map"] < 1, "MAP cannot be 1, as there is an image with no ground truth, but some predictions." + + def test_segm_iou_empty_gt_mask(self, backend): + """Test empty ground truths.""" + metric = MeanAveragePrecision(iou_type="segm", backend=backend) metric.update( - [{"boxes": Tensor(), "scores": Tensor(), "labels": IntTensor()}], - [{"boxes": [], "labels": IntTensor()}], + [{"masks": torch.randint(0, 1, (1, 10, 10)).bool(), "scores": Tensor([0.5]), "labels": IntTensor([4])}], + [{"masks": Tensor([]), "labels": IntTensor([])}], ) + metric.compute() - with pytest.raises(ValueError, match="Expected all labels in `target` to be of type Tensor"): + def test_segm_iou_empty_pred_mask(self, backend): + """Test empty predictions.""" + metric = MeanAveragePrecision(iou_type="segm", backend=backend) metric.update( - [{"boxes": Tensor(), "scores": Tensor(), "labels": IntTensor()}], - [{"boxes": Tensor(), "labels": []}], + [{"masks": torch.BoolTensor([]), "scores": Tensor([]), "labels": IntTensor([])}], + [{"masks": torch.randint(0, 1, (1, 10, 10)).bool(), "labels": IntTensor([4])}], ) + metric.compute() + + def test_error_on_wrong_input(self, backend): + """Test class input validation.""" + metric = MeanAveragePrecision(backend=backend) + + metric.update([], []) # no error + + with pytest.raises(ValueError, match="Expected argument `preds` to be of type Sequence"): + metric.update(Tensor(), []) # type: ignore + + with pytest.raises(ValueError, match="Expected argument `target` to be of type Sequence"): + metric.update([], Tensor()) # type: ignore + + with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same length"): + metric.update([{}], [{}, {}]) + + with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `boxes` key"): + metric.update( + [{"scores": Tensor(), "labels": IntTensor}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `scores` key"): + metric.update( + [{"boxes": Tensor(), "labels": IntTensor}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `labels` key"): + metric.update( + [{"boxes": Tensor(), "scores": IntTensor}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `boxes` key"): + metric.update( + [{"boxes": Tensor(), "scores": IntTensor, "labels": IntTensor}], + [{"labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `labels` key"): + metric.update( + [{"boxes": Tensor(), "scores": IntTensor, "labels": IntTensor}], + [{"boxes": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all boxes in `preds` to be of type Tensor"): + metric.update( + [{"boxes": [], "scores": Tensor(), "labels": IntTensor()}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all scores in `preds` to be of type Tensor"): + metric.update( + [{"boxes": Tensor(), "scores": [], "labels": IntTensor()}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all labels in `preds` to be of type Tensor"): + metric.update( + [{"boxes": Tensor(), "scores": Tensor(), "labels": []}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all boxes in `target` to be of type Tensor"): + metric.update( + [{"boxes": Tensor(), "scores": Tensor(), "labels": IntTensor()}], + [{"boxes": [], "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all labels in `target` to be of type Tensor"): + metric.update( + [{"boxes": Tensor(), "scores": Tensor(), "labels": IntTensor()}], + [{"boxes": Tensor(), "labels": []}], + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + def test_device_changing(self, backend): + """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1743. + + Checks that the custom apply function of the metric works as expected. + """ + device = "cuda" + metric = MeanAveragePrecision(iou_type="segm", backend=backend).to(device) + for _ in range(2): + preds, targets = _generate_random_segm_input(device) + metric.update(preds, targets) -def _generate_random_segm_input(device, batch_size=2, num_preds_size=10, num_gt_size=10, random_size=True): - """Generate random inputs for mAP when iou_type=segm.""" - preds = [] - targets = [] - for _ in range(batch_size): - result = {} - num_preds = torch.randint(0, num_preds_size, (1,)).item() if random_size else num_preds_size - result["scores"] = torch.rand((num_preds,), device=device) - result["labels"] = torch.randint(0, 10, (num_preds,), device=device) - result["masks"] = torch.randint(0, 2, (num_preds, 10, 10), device=device).bool() - preds.append(result) - gt = {} - num_gt = torch.randint(0, num_gt_size, (1,)).item() if random_size else num_gt_size - gt["labels"] = torch.randint(0, 10, (num_gt,), device=device) - gt["masks"] = torch.randint(0, 2, (num_gt, 10, 10), device=device).bool() - targets.append(gt) - return preds, targets + metric = metric.cpu() + val = metric.compute() + assert isinstance(val, dict) + @pytest.mark.parametrize( + ("box_format", "iou_val_expected", "map_val_expected"), + [ + ("xyxy", 0.25, 1), + ("xywh", 0.143, 0.0), + ("cxcywh", 0.143, 0.0), + ], + ) + def test_for_box_format(self, box_format, iou_val_expected, map_val_expected, backend): + """Test that only the correct box format lead to a score of 1. -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") -def test_device_changing(): - """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1743. + See issue: https://github.com/Lightning-AI/torchmetrics/issues/1908. - Checks that the custom apply function of the metric works as expected. - """ - device = "cuda" - metric = MeanAveragePrecision(iou_type="segm").to(device) + """ + predictions = [ + {"boxes": torch.tensor([[0.5, 0.5, 1, 1]]), "scores": torch.tensor([1.0]), "labels": torch.tensor([0])} + ] - for _ in range(2): - preds, targets = _generate_random_segm_input(device) - metric.update(preds, targets) + targets = [{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}] - metric = metric.cpu() - val = metric.compute() - assert isinstance(val, dict) + metric = MeanAveragePrecision( + box_format=box_format, iou_thresholds=[0.2], extended_summary=True, backend=backend + ) + metric.update(predictions, targets) + result = metric.compute() + assert result["map"].item() == map_val_expected + assert round(float(result["ious"][(0, 0)]), 3) == iou_val_expected + + @pytest.mark.parametrize("iou_type", ["bbox", "segm"]) + def test_warning_on_many_detections(self, iou_type, backend): + """Test that a warning is raised when there are many detections.""" + if iou_type == "bbox": + preds = [ + { + "boxes": torch.tensor([[0.5, 0.5, 1, 1]]).repeat(101, 1), + "scores": torch.tensor([1.0]).repeat(101), + "labels": torch.tensor([0]).repeat(101), + } + ] + targets = [{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}] + else: + preds, targets = _generate_random_segm_input("cpu", 1, 101, 10, False) + + metric = MeanAveragePrecision(iou_type=iou_type, backend=backend) + with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"): + metric.update(preds, targets) + + @pytest.mark.parametrize( + ("preds", "target", "expected_iou_len", "iou_keys", "precision_shape", "recall_shape"), + [ + ( + [ + [ + { + "boxes": torch.tensor([[0.5, 0.5, 1, 1]]), + "scores": torch.tensor([1.0]), + "labels": torch.tensor([0]), + } + ] + ], + [[{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}]], + 1, # 1 image x 1 class = 1 + [(0, 0)], + (10, 101, 1, 4, 3), + (10, 1, 4, 3), + ), + ( + _inputs.preds, + _inputs.target, + 24, # 4 images x 6 classes = 24 + list(product([0, 1, 2, 3], [0, 1, 2, 3, 4, 49])), + (10, 101, 6, 4, 3), + (10, 6, 4, 3), + ), + ], + ) + def test_for_extended_stats( + self, preds, target, expected_iou_len, iou_keys, precision_shape, recall_shape, backend + ): + """Test that extended stats are computed correctly.""" + metric = MeanAveragePrecision(extended_summary=True, backend=backend) + for p, t in zip(preds, target): + metric.update(p, t) + result = metric.compute() + ious = result["ious"] -@pytest.mark.parametrize( - ("box_format", "iou_val_expected", "map_val_expected"), - [ - ("xyxy", 0.25, 1), - ("xywh", 0.143, 0.0), - ("cxcywh", 0.143, 0.0), - ], -) -def test_for_box_format(box_format, iou_val_expected, map_val_expected): - """Test that only the correct box format lead to a score of 1. + assert isinstance(ious, dict) + assert len(ious) == expected_iou_len + for key in ious: + assert key in iou_keys - See issue: https://github.com/Lightning-AI/torchmetrics/issues/1908. + precision = result["precision"] + assert isinstance(precision, Tensor) + assert precision.shape == precision_shape - """ - predictions = [ - {"boxes": torch.tensor([[0.5, 0.5, 1, 1]]), "scores": torch.tensor([1.0]), "labels": torch.tensor([0])} - ] + recall = result["recall"] + assert isinstance(recall, Tensor) + assert recall.shape == recall_shape - targets = [{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}] + @pytest.mark.parametrize("class_metrics", [False, True]) + def test_average_argument(self, class_metrics, backend): + """Test that average argument works. - metric = MeanAveragePrecision(box_format=box_format, iou_thresholds=[0.2], extended_summary=True) - metric.update(predictions, targets) - result = metric.compute() - assert result["map"].item() == map_val_expected - assert round(float(result["ious"][(0, 0)]), 3) == iou_val_expected + Calculating macro on inputs that only have one label should be the same as micro. Calculating class metrics + should be the same regardless of average argument. + """ + if class_metrics: + _preds = _inputs.preds + _target = _inputs.target + else: + _preds = apply_to_collection(deepcopy(_inputs.preds), IntTensor, lambda x: torch.ones_like(x)) + _target = apply_to_collection(deepcopy(_inputs.target), IntTensor, lambda x: torch.ones_like(x)) + + metric_macro = MeanAveragePrecision(average="macro", class_metrics=class_metrics, backend=backend) + metric_macro.update(_preds[0], _target[0]) + metric_macro.update(_preds[1], _target[1]) + result_macro = metric_macro.compute() + + metric_micro = MeanAveragePrecision(average="micro", class_metrics=class_metrics, backend=backend) + metric_micro.update(_inputs.preds[0], _inputs.target[0]) + metric_micro.update(_inputs.preds[1], _inputs.target[1]) + result_micro = metric_micro.compute() + + if class_metrics: + assert torch.allclose(result_macro["map_per_class"], result_micro["map_per_class"]) + assert torch.allclose(result_macro["mar_100_per_class"], result_micro["mar_100_per_class"]) + else: + for key in result_macro: + if key == "classes": + continue + assert torch.allclose(result_macro[key], result_micro[key]) + + def test_many_detection_thresholds(self, backend): + """Test how metric behaves when there are many detection thresholds. + + Known to fail with the default pycocotools backend. + See issue: https://github.com/Lightning-AI/torchmetrics/issues/1153 -@pytest.mark.parametrize("iou_type", ["bbox", "segm"]) -def test_warning_on_many_detections(iou_type): - """Test that a warning is raised when there are many detections.""" - if iou_type == "bbox": + """ preds = [ { - "boxes": torch.tensor([[0.5, 0.5, 1, 1]]).repeat(101, 1), - "scores": torch.tensor([1.0]).repeat(101), - "labels": torch.tensor([0]).repeat(101), + "boxes": torch.tensor([[258.0, 41.0, 606.0, 285.0]]), + "scores": torch.tensor([0.536]), + "labels": torch.tensor([0]), } ] - targets = [{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}] - else: - preds, targets = _generate_random_segm_input("cpu", 1, 101, 10, False) - - metric = MeanAveragePrecision(iou_type=iou_type) - with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"): - metric.update(preds, targets) - - -@pytest.mark.parametrize( - ("preds", "target", "expected_iou_len", "iou_keys", "precision_shape", "recall_shape"), - [ - ( - [[{"boxes": torch.tensor([[0.5, 0.5, 1, 1]]), "scores": torch.tensor([1.0]), "labels": torch.tensor([0])}]], - [[{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}]], - 1, # 1 image x 1 class = 1 - [(0, 0)], - (10, 101, 1, 4, 3), - (10, 1, 4, 3), - ), - ( - _inputs.preds, - _inputs.target, - 24, # 4 images x 6 classes = 24 - product([0, 1, 2, 3], [0, 1, 2, 3, 4, 49]), - (10, 101, 6, 4, 3), - (10, 6, 4, 3), - ), - ], -) -def test_for_extended_stats(preds, target, expected_iou_len, iou_keys, precision_shape, recall_shape): - """Test that extended stats are computed correctly.""" - metric = MeanAveragePrecision(extended_summary=True) - for ( - p, - t, - ) in zip(preds, target): - metric.update(p, t) - result = metric.compute() - - ious = result["ious"] - assert isinstance(ious, dict) - assert len(ious) == expected_iou_len - for key in ious: - assert key in iou_keys - - precision = result["precision"] - assert isinstance(precision, Tensor) - assert precision.shape == precision_shape - - recall = result["recall"] - assert isinstance(recall, Tensor) - assert recall.shape == recall_shape - - -@pytest.mark.parametrize("class_metrics", [False, True]) -def test_average_argument(class_metrics): - """Test that average argument works. - - Calculating macro on inputs that only have one label should be the same as micro. Calculating class metrics should - be the same regardless of average argument. - - """ - if class_metrics: - _preds = _inputs.preds - _target = _inputs.target - else: - _preds = apply_to_collection(deepcopy(_inputs.preds), IntTensor, lambda x: torch.ones_like(x)) - _target = apply_to_collection(deepcopy(_inputs.target), IntTensor, lambda x: torch.ones_like(x)) - - metric_macro = MeanAveragePrecision(average="macro", class_metrics=class_metrics) - metric_macro.update(_preds[0], _target[0]) - metric_macro.update(_preds[1], _target[1]) - result_macro = metric_macro.compute() - - metric_micro = MeanAveragePrecision(average="micro", class_metrics=class_metrics) - metric_micro.update(_inputs.preds[0], _inputs.target[0]) - metric_micro.update(_inputs.preds[1], _inputs.target[1]) - result_micro = metric_micro.compute() + target = [ + { + "boxes": torch.tensor([[214.0, 41.0, 562.0, 285.0]]), + "labels": torch.tensor([0]), + } + ] + metric = MeanAveragePrecision(max_detection_thresholds=[1, 10, 1000], backend=backend) + res = metric(preds, target) - if class_metrics: - assert torch.allclose(result_macro["map_per_class"], result_micro["map_per_class"]) - assert torch.allclose(result_macro["mar_100_per_class"], result_micro["mar_100_per_class"]) - else: - for key in result_macro: - if key == "classes": - continue - assert torch.allclose(result_macro[key], result_micro[key]) + if backend == "pycocotools": + assert round(res["map"].item(), 5) != 0.6 + else: + assert round(res["map"].item(), 5) == 0.6