Skip to content

Commit cff5fa7

Browse files
SkafteNickiBordamergify[bot]
authored andcommitted
MeanAveragePrecision backend (Lightning-AI#2034)
* implementation * imports + requirements + links * add tests * clarify language * Apply suggestions from code review * fix line length --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 33d1240 commit cff5fa7

File tree

5 files changed

+533
-420
lines changed

5 files changed

+533
-420
lines changed

docs/source/links.rst

+2
Original file line numberDiff line numberDiff line change
@@ -154,5 +154,7 @@
154154
.. _Normalized Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html
155155
.. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools
156156
.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075
157+
.. _faster-coco-eval: https://github.com/MiXaiLL76/faster_coco_eval
158+
.. _fork of pycocotools: https://github.com/ppwwyyxx/cocoapi
157159
.. _Adjusted Rand Score: https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index
158160
.. _Dunn Index: https://en.wikipedia.org/wiki/Dunn_index

requirements/detection_test.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
2+
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
3+
4+
faster-coco-eval >=1.3.3

src/torchmetrics/detection/mean_ap.py

+94-21
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import contextlib
1515
import io
1616
import json
17-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
17+
from types import ModuleType
18+
from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, Union
1819

1920
import numpy as np
2021
import torch
@@ -27,6 +28,7 @@
2728
from torchmetrics.metric import Metric
2829
from torchmetrics.utilities import rank_zero_warn
2930
from torchmetrics.utilities.imports import (
31+
_FASTER_COCO_EVAL_AVAILABLE,
3032
_MATPLOTLIB_AVAILABLE,
3133
_PYCOCOTOOLS_AVAILABLE,
3234
_TORCHVISION_GREATER_EQUAL_0_8,
@@ -48,14 +50,7 @@
4850
"MeanAveragePrecision.coco_to_tm",
4951
]
5052

51-
52-
if _PYCOCOTOOLS_AVAILABLE:
53-
import pycocotools.mask as mask_utils
54-
from pycocotools.coco import COCO
55-
from pycocotools.cocoeval import COCOeval
56-
else:
57-
COCO, COCOeval = None, None
58-
mask_utils = None
53+
if not _PYCOCOTOOLS_AVAILABLE:
5954
__doctest_skip__ = [
6055
"MeanAveragePrecision.plot",
6156
"MeanAveragePrecision",
@@ -64,6 +59,32 @@
6459
]
6560

6661

62+
def _load_backend_tools(backend: Literal["pycocotools", "faster_coco_eval"]) -> Tuple[object, object, ModuleType]:
63+
"""Load the backend tools for the given backend."""
64+
if backend == "pycocotools":
65+
if not _PYCOCOTOOLS_AVAILABLE:
66+
raise ModuleNotFoundError(
67+
"Backend `pycocotools` in metric `MeanAveragePrecision` metric requires that `pycocotools` is"
68+
" installed. Please install with `pip install pycocotools` or `pip install torchmetrics[detection]`"
69+
)
70+
import pycocotools.mask as mask_utils
71+
from pycocotools.coco import COCO
72+
from pycocotools.cocoeval import COCOeval
73+
74+
return COCO, COCOeval, mask_utils
75+
76+
if not _FASTER_COCO_EVAL_AVAILABLE:
77+
raise ModuleNotFoundError(
78+
"Backend `faster_coco_eval` in metric `MeanAveragePrecision` metric requires that `faster-coco-eval` is"
79+
" installed. Please install with `pip install faster-coco-eval`."
80+
)
81+
from faster_coco_eval import COCO
82+
from faster_coco_eval import COCOeval_faster as COCOeval
83+
from faster_coco_eval.core import mask as mask_utils
84+
85+
return COCO, COCOeval, mask_utils
86+
87+
6788
class MeanAveragePrecision(Metric):
6889
r"""Compute the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)`_ for object detection predictions.
6990
@@ -142,9 +163,16 @@ class MeanAveragePrecision(Metric):
142163
Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well.
143164
144165
.. note::
145-
This metric utilizes the official `pycocotools` implementation as its backend. This means that the metric
146-
requires you to have `pycocotools` installed. In addition we require `torchvision` version 0.8.0 or newer.
147-
Please install with ``pip install torchmetrics[detection]``.
166+
This metric supports, at the moment, two different backends for the evaluation. The default backend is
167+
``"pycocotools"``, which either require the official `pycocotools`_ implementation or this
168+
`fork of pycocotools`_ to be installed. We recommend using the fork as it is better maintained and easily
169+
available to install via pip: `pip install pycocotools`. It is also this fork that will be installed if you
170+
install ``torchmetrics[detection]``. The second backend is the `faster-coco-eval`_ implementation, which can be
171+
installed with ``pip install faster-coco-eval``. This implementation is a maintained open-source implementation
172+
that is faster and corrects certain corner cases that the official implementation has. Our own testing has shown
173+
that the results are identical to the official implementation. Regardless of the backend we also require you to
174+
have `torchvision` version 0.8.0 or newer installed. Please install with ``pip install torchvision>=0.8`` or
175+
``pip install torchmetrics[detection]``.
148176
149177
Args:
150178
box_format:
@@ -188,7 +216,9 @@ class MeanAveragePrecision(Metric):
188216
of max detections per image.
189217
190218
average:
191-
Method for averaging scores over labels. Choose between "``macro``"" and "``micro``". Default is "macro"
219+
Method for averaging scores over labels. Choose between "``"macro"`` and ``"micro"``.
220+
backend:
221+
Backend to use for the evaluation. Choose between ``"pycocotools"`` and ``"faster_coco_eval"``.
192222
193223
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
194224
@@ -323,6 +353,19 @@ class MeanAveragePrecision(Metric):
323353

324354
warn_on_many_detections: bool = True
325355

356+
__jit_unused_properties__: ClassVar[List[str]] = [
357+
"is_differentiable",
358+
"higher_is_better",
359+
"plot_lower_bound",
360+
"plot_upper_bound",
361+
"plot_legend_name",
362+
"metric_state",
363+
# below is added for specifically for this metric
364+
"coco",
365+
"cocoeval",
366+
"mask_utils",
367+
]
368+
326369
def __init__(
327370
self,
328371
box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
@@ -333,6 +376,7 @@ def __init__(
333376
class_metrics: bool = False,
334377
extended_summary: bool = False,
335378
average: Literal["macro", "micro"] = "macro",
379+
backend: Literal["pycocotools", "faster_coco_eval"] = "pycocotools",
336380
**kwargs: Any,
337381
) -> None:
338382
super().__init__(**kwargs)
@@ -387,6 +431,12 @@ def __init__(
387431
raise ValueError(f"Expected argument `average` to be one of ('macro', 'micro') but got {average}")
388432
self.average = average
389433

434+
if backend not in ("pycocotools", "faster_coco_eval"):
435+
raise ValueError(
436+
f"Expected argument `backend` to be one of ('pycocotools', 'faster_coco_eval') but got {backend}"
437+
)
438+
self.backend = backend
439+
390440
self.add_state("detection_box", default=[], dist_reduce_fx=None)
391441
self.add_state("detection_mask", default=[], dist_reduce_fx=None)
392442
self.add_state("detection_scores", default=[], dist_reduce_fx=None)
@@ -397,6 +447,24 @@ def __init__(
397447
self.add_state("groundtruth_crowds", default=[], dist_reduce_fx=None)
398448
self.add_state("groundtruth_area", default=[], dist_reduce_fx=None)
399449

450+
@property
451+
def coco(self) -> object:
452+
"""Returns the coco module for the given backend, done in this way to make metric picklable."""
453+
coco, _, _ = _load_backend_tools(self.backend)
454+
return coco
455+
456+
@property
457+
def cocoeval(self) -> object:
458+
"""Returns the coco eval module for the given backend, done in this way to make metric picklable."""
459+
_, cocoeval, _ = _load_backend_tools(self.backend)
460+
return cocoeval
461+
462+
@property
463+
def mask_utils(self) -> object:
464+
"""Returns the mask utils object for the given backend, done in this way to make metric picklable."""
465+
_, _, mask_utils = _load_backend_tools(self.backend)
466+
return mask_utils
467+
400468
def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None:
401469
"""Update metric state.
402470
@@ -454,7 +522,7 @@ def compute(self) -> dict:
454522
for anno in coco_preds.dataset["annotations"]:
455523
anno["area"] = anno[f"area_{i_type}"]
456524

457-
coco_eval = COCOeval(coco_target, coco_preds, iouType=i_type)
525+
coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type)
458526
coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64)
459527
coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64)
460528
coco_eval.params.maxDets = self.max_detection_thresholds
@@ -482,7 +550,7 @@ def compute(self) -> dict:
482550
# since micro averaging have all the data in one class, we need to reinitialize the coco_eval
483551
# object in macro mode to get the per class stats
484552
coco_preds, coco_target = self._get_coco_datasets(average="macro")
485-
coco_eval = COCOeval(coco_target, coco_preds, iouType=i_type)
553+
coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type)
486554
coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64)
487555
coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64)
488556
coco_eval.params.maxDets = self.max_detection_thresholds
@@ -516,7 +584,7 @@ def compute(self) -> dict:
516584

517585
return result_dict
518586

519-
def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[COCO, COCO]:
587+
def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[object, object]:
520588
"""Returns the coco datasets for the target and the predictions."""
521589
if average == "micro":
522590
# for micro averaging we set everything to be the same class
@@ -526,7 +594,7 @@ def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[COCO,
526594
groundtruth_labels = self.groundtruth_labels
527595
detection_labels = self.detection_labels
528596

529-
coco_target, coco_preds = COCO(), COCO()
597+
coco_target, coco_preds = self.coco(), self.coco()
530598

531599
coco_target.dataset = self._get_coco_format(
532600
labels=groundtruth_labels,
@@ -571,6 +639,7 @@ def coco_to_tm(
571639
coco_preds: str,
572640
coco_target: str,
573641
iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox",
642+
backend: Literal["pycocotools", "faster_coco_eval"] = "pycocotools",
574643
) -> Tuple[List[Dict[str, Tensor]], List[Dict[str, Tensor]]]:
575644
"""Utility function for converting .json coco format files to the input format of this metric.
576645
@@ -581,6 +650,7 @@ def coco_to_tm(
581650
coco_preds: Path to the json file containing the predictions in coco format
582651
coco_target: Path to the json file containing the targets in coco format
583652
iou_type: Type of input, either `bbox` for bounding boxes or `segm` for segmentation masks
653+
backend: Backend to use for the conversion. Either `pycocotools` or `faster_coco_eval`.
584654
585655
Returns:
586656
A tuple containing the predictions and targets in the input format of this metric. Each element of the
@@ -599,9 +669,10 @@ def coco_to_tm(
599669
600670
"""
601671
iou_type = _validate_iou_type_arg(iou_type)
672+
coco, _, _ = _load_backend_tools(backend)
602673

603674
with contextlib.redirect_stdout(io.StringIO()):
604-
gt = COCO(coco_target)
675+
gt = coco(coco_target)
605676
dt = gt.loadRes(coco_preds)
606677

607678
gt_dataset = gt.dataset["annotations"]
@@ -748,7 +819,7 @@ def _get_safe_item_values(
748819
if "segm" in self.iou_type:
749820
masks = []
750821
for i in item["masks"].cpu().numpy():
751-
rle = mask_utils.encode(np.asfortranarray(i))
822+
rle = self.mask_utils.encode(np.asfortranarray(i))
752823
masks.append((tuple(rle["size"]), rle["counts"]))
753824
output[1] = tuple(masks)
754825
if (output[0] is not None and len(output[0]) > self.max_detection_thresholds[-1]) or (
@@ -819,10 +890,12 @@ def _get_coco_format(
819890
if area is not None and area[image_id][k].cpu().tolist() > 0:
820891
area_stat = area[image_id][k].cpu().tolist()
821892
else:
822-
area_stat = mask_utils.area(image_mask) if "segm" in self.iou_type else image_box[2] * image_box[3]
893+
area_stat = (
894+
self.mask_utils.area(image_mask) if "segm" in self.iou_type else image_box[2] * image_box[3]
895+
)
823896
if len(self.iou_type) > 1:
824897
area_stat_box = image_box[2] * image_box[3]
825-
area_stat_mask = mask_utils.area(image_mask)
898+
area_stat_mask = self.mask_utils.area(image_mask)
826899

827900
annotation = {
828901
"id": annotation_id,

src/torchmetrics/utilities/imports.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,6 @@
5757
_MULTIPROCESSING_AVAILABLE: bool = package_available("multiprocessing")
5858
_XLA_AVAILABLE: bool = package_available("torch_xla")
5959
_PIQ_GREATER_EQUAL_0_8: Optional[bool] = compare_version("piq", operator.ge, "0.8.0")
60+
_FASTER_COCO_EVAL_AVAILABLE: bool = package_available("faster_coco_eval")
6061

6162
_LATEX_AVAILABLE: bool = shutil.which("latex") is not None

0 commit comments

Comments
 (0)