Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MeanAveragePrecision backend #2034

Merged
merged 13 commits into from
Sep 8, 2023
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,5 @@
.. _Normalized Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html
.. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools
.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075
.. _faster-coco-eval: https://github.com/MiXaiLL76/faster_coco_eval
.. _fork of pycocotools: https://github.com/ppwwyyxx/cocoapi
4 changes: 4 additions & 0 deletions requirements/detection_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

faster-coco-eval >=1.3.3
115 changes: 94 additions & 21 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import contextlib
import io
import json
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from types import ModuleType
from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -27,6 +28,7 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.imports import (
_FASTER_COCO_EVAL_AVAILABLE,
_MATPLOTLIB_AVAILABLE,
_PYCOCOTOOLS_AVAILABLE,
_TORCHVISION_GREATER_EQUAL_0_8,
Expand All @@ -48,14 +50,7 @@
"MeanAveragePrecision.coco_to_tm",
]


if _PYCOCOTOOLS_AVAILABLE:
import pycocotools.mask as mask_utils
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
else:
COCO, COCOeval = None, None
mask_utils = None
if not _PYCOCOTOOLS_AVAILABLE:
__doctest_skip__ = [
"MeanAveragePrecision.plot",
"MeanAveragePrecision",
Expand All @@ -64,6 +59,32 @@
]


def _load_backend_tools(backend: Literal["pycocotools", "faster_coco_eval"]) -> Tuple[object, object, ModuleType]:
"""Load the backend tools for the given backend."""
if backend == "pycocotools":
if not _PYCOCOTOOLS_AVAILABLE:
raise ModuleNotFoundError(
"Backend `pycocotools` in metric `MeanAveragePrecision` metric requires that `pycocotools` is"
" installed. Please install with `pip install pycocotools` or `pip install torchmetrics[detection]`"
)
import pycocotools.mask as mask_utils
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

return COCO, COCOeval, mask_utils

if not _FASTER_COCO_EVAL_AVAILABLE:
raise ModuleNotFoundError(
"Backend `faster_coco_eval` in metric `MeanAveragePrecision` metric requires that `faster-coco-eval` is"
" installed. Please install with `pip install faster-coco-eval`."
)
from faster_coco_eval import COCO
from faster_coco_eval import COCOeval_faster as COCOeval
from faster_coco_eval.core import mask as mask_utils

return COCO, COCOeval, mask_utils


class MeanAveragePrecision(Metric):
r"""Compute the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)`_ for object detection predictions.

Expand Down Expand Up @@ -142,9 +163,16 @@ class MeanAveragePrecision(Metric):
Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well.

.. note::
This metric utilizes the official `pycocotools` implementation as its backend. This means that the metric
requires you to have `pycocotools` installed. In addition we require `torchvision` version 0.8.0 or newer.
Please install with ``pip install torchmetrics[detection]``.
This metric supports, at the moment, two different backends for the evaluation. The default backend is
``"pycocotools"``, which either require the official `pycocotools`_ implementation or this
`fork of pycocotools`_ to be installed. We recommend using the fork as it is better maintained and easily
available to install via pip: `pip install pycocotools`. It is also this fork that will be installed if you
install ``torchmetrics[detection]``. The second backend is the `faster-coco-eval`_ implementation, which can be
installed with ``pip install faster-coco-eval``. This implementation is a maintained open-source implementation
that is faster and corrects certain corner cases that the official implementation has. Our own testing has shown
that the results are identical to the official implementation. Regardless of the backend we also require you to
have `torchvision` version 0.8.0 or newer installed. Please install with ``pip install torchvision>=0.8`` or
``pip install torchmetrics[detection]``.

Args:
box_format:
Expand Down Expand Up @@ -188,7 +216,9 @@ class MeanAveragePrecision(Metric):
of max detections per image.

average:
Method for averaging scores over labels. Choose between "``macro``"" and "``micro``". Default is "macro"
Method for averaging scores over labels. Choose between "``"macro"`` and ``"micro"``.
backend:
Backend to use for the evaluation. Choose between ``"pycocotools"`` and ``"faster_coco_eval"``.

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Expand Down Expand Up @@ -323,6 +353,19 @@ class MeanAveragePrecision(Metric):

warn_on_many_detections: bool = True

__jit_unused_properties__: ClassVar[List[str]] = [
"is_differentiable",
"higher_is_better",
"plot_lower_bound",
"plot_upper_bound",
"plot_legend_name",
"metric_state",
# below is added for specifically for this metric
"coco",
"cocoeval",
"mask_utils",
]

def __init__(
self,
box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
Expand All @@ -333,6 +376,7 @@ def __init__(
class_metrics: bool = False,
extended_summary: bool = False,
average: Literal["macro", "micro"] = "macro",
backend: Literal["pycocotools", "faster_coco_eval"] = "pycocotools",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -387,6 +431,12 @@ def __init__(
raise ValueError(f"Expected argument `average` to be one of ('macro', 'micro') but got {average}")
self.average = average

if backend not in ("pycocotools", "faster_coco_eval"):
raise ValueError(
f"Expected argument `backend` to be one of ('pycocotools', 'faster_coco_eval') but got {backend}"
)
self.backend = backend

self.add_state("detection_box", default=[], dist_reduce_fx=None)
self.add_state("detection_mask", default=[], dist_reduce_fx=None)
self.add_state("detection_scores", default=[], dist_reduce_fx=None)
Expand All @@ -397,6 +447,24 @@ def __init__(
self.add_state("groundtruth_crowds", default=[], dist_reduce_fx=None)
self.add_state("groundtruth_area", default=[], dist_reduce_fx=None)

@property
def coco(self) -> object:
"""Returns the coco module for the given backend, done in this way to make metric picklable."""
coco, _, _ = _load_backend_tools(self.backend)
return coco

@property
def cocoeval(self) -> object:
"""Returns the coco eval module for the given backend, done in this way to make metric picklable."""
_, cocoeval, _ = _load_backend_tools(self.backend)
return cocoeval

@property
def mask_utils(self) -> object:
"""Returns the mask utils object for the given backend, done in this way to make metric picklable."""
_, _, mask_utils = _load_backend_tools(self.backend)
return mask_utils

def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None:
"""Update metric state.

Expand Down Expand Up @@ -454,7 +522,7 @@ def compute(self) -> dict:
for anno in coco_preds.dataset["annotations"]:
anno["area"] = anno[f"area_{i_type}"]

coco_eval = COCOeval(coco_target, coco_preds, iouType=i_type)
coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type)
coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64)
coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64)
coco_eval.params.maxDets = self.max_detection_thresholds
Expand Down Expand Up @@ -482,7 +550,7 @@ def compute(self) -> dict:
# since micro averaging have all the data in one class, we need to reinitialize the coco_eval
# object in macro mode to get the per class stats
coco_preds, coco_target = self._get_coco_datasets(average="macro")
coco_eval = COCOeval(coco_target, coco_preds, iouType=i_type)
coco_eval = self.cocoeval(coco_target, coco_preds, iouType=i_type)
coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64)
coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64)
coco_eval.params.maxDets = self.max_detection_thresholds
Expand Down Expand Up @@ -516,7 +584,7 @@ def compute(self) -> dict:

return result_dict

def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[COCO, COCO]:
def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[object, object]:
"""Returns the coco datasets for the target and the predictions."""
if average == "micro":
# for micro averaging we set everything to be the same class
Expand All @@ -526,7 +594,7 @@ def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[COCO,
groundtruth_labels = self.groundtruth_labels
detection_labels = self.detection_labels

coco_target, coco_preds = COCO(), COCO()
coco_target, coco_preds = self.coco(), self.coco()

coco_target.dataset = self._get_coco_format(
labels=groundtruth_labels,
Expand Down Expand Up @@ -571,6 +639,7 @@ def coco_to_tm(
coco_preds: str,
coco_target: str,
iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox",
backend: Literal["pycocotools", "faster_coco_eval"] = "pycocotools",
) -> Tuple[List[Dict[str, Tensor]], List[Dict[str, Tensor]]]:
"""Utility function for converting .json coco format files to the input format of this metric.

Expand All @@ -581,6 +650,7 @@ def coco_to_tm(
coco_preds: Path to the json file containing the predictions in coco format
coco_target: Path to the json file containing the targets in coco format
iou_type: Type of input, either `bbox` for bounding boxes or `segm` for segmentation masks
backend: Backend to use for the conversion. Either `pycocotools` or `faster_coco_eval`.

Returns:
A tuple containing the predictions and targets in the input format of this metric. Each element of the
Expand All @@ -599,9 +669,10 @@ def coco_to_tm(

"""
iou_type = _validate_iou_type_arg(iou_type)
coco, _, _ = _load_backend_tools(backend)

with contextlib.redirect_stdout(io.StringIO()):
gt = COCO(coco_target)
gt = coco(coco_target)
dt = gt.loadRes(coco_preds)

gt_dataset = gt.dataset["annotations"]
Expand Down Expand Up @@ -748,7 +819,7 @@ def _get_safe_item_values(
if "segm" in self.iou_type:
masks = []
for i in item["masks"].cpu().numpy():
rle = mask_utils.encode(np.asfortranarray(i))
rle = self.mask_utils.encode(np.asfortranarray(i))
masks.append((tuple(rle["size"]), rle["counts"]))
output[1] = tuple(masks)
if (output[0] is not None and len(output[0]) > self.max_detection_thresholds[-1]) or (
Expand Down Expand Up @@ -819,10 +890,12 @@ def _get_coco_format(
if area is not None and area[image_id][k].cpu().tolist() > 0:
area_stat = area[image_id][k].cpu().tolist()
else:
area_stat = mask_utils.area(image_mask) if "segm" in self.iou_type else image_box[2] * image_box[3]
area_stat = (
self.mask_utils.area(image_mask) if "segm" in self.iou_type else image_box[2] * image_box[3]
)
if len(self.iou_type) > 1:
area_stat_box = image_box[2] * image_box[3]
area_stat_mask = mask_utils.area(image_mask)
area_stat_mask = self.mask_utils.area(image_mask)

annotation = {
"id": annotation_id,
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,6 @@
_MULTIPROCESSING_AVAILABLE: bool = package_available("multiprocessing")
_XLA_AVAILABLE: bool = package_available("torch_xla")
_PIQ_GREATER_EQUAL_0_8: Optional[bool] = compare_version("piq", operator.ge, "0.8.0")
_FASTER_COCO_EVAL_AVAILABLE: bool = package_available("faster_coco_eval")

_LATEX_AVAILABLE: bool = shutil.which("latex") is not None
Loading