From 9058fee3c48b4166d3bb63418d8d65f21e343f4b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 8 Jun 2023 11:04:07 +0200 Subject: [PATCH 01/16] refactor --- src/torchmetrics/detection/mean_ap.py | 783 +++++++------------------- tests/unittests/detection/test_map.py | 233 ++++++-- 2 files changed, 405 insertions(+), 611 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 86e5cb0e2b8..a320492dddb 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -1,4 +1,4 @@ -# Copyright The Lightning team. +# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,31 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import sys +from dataclasses import dataclass +from types import TracebackType +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import numpy as np import torch -import torch.distributed as dist -from torch import IntTensor, Tensor +from torch import Tensor +from torch import distributed as dist from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator from torchmetrics.metric import Metric -from torchmetrics.utilities.data import _cumsum -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.imports import ( + _MATPLOTLIB_AVAILABLE, + _PYCOCOTOOLS_AVAILABLE, + _TORCHVISION_AVAILABLE, + _TORCHVISION_GREATER_EQUAL_0_8, +) from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["MeanAveragePrecision.plot"] + if _TORCHVISION_GREATER_EQUAL_0_8: from torchvision.ops import box_area, box_convert, box_iou else: box_convert = box_iou = box_area = None __doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"] + if _PYCOCOTOOLS_AVAILABLE: import pycocotools.mask as mask_utils + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval else: + COCO, COCOeval = None, None mask_utils = None __doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"] @@ -44,108 +56,69 @@ log = logging.getLogger(__name__) -def compute_area(inputs: List[Any], iou_type: str = "bbox") -> Tensor: - """Compute area of input depending on the specified iou_type. - - Default output for empty input is :class:`~torch.Tensor` - """ - if len(inputs) == 0: - return Tensor([]) - - if iou_type == "bbox": - return box_area(torch.stack(inputs)) - if iou_type == "segm": - inputs = [{"size": i[0], "counts": i[1]} for i in inputs] - area = torch.tensor(mask_utils.area(inputs).astype("float")) - return area - - raise Exception(f"IOU type {iou_type} is not supported") - - -def compute_iou( - det: List[Any], - gt: List[Any], - iou_type: str = "bbox", -) -> Tensor: - """Compute IOU between detections and ground-truth using the specified iou_type.""" - if iou_type == "bbox": - return box_iou(torch.stack(det), torch.stack(gt)) - if iou_type == "segm": - return _segm_iou(det, gt) - raise Exception(f"IOU type {iou_type} is not supported") - - -class BaseMetricResults(dict): - """Base metric class, that allows fields for pre-defined metrics.""" +@dataclass +class MAPMetricResults: + """Dataclass to wrap the final mAP results.""" - def __getattr__(self, key: str) -> Tensor: - """Get a specific metric attribute.""" - # Using this you get the correct error message, an AttributeError instead of a KeyError - if key in self: - return self[key] - raise AttributeError(f"No such attribute: {key}") + map: Tensor # noqa: A003 + map_50: Tensor + map_75: Tensor + map_small: Tensor + map_medium: Tensor + map_large: Tensor + mar_1: Tensor + mar_10: Tensor + mar_100: Tensor + mar_small: Tensor + mar_medium: Tensor + mar_large: Tensor + map_per_class: Tensor + mar_100_per_class: Tensor + classes: Tensor - def __setattr__(self, key: str, value: Tensor) -> None: - """Set a specific metric attribute.""" - self[key] = value + def __getitem__(self, key: str) -> Union[Tensor, List[Tensor]]: + """Enables accessing the results via `result['map']` instead of `result.map`.""" + return getattr(self, key) - def __delattr__(self, key: str) -> None: - """Delete a specific metric attribute.""" - if key in self: - del self[key] - raise AttributeError(f"No such attribute: {key}") +# noinspection PyMethodMayBeStatic +class WriteToLog: + """Logging class to move logs to log.debug().""" -class MAPMetricResults(BaseMetricResults): - """Class to wrap the final mAP results.""" + def write(self, buf: str) -> None: # skipcq: PY-D0003, PYL-R0201 + """Write to log.debug() instead of stdout.""" + for line in buf.rstrip().splitlines(): + log.debug(line.rstrip()) - __slots__ = ("map", "map_50", "map_75", "map_small", "map_medium", "map_large", "classes") + def flush(self) -> None: # skipcq: PY-D0003, PYL-R0201 + """Flush the logger.""" + for handler in log.handlers: + handler.flush() + def close(self) -> None: # skipcq: PY-D0003, PYL-R0201 + """Close the logger.""" + for handler in log.handlers: + handler.close() -class MARMetricResults(BaseMetricResults): - """Class to wrap the final mAR results.""" - __slots__ = ("mar_1", "mar_10", "mar_100", "mar_small", "mar_medium", "mar_large") +class HidePrints: + """Internal helper context to suppress the default output of the pycocotools package.""" + def __init__(self) -> None: + """Initialize the context.""" + self._original_stdout = None -class COCOMetricResults(BaseMetricResults): - """Class to wrap the final COCO metric results including various mAP/mAR values.""" - - __slots__ = ( - "map", - "map_50", - "map_75", - "map_small", - "map_medium", - "map_large", - "mar_1", - "mar_10", - "mar_100", - "mar_small", - "mar_medium", - "mar_large", - "map_per_class", - "mar_100_per_class", - ) - - -def _segm_iou(det: List[Tuple[np.ndarray, np.ndarray]], gt: List[Tuple[np.ndarray, np.ndarray]]) -> Tensor: - """Compute IOU between detections and ground-truths using mask-IOU. - - Implementation is based on pycocotools toolkit for mask_utils. - - Args: - det: A list of detection masks as ``[(RLE_SIZE, RLE_COUNTS)]``, where ``RLE_SIZE`` is (width, height) dimension - of the input and RLE_COUNTS is its RLE representation; - - gt: A list of ground-truth masks as ``[(RLE_SIZE, RLE_COUNTS)]``, where ``RLE_SIZE`` is (width, height) dimension - of the input and RLE_COUNTS is its RLE representation; - - """ - det_coco_format = [{"size": i[0], "counts": i[1]} for i in det] - gt_coco_format = [{"size": i[0], "counts": i[1]} for i in gt] + def __enter__(self) -> None: + """Redirect stdout to log.debug().""" + self._original_stdout = sys.stdout # type: ignore + sys.stdout = WriteToLog() # type: ignore - return torch.tensor(mask_utils.iou(det_coco_format, gt_coco_format, [False for _ in gt])) + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_t: Optional[TracebackType] + ) -> None: # type: ignore + """Restore stdout.""" + sys.stdout.close() + sys.stdout = self._original_stdout # type: ignore class MeanAveragePrecision(Metric): @@ -325,6 +298,11 @@ def __init__( ) -> None: super().__init__(**kwargs) + if not _PYCOCOTOOLS_AVAILABLE: + raise ImportError( + "`MAP` metric requires that `pycocotools` installed." + " Please install with `pip install pycocotools` or `pip install torchmetrics[detection]`" + ) if not _TORCHVISION_GREATER_EQUAL_0_8: raise ModuleNotFoundError( "`MeanAveragePrecision` metric requires that `torchvision` version 0.8.0 or newer is installed." @@ -332,38 +310,47 @@ def __init__( ) allowed_box_formats = ("xyxy", "xywh", "cxcywh") - allowed_iou_types = ("segm", "bbox") if box_format not in allowed_box_formats: raise ValueError(f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}") self.box_format = box_format - self.iou_thresholds = iou_thresholds or torch.linspace(0.5, 0.95, round((0.95 - 0.5) / 0.05) + 1).tolist() - self.rec_thresholds = rec_thresholds or torch.linspace(0.0, 1.00, round(1.00 / 0.01) + 1).tolist() - max_det_thr, _ = torch.sort(IntTensor(max_detection_thresholds or [1, 10, 100])) - self.max_detection_thresholds = max_det_thr.tolist() + + allowed_iou_types = ("segm", "bbox") if iou_type not in allowed_iou_types: raise ValueError(f"Expected argument `iou_type` to be one of {allowed_iou_types} but got {iou_type}") - if iou_type == "segm" and not _PYCOCOTOOLS_AVAILABLE: - raise ModuleNotFoundError("When `iou_type` is set to 'segm', pycocotools need to be installed") self.iou_type = iou_type - self.bbox_area_ranges = { - "all": (float(0**2), float(1e5**2)), - "small": (float(0**2), float(32**2)), - "medium": (float(32**2), float(96**2)), - "large": (float(96**2), float(1e5**2)), - } + + if iou_thresholds is not None and not isinstance(iou_thresholds, list): + raise ValueError( + f"Expected argument `iou_thresholds` to either be `None` or a list of floats but got {iou_thresholds}" + ) + self.iou_thresholds = iou_thresholds or torch.linspace(0.5, 0.95, round((0.95 - 0.5) / 0.05) + 1).tolist() + + if rec_thresholds is not None and not isinstance(rec_thresholds, list): + raise ValueError( + f"Expected argument `rec_thresholds` to either be `None` or a list of floats but got {rec_thresholds}" + ) + self.rec_thresholds = rec_thresholds or torch.linspace(0.0, 1.00, round(1.00 / 0.01) + 1).tolist() + + if max_detection_thresholds is not None and not isinstance(max_detection_thresholds, list): + raise ValueError( + f"Expected argument `max_detection_thresholds` to either be `None` or a list of ints" + f" but got {max_detection_thresholds}" + ) + max_det_thr, _ = torch.sort(torch.tensor(max_detection_thresholds or [1, 10, 100], dtype=torch.int)) + self.max_detection_thresholds = max_det_thr.tolist() if not isinstance(class_metrics, bool): raise ValueError("Expected argument `class_metrics` to be a boolean") - self.class_metrics = class_metrics + self.add_state("detections", default=[], dist_reduce_fx=None) self.add_state("detection_scores", default=[], dist_reduce_fx=None) self.add_state("detection_labels", default=[], dist_reduce_fx=None) self.add_state("groundtruths", default=[], dist_reduce_fx=None) self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) - def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: - """Update state with predictions and targets.""" + def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore + """Update metric state.""" _input_validator(preds, target, iou_type=self.iou_type) for item in preds: @@ -378,19 +365,6 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] self.groundtruths.append(groundtruths) self.groundtruth_labels.append(item["labels"]) - def _move_list_states_to_cpu(self) -> None: - """Move list states to cpu to save GPU memory.""" - for key in self._defaults: - current_val = getattr(self, key) - current_to_cpu = [] - if isinstance(current_val, Sequence): - for cur_v in current_val: - # Cannot handle RLE as Tensor - if not isinstance(cur_v, tuple): - cur_v = cur_v.to("cpu") - current_to_cpu.append(cur_v) - setattr(self, key, current_to_cpu) - def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: if self.iou_type == "bbox": boxes = _fix_empty_tensors(item["boxes"]) @@ -408,468 +382,129 @@ def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: def _get_classes(self) -> List: """Return a list of unique classes found in ground truth and detection data.""" if len(self.detection_labels) > 0 or len(self.groundtruth_labels) > 0: - return torch.cat(self.detection_labels + self.groundtruth_labels).unique().tolist() + return torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist() return [] - def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor: - """Compute the Intersection over Union (IoU) between bounding boxes for the given image and class. + def compute(self) -> dict: + """Computes the metric.""" + coco_target, coco_preds = COCO(), COCO() - Args: - idx: - Image Id, equivalent to the index of supplied samples - class_id: - Class Id of the supplied ground truth and detection labels - max_det: - Maximum number of evaluated detection bounding boxes - """ - # if self.iou_type == "bbox": - gt = self.groundtruths[idx] - det = self.detections[idx] - - gt_label_mask = (self.groundtruth_labels[idx] == class_id).nonzero().squeeze(1) - det_label_mask = (self.detection_labels[idx] == class_id).nonzero().squeeze(1) - - if len(gt_label_mask) == 0 or len(det_label_mask) == 0: - return Tensor([]) - - gt = [gt[i] for i in gt_label_mask] - det = [det[i] for i in det_label_mask] - - if len(gt) == 0 or len(det) == 0: - return Tensor([]) - - # Sort by scores and use only max detections - scores = self.detection_scores[idx] - scores_filtered = scores[self.detection_labels[idx] == class_id] - inds = torch.argsort(scores_filtered, descending=True) - - # TODO Fix (only for masks is necessary) - det = [det[i] for i in inds] - if len(det) > max_det: - det = det[:max_det] - - return compute_iou(det, gt, self.iou_type).to(self.device) - - def __evaluate_image_gt_no_preds( - self, gt: Tensor, gt_label_mask: Tensor, area_range: Tuple[int, int], nb_iou_thrs: int - ) -> Dict[str, Any]: - """Evaluate images with a ground truth but no predictions.""" - # GTs - gt = [gt[i] for i in gt_label_mask] - nb_gt = len(gt) - areas = compute_area(gt, iou_type=self.iou_type).to(self.device) - ignore_area = (areas < area_range[0]) | (areas > area_range[1]) - gt_ignore, _ = torch.sort(ignore_area.to(torch.uint8)) - gt_ignore = gt_ignore.to(torch.bool) - - # Detections - nb_det = 0 - det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device) - - return { - "dtMatches": torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device), - "gtMatches": torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device), - "dtScores": torch.zeros(nb_det, dtype=torch.float32, device=self.device), - "gtIgnore": gt_ignore, - "dtIgnore": det_ignore, - } - - def __evaluate_image_preds_no_gt( - self, det: Tensor, idx: int, det_label_mask: Tensor, max_det: int, area_range: Tuple[int, int], nb_iou_thrs: int - ) -> Dict[str, Any]: - """Evaluate images with a prediction but no ground truth.""" - # GTs - nb_gt = 0 - - gt_ignore = torch.zeros(nb_gt, dtype=torch.bool, device=self.device) - - # Detections - - det = [det[i] for i in det_label_mask] - scores = self.detection_scores[idx] - scores_filtered = scores[det_label_mask] - scores_sorted, dtind = torch.sort(scores_filtered, descending=True) - - det = [det[i] for i in dtind] - if len(det) > max_det: - det = det[:max_det] - nb_det = len(det) - det_areas = compute_area(det, iou_type=self.iou_type).to(self.device) - det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1]) - ar = det_ignore_area.reshape((1, nb_det)) - det_ignore = torch.repeat_interleave(ar, nb_iou_thrs, 0) - - return { - "dtMatches": torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device), - "gtMatches": torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device), - "dtScores": scores_sorted.to(self.device), - "gtIgnore": gt_ignore.to(self.device), - "dtIgnore": det_ignore.to(self.device), - } - - def _evaluate_image( - self, idx: int, class_id: int, area_range: Tuple[int, int], max_det: int, ious: dict - ) -> Optional[dict]: - """Perform evaluation for single class and image. + coco_target.dataset = self._get_coco_format(self.groundtruths, self.groundtruth_labels) + coco_preds.dataset = self._get_coco_format(self.detections, self.detection_labels, self.detection_scores) - Args: - idx: - Image Id, equivalent to the index of supplied samples. - class_id: - Class Id of the supplied ground truth and detection labels. - area_range: - List of lower and upper bounding box area threshold. - max_det: - Maximum number of evaluated detection bounding boxes. - ious: - IoU results for image and class. - """ - gt = self.groundtruths[idx] - det = self.detections[idx] - gt_label_mask = (self.groundtruth_labels[idx] == class_id).nonzero().squeeze(1) - det_label_mask = (self.detection_labels[idx] == class_id).nonzero().squeeze(1) - - # No Gt and No predictions --> ignore image - if len(gt_label_mask) == 0 and len(det_label_mask) == 0: - return None - - nb_iou_thrs = len(self.iou_thresholds) - - # Some GT but no predictions - if len(gt_label_mask) > 0 and len(det_label_mask) == 0: - return self.__evaluate_image_gt_no_preds(gt, gt_label_mask, area_range, nb_iou_thrs) - - # Some predictions but no GT - if len(gt_label_mask) == 0 and len(det_label_mask) >= 0: - return self.__evaluate_image_preds_no_gt(det, idx, det_label_mask, max_det, area_range, nb_iou_thrs) - - gt = [gt[i] for i in gt_label_mask] - det = [det[i] for i in det_label_mask] - if len(gt) == 0 and len(det) == 0: - return None - if isinstance(det, dict): - det = [det] - if isinstance(gt, dict): - gt = [gt] - - areas = compute_area(gt, iou_type=self.iou_type).to(self.device) - - ignore_area = torch.logical_or(areas < area_range[0], areas > area_range[1]) - - # sort dt highest score first, sort gt ignore last - ignore_area_sorted, gtind = torch.sort(ignore_area.to(torch.uint8)) - # Convert to uint8 temporarily and back to bool, because "Sort currently does not support bool dtype on CUDA" - - ignore_area_sorted = ignore_area_sorted.to(torch.bool).to(self.device) - - gt = [gt[i] for i in gtind] - scores = self.detection_scores[idx] - scores_filtered = scores[det_label_mask] - scores_sorted, dtind = torch.sort(scores_filtered, descending=True) - det = [det[i] for i in dtind] - if len(det) > max_det: - det = det[:max_det] - # load computed ious - ious = ious[idx, class_id][:, gtind] if len(ious[idx, class_id]) > 0 else ious[idx, class_id] - - nb_iou_thrs = len(self.iou_thresholds) - nb_gt = len(gt) - nb_det = len(det) - gt_matches = torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device) - det_matches = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device) - gt_ignore = ignore_area_sorted - det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device) - - if torch.numel(ious) > 0: - for idx_iou, t in enumerate(self.iou_thresholds): - for idx_det, _ in enumerate(det): - m = MeanAveragePrecision._find_best_gt_match(t, gt_matches, idx_iou, gt_ignore, ious, idx_det) - if m == -1: - continue - det_ignore[idx_iou, idx_det] = gt_ignore[m] - det_matches[idx_iou, idx_det] = 1 - gt_matches[idx_iou, m] = 1 - - # set unmatched detections outside of area range to ignore - det_areas = compute_area(det, iou_type=self.iou_type).to(self.device) - det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1]) - ar = det_ignore_area.reshape((1, nb_det)) - det_ignore = torch.logical_or( - det_ignore, torch.logical_and(det_matches == 0, torch.repeat_interleave(ar, nb_iou_thrs, 0)) - ) + with HidePrints(): + coco_target.createIndex() + coco_preds.createIndex() - return { - "dtMatches": det_matches.to(self.device), - "gtMatches": gt_matches.to(self.device), - "dtScores": scores_sorted.to(self.device), - "gtIgnore": gt_ignore.to(self.device), - "dtIgnore": det_ignore.to(self.device), - } + coco_eval = COCOeval(coco_target, coco_preds, iouType=self.iou_type) + coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) + coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) + coco_eval.params.maxDets = self.max_detection_thresholds - @staticmethod - def _find_best_gt_match( - thr: int, gt_matches: Tensor, idx_iou: float, gt_ignore: Tensor, ious: Tensor, idx_det: int - ) -> int: - """Return id of best ground truth match with current detection. + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + stats = coco_eval.stats - Args: - thr: - Current threshold value. - gt_matches: - Tensor showing if a ground truth matches for threshold ``t`` exists. - idx_iou: - Id of threshold ``t``. - gt_ignore: - Tensor showing if ground truth should be ignored. - ious: - IoUs for all combinations of detection and ground truth. - idx_det: - Id of current detection. - """ - previously_matched = gt_matches[idx_iou] - # Remove previously matched or ignored gts - remove_mask = previously_matched | gt_ignore - gt_ious = ious[idx_det] * ~remove_mask - match_idx = gt_ious.argmax().item() - if gt_ious[match_idx] > thr: - return match_idx - return -1 - - def _summarize( - self, - results: Dict, - avg_prec: bool = True, - iou_threshold: Optional[float] = None, - area_range: str = "all", - max_dets: int = 100, - ) -> Tensor: - """Perform evaluation for single class and image. - - Args: - results: - Dictionary including precision, recall and scores for all combinations. - avg_prec: - Calculate average precision. Else calculate average recall. - iou_threshold: - IoU threshold. If set to ``None`` it all values are used. Else results are filtered. - area_range: - Bounding box area range key. - max_dets: - Maximum detections. - """ - area_inds = [i for i, k in enumerate(self.bbox_area_ranges.keys()) if k == area_range] - mdet_inds = [i for i, k in enumerate(self.max_detection_thresholds) if k == max_dets] - if avg_prec: - # dimension of precision: [TxRxKxAxM] - prec = results["precision"] - # IoU - if iou_threshold is not None: - thr = self.iou_thresholds.index(iou_threshold) - prec = prec[thr, :, :, area_inds, mdet_inds] - else: - prec = prec[:, :, :, area_inds, mdet_inds] - else: - # dimension of recall: [TxKxAxM] - prec = results["recall"] - if iou_threshold is not None: - thr = self.iou_thresholds.index(iou_threshold) - prec = prec[thr, :, :, area_inds, mdet_inds] - else: - prec = prec[:, :, area_inds, mdet_inds] - - return torch.tensor([-1.0]) if len(prec[prec > -1]) == 0 else torch.mean(prec[prec > -1]) + map_per_class_values: Tensor = torch.Tensor([-1]) + mar_100_per_class_values: Tensor = torch.Tensor([-1]) + # if class mode is enabled, evaluate metrics per class + if self.class_metrics: + map_per_class_list = [] + mar_100_per_class_list = [] + for class_id in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist(): + coco_eval.params.catIds = [class_id] + with HidePrints(): + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + class_stats = coco_eval.stats + + map_per_class_list.append(torch.Tensor([class_stats[0]])) + mar_100_per_class_list.append(torch.Tensor([class_stats[8]])) + map_per_class_values = torch.Tensor(map_per_class_list) + mar_100_per_class_values = torch.Tensor(mar_100_per_class_list) + + metrics = MAPMetricResults( + map=torch.Tensor([stats[0]]), + map_50=torch.Tensor([stats[1]]), + map_75=torch.Tensor([stats[2]]), + map_small=torch.Tensor([stats[3]]), + map_medium=torch.Tensor([stats[4]]), + map_large=torch.Tensor([stats[5]]), + mar_1=torch.Tensor([stats[6]]), + mar_10=torch.Tensor([stats[7]]), + mar_100=torch.Tensor([stats[8]]), + mar_small=torch.Tensor([stats[9]]), + mar_medium=torch.Tensor([stats[10]]), + mar_large=torch.Tensor([stats[11]]), + map_per_class=map_per_class_values, + mar_100_per_class=mar_100_per_class_values, + classes=torch.Tensor(self._get_classes()), + ) + return metrics.__dict__ - def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResults]: - """Calculate the precision and recall for all supplied classes to calculate mAP/mAR. + def _get_coco_format( + self, boxes: List[torch.Tensor], labels: List[torch.Tensor], scores: Optional[List[torch.Tensor]] = None + ) -> Dict: + """Transforms and returns all cached targets or predictions in COCO format. - Args: - class_ids: - List of label class Ids. + Format is defined at https://cocodataset.org/#format-data """ - img_ids = range(len(self.groundtruths)) - max_detections = self.max_detection_thresholds[-1] - area_ranges = self.bbox_area_ranges.values() - - ious = { - (idx, class_id): self._compute_iou(idx, class_id, max_detections) - for idx in img_ids - for class_id in class_ids - } - - eval_imgs = [ - self._evaluate_image(img_id, class_id, area, max_detections, ious) - for class_id in class_ids - for area in area_ranges - for img_id in img_ids - ] - - nb_iou_thrs = len(self.iou_thresholds) - nb_rec_thrs = len(self.rec_thresholds) - nb_classes = len(class_ids) - nb_bbox_areas = len(self.bbox_area_ranges) - nb_max_det_thrs = len(self.max_detection_thresholds) - nb_imgs = len(img_ids) - precision = -torch.ones((nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs)) - recall = -torch.ones((nb_iou_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs)) - scores = -torch.ones((nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs)) - - # move tensors if necessary - rec_thresholds_tensor = torch.tensor(self.rec_thresholds) - - # retrieve E at each category, area range, and max number of detections - for idx_cls, _ in enumerate(class_ids): - for idx_bbox_area, _ in enumerate(self.bbox_area_ranges): - for idx_max_det_thrs, max_det in enumerate(self.max_detection_thresholds): - recall, precision, scores = MeanAveragePrecision.__calculate_recall_precision_scores( - recall, - precision, - scores, - idx_cls=idx_cls, - idx_bbox_area=idx_bbox_area, - idx_max_det_thrs=idx_max_det_thrs, - eval_imgs=eval_imgs, - rec_thresholds=rec_thresholds_tensor, - max_det=max_det, - nb_imgs=nb_imgs, - nb_bbox_areas=nb_bbox_areas, + images = [] + annotations = [] + annotation_id = 1 # has to start with 1, otherwise COCOEval results are wrong + + for image_id, (image_boxes, image_labels) in enumerate(zip(boxes, labels)): + if self.iou_type == "segm" and len(image_boxes) == 0: + continue + + if self.iou_type == "bbox": + image_boxes = image_boxes.cpu().tolist() + image_labels = image_labels.cpu().tolist() + + images.append({"id": image_id}) + if self.iou_type == "segm": + images[-1]["height"], images[-1]["width"] = image_boxes[0][0][0], image_boxes[0][0][1] + + for k, (image_box, image_label) in enumerate(zip(image_boxes, image_labels)): + if self.iou_type == "bbox" and len(image_box) != 4: + raise ValueError( + f"Invalid input box of sample {image_id}, element {k} (expected 4 values, got {len(image_box)})" ) - return precision, recall - - def _summarize_results(self, precisions: Tensor, recalls: Tensor) -> Tuple[MAPMetricResults, MARMetricResults]: - """Summarizes the precision and recall values to calculate mAP/mAR. - - Args: - precisions: - Precision values for different thresholds - recalls: - Recall values for different thresholds - """ - results = {"precision": precisions, "recall": recalls} - map_metrics = MAPMetricResults() - last_max_det_thr = self.max_detection_thresholds[-1] - map_metrics.map = self._summarize(results, True, max_dets=last_max_det_thr) - if 0.5 in self.iou_thresholds: - map_metrics.map_50 = self._summarize(results, True, iou_threshold=0.5, max_dets=last_max_det_thr) - else: - map_metrics.map_50 = torch.tensor([-1]) - if 0.75 in self.iou_thresholds: - map_metrics.map_75 = self._summarize(results, True, iou_threshold=0.75, max_dets=last_max_det_thr) - else: - map_metrics.map_75 = torch.tensor([-1]) - map_metrics.map_small = self._summarize(results, True, area_range="small", max_dets=last_max_det_thr) - map_metrics.map_medium = self._summarize(results, True, area_range="medium", max_dets=last_max_det_thr) - map_metrics.map_large = self._summarize(results, True, area_range="large", max_dets=last_max_det_thr) - - mar_metrics = MARMetricResults() - for max_det in self.max_detection_thresholds: - mar_metrics[f"mar_{max_det}"] = self._summarize(results, False, max_dets=max_det) - mar_metrics.mar_small = self._summarize(results, False, area_range="small", max_dets=last_max_det_thr) - mar_metrics.mar_medium = self._summarize(results, False, area_range="medium", max_dets=last_max_det_thr) - mar_metrics.mar_large = self._summarize(results, False, area_range="large", max_dets=last_max_det_thr) - - return map_metrics, mar_metrics - - @staticmethod - def __calculate_recall_precision_scores( - recall: Tensor, - precision: Tensor, - scores: Tensor, - idx_cls: int, - idx_bbox_area: int, - idx_max_det_thrs: int, - eval_imgs: list, - rec_thresholds: Tensor, - max_det: int, - nb_imgs: int, - nb_bbox_areas: int, - ) -> Tuple[Tensor, Tensor, Tensor]: - nb_rec_thrs = len(rec_thresholds) - idx_cls_pointer = idx_cls * nb_bbox_areas * nb_imgs - idx_bbox_area_pointer = idx_bbox_area * nb_imgs - # Load all image evals for current class_id and area_range - img_eval_cls_bbox = [eval_imgs[idx_cls_pointer + idx_bbox_area_pointer + i] for i in range(nb_imgs)] - img_eval_cls_bbox = [e for e in img_eval_cls_bbox if e is not None] - if not img_eval_cls_bbox: - return recall, precision, scores - - det_scores = torch.cat([e["dtScores"][:max_det] for e in img_eval_cls_bbox]) - - # different sorting method generates slightly different results. - # mergesort is used to be consistent as Matlab implementation. - # Sort in PyTorch does not support bool types on CUDA (yet, 1.11.0) - dtype = torch.uint8 if det_scores.is_cuda and det_scores.dtype is torch.bool else det_scores.dtype - # Explicitly cast to uint8 to avoid error for bool inputs on CUDA to argsort - inds = torch.argsort(det_scores.to(dtype), descending=True) - det_scores_sorted = det_scores[inds] - - det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] - det_ignore = torch.cat([e["dtIgnore"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] - gt_ignore = torch.cat([e["gtIgnore"] for e in img_eval_cls_bbox]) - npig = torch.count_nonzero(gt_ignore == False) # noqa: E712 - if npig == 0: - return recall, precision, scores - tps = torch.logical_and(det_matches, torch.logical_not(det_ignore)) - fps = torch.logical_and(torch.logical_not(det_matches), torch.logical_not(det_ignore)) - - tp_sum = _cumsum(tps, dim=1, dtype=torch.float) - fp_sum = _cumsum(fps, dim=1, dtype=torch.float) - for idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): - nd = len(tp) - rc = tp / npig - pr = tp / (fp + tp + torch.finfo(torch.float64).eps) - prec = torch.zeros((nb_rec_thrs,)) - score = torch.zeros((nb_rec_thrs,)) - - recall[idx, idx_cls, idx_bbox_area, idx_max_det_thrs] = rc[-1] if nd else 0 - - # Remove zigzags for AUC - diff_zero = torch.zeros((1,), device=pr.device) - diff = torch.ones((1,), device=pr.device) - while not torch.all(diff == 0): - diff = torch.clamp(torch.cat(((pr[1:] - pr[:-1]), diff_zero), 0), min=0) - pr += diff - - inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False) - num_inds = inds.argmax() if inds.max() >= nd else nb_rec_thrs - inds = inds[:num_inds] - prec[:num_inds] = pr[inds] - score[:num_inds] = det_scores_sorted[inds] - precision[idx, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = prec - scores[idx, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = score - - return recall, precision, scores - - def compute(self) -> dict: - """Compute metric.""" - classes = self._get_classes() - precisions, recalls = self._calculate(classes) - map_val, mar_val = self._summarize_results(precisions, recalls) + if type(image_label) != int: + raise ValueError( + f"Invalid input class of sample {image_id}, element {k}" + f" (expected value of type integer, got type {type(image_label)})" + ) - # if class mode is enabled, evaluate metrics per class - map_per_class_values: Tensor = torch.tensor([-1.0]) - mar_max_dets_per_class_values: Tensor = torch.tensor([-1.0]) - if self.class_metrics: - map_per_class_list = [] - mar_max_dets_per_class_list = [] - - for class_idx, _ in enumerate(classes): - cls_precisions = precisions[:, :, class_idx].unsqueeze(dim=2) - cls_recalls = recalls[:, class_idx].unsqueeze(dim=1) - cls_map, cls_mar = self._summarize_results(cls_precisions, cls_recalls) - map_per_class_list.append(cls_map.map) - mar_max_dets_per_class_list.append(cls_mar[f"mar_{self.max_detection_thresholds[-1]}"]) - - map_per_class_values = torch.tensor(map_per_class_list, dtype=torch.float) - mar_max_dets_per_class_values = torch.tensor(mar_max_dets_per_class_list, dtype=torch.float) - - metrics = COCOMetricResults() - metrics.update(map_val) - metrics.update(mar_val) - metrics.map_per_class = map_per_class_values - metrics[f"mar_{self.max_detection_thresholds[-1]}_per_class"] = mar_max_dets_per_class_values - metrics.classes = torch.tensor(classes, dtype=torch.int) - return metrics + stat = image_box if self.iou_type == "bbox" else {"size": image_box[0], "counts": image_box[1]} + + annotation = { + "id": annotation_id, + "image_id": image_id, + "bbox" if self.iou_type == "bbox" else "segmentation": stat, + "area": image_box[2] * image_box[3] if self.iou_type == "bbox" else mask_utils.area(stat), + "category_id": image_label, + "iscrowd": 0, + } + + if scores is not None: + score = scores[image_id][k].cpu().tolist() + if type(score) != float: + raise ValueError( + f"Invalid input score of sample {image_id}, element {k}" + f" (expected value of type float, got type {type(score)})" + ) + annotation["score"] = score + annotations.append(annotation) + annotation_id += 1 + + classes = [{"id": i, "name": str(i)} for i in self._get_classes()] + return {"images": images, "annotations": annotations, "categories": classes} + + # specialized syncronization and apply functions for this metric def _apply(self, fn: Callable) -> torch.nn.Module: """Custom apply function. diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 19e0e0e06cd..206987f16f9 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -244,30 +244,107 @@ def _create_inputs_masks() -> Input: { "boxes": Tensor([[1.0, 2.0, 3.0, 4.0]]), "scores": Tensor([0.8]), # target does not have scores - "labels": Tensor([1]), + "labels": IntTensor([1]), }, ], ], ) +_inputs4 = Input( + preds=[ + [ + { + "boxes": torch.Tensor([[258.15, 41.29, 606.41, 285.07]]), + "scores": torch.Tensor([0.236]), + "labels": torch.IntTensor([4]), + }, # coco image id 42 + { + "boxes": torch.Tensor([[61.00, 22.75, 565.00, 632.42], [12.66, 3.32, 281.26, 275.23]]), + "scores": torch.Tensor([0.318, 0.726]), + "labels": torch.IntTensor([3, 2]), + }, # coco image id 73 + ], + [ + { + "boxes": torch.Tensor( + [ + [87.87, 276.25, 384.29, 379.43], + [0.00, 3.66, 142.15, 316.06], + [296.55, 93.96, 314.97, 152.79], + [328.94, 97.05, 342.49, 122.98], + [356.62, 95.47, 372.33, 147.55], + [464.08, 105.09, 495.74, 146.99], + [276.11, 103.84, 291.44, 150.72], + ] + ), + "scores": torch.Tensor([0.546, 0.3, 0.407, 0.611, 0.335, 0.805, 0.953]), + "labels": torch.IntTensor([4, 1, 0, 0, 0, 0, 0]), + }, # coco image id 74 + { + "boxes": torch.Tensor([[0.00, 2.87, 601.00, 421.52]]), + "scores": torch.Tensor([0.423]), + "labels": torch.IntTensor([5]), + }, # coco image id 133 + ], + ], + target=[ + [ + { + "boxes": torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), + "labels": torch.IntTensor([4]), + }, # coco image id 42 + { + "boxes": torch.Tensor( + [ + [13.00, 22.75, 548.98, 632.42], + [1.66, 3.32, 270.26, 275.23], + ] + ), + "labels": torch.IntTensor([2, 2]), + }, # coco image id 73 + ], + [ + { + "boxes": torch.Tensor( + [ + [61.87, 276.25, 358.29, 379.43], + [2.75, 3.66, 162.15, 316.06], + [295.55, 93.96, 313.97, 152.79], + [326.94, 97.05, 340.49, 122.98], + [356.62, 95.47, 372.33, 147.55], + [462.08, 105.09, 493.74, 146.99], + [277.11, 103.84, 292.44, 150.72], + ] + ), + "labels": torch.IntTensor([4, 1, 0, 0, 0, 0, 0]), + }, # coco image id 74 + { + "boxes": torch.Tensor([[13.99, 2.87, 640.00, 421.52]]), + "labels": torch.IntTensor([5]), + }, # coco image id 133 + ], + ], +) + + def _compare_fn(preds, target) -> dict: """Comparison function for map implementation. Official pycocotools results calculated from a subset of https://github.com/cocodataset/cocoapi/tree/master/results All classes - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.637 - Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.859 - Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.761 - Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.622 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.706 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.901 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.846 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.689 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800 - Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.635 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.432 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.652 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.652 - Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.673 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.701 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.592 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.716 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.716 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.767 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800 - Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.633 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.700 Class 0 Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.725 @@ -275,7 +352,7 @@ def _compare_fn(preds, target) -> dict: Class 1 Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.800 Class 2 Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.454 @@ -289,26 +366,25 @@ def _compare_fn(preds, target) -> dict: Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.650 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.650 - Class 49 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.556 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.580 + Class 5 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.900 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.900 """ return { - "map": Tensor([0.637]), - "map_50": Tensor([0.859]), - "map_75": Tensor([0.761]), - "map_small": Tensor([0.622]), - "map_medium": Tensor([0.800]), - "map_large": Tensor([0.635]), - "mar_1": Tensor([0.432]), - "mar_10": Tensor([0.652]), - "mar_100": Tensor([0.652]), - "mar_small": Tensor([0.673]), - "mar_medium": Tensor([0.800]), - "mar_large": Tensor([0.633]), - "map_per_class": Tensor([0.725, 0.800, 0.454, -1.000, 0.650, 0.556]), - "mar_100_per_class": Tensor([0.780, 0.800, 0.450, -1.000, 0.650, 0.580]), - "classes": Tensor([0, 1, 2, 3, 4, 49]), + "map": torch.Tensor([0.706]), + "map_50": torch.Tensor([0.901]), + "map_75": torch.Tensor([0.846]), + "map_small": torch.Tensor([0.689]), + "map_medium": torch.Tensor([0.800]), + "map_large": torch.Tensor([0.701]), + "mar_1": torch.Tensor([0.592]), + "mar_10": torch.Tensor([0.716]), + "mar_100": torch.Tensor([0.716]), + "mar_small": torch.Tensor([0.767]), + "mar_medium": torch.Tensor([0.800]), + "mar_large": torch.Tensor([0.700]), + "map_per_class": torch.Tensor([0.725, 0.800, 0.454, -1.000, 0.650, 0.900]), + "mar_100_per_class": torch.Tensor([0.780, 0.800, 0.450, -1.000, 0.650, 0.900]), } @@ -352,7 +428,6 @@ def _compare_fn_segm(preds, target) -> dict: @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -@pytest.mark.parametrize("compute_on_cpu", [True]) class TestMAP(MetricTester): """Test the MAP metric for object detection predictions. @@ -364,20 +439,20 @@ class TestMAP(MetricTester): atol = 1e-2 @pytest.mark.parametrize("ddp", [False, True]) - def test_map_bbox(self, compute_on_cpu, ddp): + def test_map_bbox(self, ddp): """Test modular implementation for correctness.""" self.run_class_metric_test( ddp=ddp, - preds=_inputs.preds, - target=_inputs.target, + preds=_inputs4.preds, + target=_inputs4.target, metric_class=MeanAveragePrecision, reference_metric=_compare_fn, check_batch=False, - metric_args={"class_metrics": True, "compute_on_cpu": compute_on_cpu}, + metric_args={"class_metrics": True}, ) @pytest.mark.parametrize("ddp", [False, True]) - def test_map_segm(self, compute_on_cpu, ddp): + def test_map_segm(self, ddp): """Test modular implementation for correctness.""" _inputs_masks = _create_inputs_masks() self.run_class_metric_test( @@ -387,7 +462,7 @@ def test_map_segm(self, compute_on_cpu, ddp): metric_class=MeanAveragePrecision, reference_metric=_compare_fn_segm, check_batch=False, - metric_args={"class_metrics": True, "compute_on_cpu": compute_on_cpu, "iou_type": "segm"}, + metric_args={"class_metrics": True, "iou_type": "segm"}, ) @@ -716,3 +791,87 @@ def test_device_changing(): metric = metric.cpu() val = metric.compute() assert isinstance(val, dict) + + +def test_order(): + """Test that the ordering of input does not matter. + + Issue: https://github.com/Lightning-AI/torchmetrics/issues/1774 + """ + targets = [ + { + "boxes": torch.zeros((0, 4), dtype=torch.float32), + "labels": torch.zeros((0,), dtype=torch.long), + }, + { + "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), + "labels": torch.LongTensor([1, 2]), + }, + ] + + preds = [ + { + "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), + "labels": torch.LongTensor([1, 2]), + "scores": torch.FloatTensor([0.9, 0.8]), + }, + { + "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), + "labels": torch.LongTensor([1, 2]), + "scores": torch.FloatTensor([0.9, 0.8]), + }, + ] + metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox") + metrics = metric(preds, targets) + assert metrics["map_50"] == torch.tensor([0.5]) + + targets = [ + { + "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), + "labels": torch.LongTensor([1, 2]), + }, + { + "boxes": torch.zeros((0, 4), dtype=torch.float32), + "labels": torch.zeros((0,), dtype=torch.long), + }, + ] + + preds = [ + { + "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), + "labels": torch.LongTensor([1, 2]), + "scores": torch.FloatTensor([0.9, 0.8]), + }, + { + "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), + "labels": torch.LongTensor([1, 2]), + "scores": torch.FloatTensor([0.9, 0.8]), + }, + ] + metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox") + metrics = metric(preds, targets) + assert metrics["map_50"] == torch.tensor([0.5]) + + +def test_corner_case(): + """Issue: https://github.com/Lightning-AI/torchmetrics/issues/1184.""" + metric = MeanAveragePrecision(iou_thresholds=[0.5], class_metrics=True) + preds = [ + { + "boxes": torch.Tensor( + [[0, 0, 20, 20], [30, 30, 50, 50], [70, 70, 90, 90], [100, 100, 120, 120]] + ), # FP # FP + "scores": torch.Tensor([0.6, 0.6, 0.6, 0.6]), + "labels": torch.IntTensor([0, 1, 2, 3]), + } + ] + + targets = [ + { + "boxes": torch.Tensor([[0, 0, 20, 20], [30, 30, 50, 50]]), + "labels": torch.IntTensor([0, 1]), + } + ] + metric.update(preds, targets) + res = metric.compute() + assert res["map"] == torch.tensor([0.5]) From 48f3788d680a2a3d6e3a040232772ec5cba3a616 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 9 Jun 2023 08:43:01 +0200 Subject: [PATCH 02/16] tests --- src/torchmetrics/detection/mean_ap.py | 12 +- tests/unittests/detection/__init__.py | 3 + tests/unittests/detection/test_map.py | 155 +++++++++++++++++++++++++- 3 files changed, 158 insertions(+), 12 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index a320492dddb..58b45a6937b 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -27,7 +27,6 @@ from torchmetrics.utilities.imports import ( _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, - _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8, ) from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -37,9 +36,9 @@ if _TORCHVISION_GREATER_EQUAL_0_8: - from torchvision.ops import box_area, box_convert, box_iou + from torchvision.ops import box_convert else: - box_convert = box_iou = box_area = None + box_convert = None __doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"] @@ -81,21 +80,20 @@ def __getitem__(self, key: str) -> Union[Tensor, List[Tensor]]: return getattr(self, key) -# noinspection PyMethodMayBeStatic class WriteToLog: """Logging class to move logs to log.debug().""" - def write(self, buf: str) -> None: # skipcq: PY-D0003, PYL-R0201 + def write(self, buf: str) -> None: """Write to log.debug() instead of stdout.""" for line in buf.rstrip().splitlines(): log.debug(line.rstrip()) - def flush(self) -> None: # skipcq: PY-D0003, PYL-R0201 + def flush(self) -> None: """Flush the logger.""" for handler in log.handlers: handler.flush() - def close(self) -> None: # skipcq: PY-D0003, PYL-R0201 + def close(self) -> None: """Close the logger.""" for handler in log.handlers: handler.close() diff --git a/tests/unittests/detection/__init__.py b/tests/unittests/detection/__init__.py index 6fac88ad7ef..ec3fb8193f7 100644 --- a/tests/unittests/detection/__init__.py +++ b/tests/unittests/detection/__init__.py @@ -3,3 +3,6 @@ from unittests import _PATH_ROOT _SAMPLE_DETECTION_SEGMENTATION = os.path.join(_PATH_ROOT, "_data", "detection", "instance_segmentation_inputs.json") +_DETECTION_VAL = os.path.join(_PATH_ROOT, "_data", "detection", "instances_val2014.json") +_DETECTION_BBOX = os.path.join(_PATH_ROOT, "_data", "detection", "instances_val2014_fakebbox100_results.json") +_DETECTION_SEGM = os.path.join(_PATH_ROOT, "_data", "detection", "instances_val2014_fakesegm100_results.json") diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 206987f16f9..56a963d382d 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -11,21 +11,166 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json from collections import namedtuple +from copy import deepcopy +from functools import partial import numpy as np import pytest import torch from pycocotools import mask +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval from torch import IntTensor, Tensor from torchmetrics.detection.mean_ap import MeanAveragePrecision -from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.imports import _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 -from unittests.detection import _SAMPLE_DETECTION_SEGMENTATION +from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL, _SAMPLE_DETECTION_SEGMENTATION from unittests.helpers.testers import MetricTester + +def _generate_inputs(iou_type): + """Generates inputs for the MAP metric.""" + gt = COCO(_DETECTION_VAL) + dt = gt.loadRes(_DETECTION_BBOX if iou_type == "bbox" else _DETECTION_SEGM) + img_ids = sorted(gt.getImgIds()) + img_ids = img_ids[0:100] + + gt_dataset = gt.dataset["annotations"] + dt_dataset = dt.dataset["annotations"] + + preds = {} + for p in dt_dataset: + if p["image_id"] not in preds: + preds[p["image_id"]] = {"boxes" if iou_type == "bbox" else "masks": [], "scores": [], "labels": []} + if iou_type == "bbox": + preds[p["image_id"]]["boxes"].append(p["bbox"]) + else: + preds[p["image_id"]]["masks"].append(gt.annToMask(p)) + preds[p["image_id"]]["scores"].append(p["score"]) + preds[p["image_id"]]["labels"].append(p["category_id"]) + missing_pred = set(img_ids) - set(preds.keys()) + for i in missing_pred: + preds[i] = {"boxes" if iou_type == "bbox" else "masks": [], "scores": [], "labels": []} + + target = {} + for t in gt_dataset: + if t["image_id"] not in img_ids: + continue + if t["image_id"] not in target: + target[t["image_id"]] = {"boxes" if iou_type == "bbox" else "masks": [], "labels": []} + if iou_type == "bbox": + target[t["image_id"]]["boxes"].append(t["bbox"]) + else: + target[t["image_id"]]["masks"].append(gt.annToMask(t)) + target[t["image_id"]]["labels"].append(t["category_id"]) + + if iou_type == "bbox": + preds = [ + { + "boxes": torch.tensor(p["boxes"]), + "scores": torch.tensor(p["scores"]), + "labels": torch.tensor(p["labels"]), + } + for p in preds.values() + ] + target = [{"boxes": torch.tensor(t["boxes"]), "labels": torch.tensor(t["labels"])} for t in target.values()] + else: + preds = [ + { + "masks": torch.tensor(p["masks"]), + "scores": torch.tensor(p["scores"]), + "labels": torch.tensor(p["labels"]), + } + for p in preds.values() + ] + target = [{"masks": torch.tensor(t["masks"]), "labels": torch.tensor(t["labels"])} for t in target.values()] + + # create 10 batches of 10 preds/targets each + preds = [preds[10 * i : 10 * (i + 1)] for i in range(10)] + target = [target[10 * i : 10 * (i + 1)] for i in range(10)] + + return preds, target + + +_bbox_input = _generate_inputs("bbox") +_segm_input = _generate_inputs("segm") + + +def _compare_fn(preds, target, iou_type, class_metrics=True): + """Taken from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb.""" + gt = COCO(_DETECTION_VAL) + dt = gt.loadRes(_DETECTION_BBOX if iou_type == "bbox" else _DETECTION_SEGM) + img_ids = sorted(gt.getImgIds()) + img_ids = img_ids[0:100] + coco_eval = COCOeval(gt, dt, iou_type) + coco_eval.params.imgIds = img_ids + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + global_stats = deepcopy(coco_eval.stats) + + map_per_class_values = torch.Tensor([-1]) + mar_100_per_class_values = torch.Tensor([-1]) + classes = Tensor(np.unique([x["category_id"] for x in gt.dataset["annotations"]])) + if class_metrics: + map_per_class_list = [] + mar_100_per_class_list = [] + for class_id in classes.tolist(): + coco_eval.params.catIds = [class_id] + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + class_stats = coco_eval.stats + map_per_class_list.append(torch.Tensor([class_stats[0]])) + mar_100_per_class_list.append(torch.Tensor([class_stats[8]])) + + map_per_class_values = torch.Tensor(map_per_class_list) + mar_100_per_class_values = torch.Tensor(mar_100_per_class_list) + + return { + "map": Tensor([global_stats[0]]), + "map_50": Tensor([global_stats[1]]), + "map_75": Tensor([global_stats[2]]), + "map_small": Tensor([global_stats[3]]), + "map_medium": Tensor([global_stats[4]]), + "map_large": Tensor([global_stats[5]]), + "mar_1": Tensor([global_stats[6]]), + "mar_10": Tensor([global_stats[7]]), + "mar_100": Tensor([global_stats[8]]), + "mar_small": Tensor([global_stats[9]]), + "mar_medium": Tensor([global_stats[10]]), + "mar_large": Tensor([global_stats[11]]), + "map_per_class": map_per_class_values, + "mar_100_per_class": mar_100_per_class_values, + "classes": classes, + } + + +_pytest_condition = not (_PYCOCOTOOLS_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) + + +@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 and pycocotools is installed") +@pytest.mark.parametrize("iou_type", ["bbox", "segm"]) +class TestMAPNew(MetricTester): + """Test map metric.""" + + # @pytest.mark.parametrize("ddp", [False, True]) + def test_map(self, iou_type): + """Test modular implementation for correctness.""" + preds, target = _segm_input if iou_type == "segm" else _bbox_input + self.run_class_metric_test( + ddp=False, + preds=preds, + target=target, + metric_class=MeanAveragePrecision, + reference_metric=partial(_compare_fn, iou_type=iou_type), + metric_args={"iou_type": iou_type}, + check_batch=False, + ) + + Input = namedtuple("Input", ["preds", "target"]) @@ -424,7 +569,7 @@ def _compare_fn_segm(preds, target) -> dict: } -_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) +_pytest_condition = not _TORCHVISION_GREATER_EQUAL_0_8 @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") @@ -855,7 +1000,7 @@ def test_order(): def test_corner_case(): """Issue: https://github.com/Lightning-AI/torchmetrics/issues/1184.""" - metric = MeanAveragePrecision(iou_thresholds=[0.5], class_metrics=True) + metric = MeanAveragePrecision(iou_thresholds=[0.501], class_metrics=True) preds = [ { "boxes": torch.Tensor( From b586a42e787cab548e67e7132951088ae33669c7 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 13 Jun 2023 12:06:06 +0200 Subject: [PATCH 03/16] working reference --- src/torchmetrics/detection/mean_ap.py | 295 ++++++++++---- tests/unittests/detection/__init__.py | 2 +- tests/unittests/detection/test_map.py | 538 +++++++------------------- 3 files changed, 359 insertions(+), 476 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 58b45a6937b..21fd6f13e01 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import sys from dataclasses import dataclass @@ -21,6 +22,7 @@ import torch from torch import Tensor from torch import distributed as dist +from typing_extensions import Literal from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator from torchmetrics.metric import Metric @@ -99,7 +101,7 @@ def close(self) -> None: handler.close() -class HidePrints: +class _HidePrints: """Internal helper context to suppress the default output of the pycocotools package.""" def __init__(self) -> None: @@ -153,6 +155,12 @@ class MeanAveragePrecision(Metric): classes for the boxes. - masks: :class:`~torch.bool` of shape ``(num_boxes, image_height, image_width)`` containing boolean masks. Only required when `iou_type="segm"`. + - iscrowd: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0/1 values indicating whether + the bounding box/masks indicate a crowd of objects. Value is optional, and if not provided it will + automatically be set to 0. + - area: :class:`~torch.FloatTensor` of shape ``(num_boxes)`` containing the area of the object. Value if + optional, and if not provided will be automatically calculated based on the bounding box/masks provided. + Only affects which samples contribute to the `map_small`, `map_medium`, `map_large` values As output of ``forward`` and ``compute`` the metric returns the following output: @@ -283,11 +291,13 @@ class MeanAveragePrecision(Metric): detection_labels: List[Tensor] groundtruths: List[Tensor] groundtruth_labels: List[Tensor] + groundtruth_crowds: List[Tensor] + groundtruth_area: List[Tensor] def __init__( self, box_format: str = "xyxy", - iou_type: str = "bbox", + iou_type: Literal["bbox", "segm"] = "bbox", iou_thresholds: Optional[List[float]] = None, rec_thresholds: Optional[List[float]] = None, max_detection_thresholds: Optional[List[int]] = None, @@ -346,6 +356,8 @@ def __init__( self.add_state("detection_labels", default=[], dist_reduce_fx=None) self.add_state("groundtruths", default=[], dist_reduce_fx=None) self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) + self.add_state("groundtruth_crowds", default=[], dist_reduce_fx=None) + self.add_state("groundtruth_area", default=[], dist_reduce_fx=None) def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore """Update metric state.""" @@ -362,35 +374,19 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] groundtruths = self._get_safe_item_values(item) self.groundtruths.append(groundtruths) self.groundtruth_labels.append(item["labels"]) - - def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: - if self.iou_type == "bbox": - boxes = _fix_empty_tensors(item["boxes"]) - if boxes.numel() > 0: - boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy") - return boxes - if self.iou_type == "segm": - masks = [] - for i in item["masks"].cpu().numpy(): - rle = mask_utils.encode(np.asfortranarray(i)) - masks.append((tuple(rle["size"]), rle["counts"])) - return tuple(masks) - raise Exception(f"IOU type {self.iou_type} is not supported") - - def _get_classes(self) -> List: - """Return a list of unique classes found in ground truth and detection data.""" - if len(self.detection_labels) > 0 or len(self.groundtruth_labels) > 0: - return torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist() - return [] + self.groundtruth_crowds.append(item.get("iscrowd", torch.zeros_like(item["labels"]))) + self.groundtruth_area.append(item.get("area", -1 * torch.zeros_like(item["labels"]))) def compute(self) -> dict: """Computes the metric.""" coco_target, coco_preds = COCO(), COCO() - coco_target.dataset = self._get_coco_format(self.groundtruths, self.groundtruth_labels) - coco_preds.dataset = self._get_coco_format(self.detections, self.detection_labels, self.detection_scores) + coco_target.dataset = self._get_coco_format( + self.groundtruths, self.groundtruth_labels, crowds=self.groundtruth_crowds, area=self.groundtruth_area + ) + coco_preds.dataset = self._get_coco_format(self.detections, self.detection_labels, scores=self.detection_scores) - with HidePrints(): + with _HidePrints(): coco_target.createIndex() coco_preds.createIndex() @@ -404,15 +400,13 @@ def compute(self) -> dict: coco_eval.summarize() stats = coco_eval.stats - map_per_class_values: Tensor = torch.Tensor([-1]) - mar_100_per_class_values: Tensor = torch.Tensor([-1]) # if class mode is enabled, evaluate metrics per class if self.class_metrics: map_per_class_list = [] mar_100_per_class_list = [] for class_id in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist(): coco_eval.params.catIds = [class_id] - with HidePrints(): + with _HidePrints(): coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() @@ -420,8 +414,12 @@ def compute(self) -> dict: map_per_class_list.append(torch.Tensor([class_stats[0]])) mar_100_per_class_list.append(torch.Tensor([class_stats[8]])) + map_per_class_values = torch.Tensor(map_per_class_list) mar_100_per_class_values = torch.Tensor(mar_100_per_class_list) + else: + map_per_class_values: Tensor = torch.Tensor([-1]) + mar_100_per_class_values: Tensor = torch.Tensor([-1]) metrics = MAPMetricResults( map=torch.Tensor([stats[0]]), @@ -440,10 +438,143 @@ def compute(self) -> dict: mar_100_per_class=mar_100_per_class_values, classes=torch.Tensor(self._get_classes()), ) + return metrics.__dict__ + @staticmethod + def coco_to_tm( + coco_preds: str, + coco_target: str, + iou_type: Literal["bbox", "segm"] = "bbox", + ) -> Tuple[List[Dict[str, Tensor]], List[Dict[str, Tensor]]]: + """Convert coco format to the input format of the map metric. + + Args: + coco_preds: Path to the json file containing the predictions in coco format + coco_target: Path to the json file containing the targets in coco format + iou_type: Type of input, either `bbox` for bounding boxes or `segm` for segmentation masks + + Returns: + preds: List of dictionaries containing the predictions in the input format of this metric + target: List of dictionaries containing the targets in the input format of this metric + + """ + gt = COCO(coco_target) + dt = gt.loadRes(coco_preds) + + gt_dataset = gt.dataset["annotations"] + dt_dataset = dt.dataset["annotations"] + + target = {} + for t in gt_dataset: + if t["image_id"] not in target: + target[t["image_id"]] = { + "boxes" if iou_type == "bbox" else "masks": [], + "labels": [], + "iscrowd": [], + "area": [], + } + if iou_type == "bbox": + target[t["image_id"]]["boxes"].append(t["bbox"]) + else: + target[t["image_id"]]["masks"].append(gt.annToMask(t)) + target[t["image_id"]]["labels"].append(t["category_id"]) + target[t["image_id"]]["iscrowd"].append(t["iscrowd"]) + target[t["image_id"]]["area"].append(t["area"]) + + preds = {} + for p in dt_dataset: + if p["image_id"] not in preds: + preds[p["image_id"]] = {"boxes" if iou_type == "bbox" else "masks": [], "scores": [], "labels": []} + if iou_type == "bbox": + preds[p["image_id"]]["boxes"].append(p["bbox"]) + else: + preds[p["image_id"]]["masks"].append(gt.annToMask(p)) + preds[p["image_id"]]["scores"].append(p["score"]) + preds[p["image_id"]]["labels"].append(p["category_id"]) + for k in target: # add empty predictions for images without predictions + if k not in preds: + preds[k] = {"boxes" if iou_type == "bbox" else "masks": [], "scores": [], "labels": []} + + batched_preds, batched_target = [], [] + for key in target: + name = "boxes" if iou_type == "bbox" else "masks" + batched_preds.append( + { + name: torch.tensor(preds[key]["boxes"]) + if iou_type == "bbox" + else torch.tensor(preds[key]["masks"]), + "scores": torch.tensor(preds[key]["scores"]), + "labels": torch.tensor(preds[key]["labels"]), + } + ) + batched_target.append( + { + name: torch.tensor(target[key]["boxes"]) + if iou_type == "bbox" + else torch.tensor(target[key]["masks"]), + "labels": torch.tensor(target[key]["labels"]), + "iscrowd": torch.tensor(target[key]["iscrowd"]), + "area": torch.tensor(target[key]["area"]), + } + ) + + return batched_preds, batched_target + + def tm_to_coco(self, name: str = "tm_map_input") -> None: + """Write the input to the map metric to a json file in coco format. + + Args: + name: Name of the output file, which will be appended with "_preds.json" and "_target.json" + """ + target_dataset = self._get_coco_format(self.groundtruths, self.groundtruth_labels) + preds_dataset = self._get_coco_format(self.detections, self.detection_labels, self.detection_scores) + + preds_json = json.dumps(preds_dataset["annotations"], indent=4) + target_json = json.dumps(target_dataset, indent=4) + + with open(f"{name}_preds.json", "w") as f: + f.write(preds_json) + + with open(f"{name}_target.json", "w") as f: + f.write(target_json) + + def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: + """Convert and return the boxes or masks from the item depending on the iou_type. + + Args: + item: input dictionary containing the boxes or masks + + Returns: + boxes or masks depending on the iou_type + + """ + if self.iou_type == "bbox": + boxes = _fix_empty_tensors(item["boxes"]) + if boxes.numel() > 0: + boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy") + return boxes + if self.iou_type == "segm": + masks = [] + for i in item["masks"].cpu().numpy(): + rle = mask_utils.encode(np.asfortranarray(i)) + masks.append((tuple(rle["size"]), rle["counts"])) + return tuple(masks) + raise Exception(f"IOU type {self.iou_type} is not supported") + + def _get_classes(self) -> List: + """Return a list of unique classes found in ground truth and detection data.""" + if len(self.detection_labels) > 0 or len(self.groundtruth_labels) > 0: + return torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist() + return [] + def _get_coco_format( - self, boxes: List[torch.Tensor], labels: List[torch.Tensor], scores: Optional[List[torch.Tensor]] = None + self, + boxes: List[torch.Tensor], + labels: List[torch.Tensor], + scores: Optional[List[torch.Tensor]] = None, + crowds: Optional[List[torch.Tensor]] = None, + area: Optional[List[torch.Tensor]] = None, ) -> Dict: """Transforms and returns all cached targets or predictions in COCO format. @@ -479,13 +610,18 @@ def _get_coco_format( stat = image_box if self.iou_type == "bbox" else {"size": image_box[0], "counts": image_box[1]} + if area is not None and area[image_id][k].cpu().tolist() > 0: + area_stat = area[image_id][k].cpu().tolist() + else: + area_stat = image_box[2] * image_box[3] if self.iou_type == "bbox" else mask_utils.area(stat) + annotation = { "id": annotation_id, "image_id": image_id, "bbox" if self.iou_type == "bbox" else "segmentation": stat, - "area": image_box[2] * image_box[3] if self.iou_type == "bbox" else mask_utils.area(stat), + "area": area_stat, "category_id": image_label, - "iscrowd": 0, + "iscrowd": crowds[image_id][k].cpu().tolist() if crowds is not None else 0, } if scores is not None: @@ -502,49 +638,6 @@ def _get_coco_format( classes = [{"id": i, "name": str(i)} for i in self._get_classes()] return {"images": images, "annotations": annotations, "categories": classes} - # specialized syncronization and apply functions for this metric - - def _apply(self, fn: Callable) -> torch.nn.Module: - """Custom apply function. - - Excludes the detections and groundtruths from the casting when the iou_type is set to `segm` as the state is - no longer a tensor but a tuple. - """ - if self.iou_type == "segm": - this = super()._apply(fn, exclude_state=("detections", "groundtruths")) - else: - this = super()._apply(fn) - return this - - def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None: - """Custom sync function. - - For the iou_type `segm` the detections and groundtruths are no longer tensors but tuples. Therefore, we need - to gather the list of tuples and then convert it back to a list of tuples. - - """ - super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) - - if self.iou_type == "segm": - self.detections = self._gather_tuple_list(self.detections, process_group) - self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group) - - @staticmethod - def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]: - """Gather a list of tuples over multiple devices.""" - world_size = dist.get_world_size(group=process_group) - dist.barrier(group=process_group) - - list_gathered = [None for _ in range(world_size)] - dist.all_gather_object(list_gathered, list_to_gather, group=process_group) - - list_merged = [] - for idx in range(len(list_gathered[0])): - for rank in range(world_size): - list_merged.append(list_gathered[rank][idx]) - - return list_merged - def plot( self, val: Optional[Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: @@ -602,3 +695,57 @@ def plot( >>> fig_, ax_ = metric.plot(vals) """ return self._plot(val, ax) + + # -------------------- + # specialized syncronization and apply functions for this metric + # -------------------- + + def _apply(self, fn: Callable) -> torch.nn.Module: + """Custom apply function. + + Excludes the detections and groundtruths from the casting when the iou_type is set to `segm` as the state is + no longer a tensor but a tuple. + """ + if self.iou_type == "segm": + this = super()._apply(fn, exclude_state=("detections", "groundtruths")) + else: + this = super()._apply(fn) + return this + + def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None: + """Custom sync function. + + For the iou_type `segm` the detections and groundtruths are no longer tensors but tuples. Therefore, we need + to gather the list of tuples and then convert it back to a list of tuples. + + """ + super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) + + if self.iou_type == "segm": + self.detections = self._gather_tuple_list(self.detections, process_group) + self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group) + + @staticmethod + def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]: + """Gather a list of tuples over multiple devices. + + Args: + list_to_gather: input list of tuples that should be gathered across devices + process_group: process group to gather the list of tuples + + Returns: + list of tuples gathered across devices + + """ + world_size = dist.get_world_size(group=process_group) + dist.barrier(group=process_group) + + list_gathered = [None for _ in range(world_size)] + dist.all_gather_object(list_gathered, list_to_gather, group=process_group) + + list_merged = [] + for idx in range(len(list_gathered[0])): + for rank in range(world_size): + list_merged.append(list_gathered[rank][idx]) + + return list_merged diff --git a/tests/unittests/detection/__init__.py b/tests/unittests/detection/__init__.py index ec3fb8193f7..bc4e1199cbf 100644 --- a/tests/unittests/detection/__init__.py +++ b/tests/unittests/detection/__init__.py @@ -3,6 +3,6 @@ from unittests import _PATH_ROOT _SAMPLE_DETECTION_SEGMENTATION = os.path.join(_PATH_ROOT, "_data", "detection", "instance_segmentation_inputs.json") -_DETECTION_VAL = os.path.join(_PATH_ROOT, "_data", "detection", "instances_val2014.json") +_DETECTION_VAL = os.path.join(_PATH_ROOT, "_data", "detection", "instances_val2014_100.json") _DETECTION_BBOX = os.path.join(_PATH_ROOT, "_data", "detection", "instances_val2014_fakebbox100_results.json") _DETECTION_SEGM = os.path.join(_PATH_ROOT, "_data", "detection", "instances_val2014_fakesegm100_results.json") diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 56a963d382d..21b34ca3c44 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -23,15 +23,22 @@ from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval from torch import IntTensor, Tensor -from torchmetrics.detection.mean_ap import MeanAveragePrecision +from torchmetrics.detection.mean_ap import MeanAveragePrecision, _HidePrints from torchmetrics.utilities.imports import _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL, _SAMPLE_DETECTION_SEGMENTATION from unittests.helpers.testers import MetricTester +_pytest_condition = not (_PYCOCOTOOLS_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) + + +def _generate_coco_inputs(iou_type): + """Generates inputs for the MAP metric. -def _generate_inputs(iou_type): - """Generates inputs for the MAP metric.""" + The inputs are generated from the official COCO results json files: + https://github.com/cocodataset/cocoapi/tree/master/results + and should therefore correspond directly to the result on the webpage + """ gt = COCO(_DETECTION_VAL) dt = gt.loadRes(_DETECTION_BBOX if iou_type == "bbox" else _DETECTION_SEGM) img_ids = sorted(gt.getImgIds()) @@ -59,69 +66,152 @@ def _generate_inputs(iou_type): if t["image_id"] not in img_ids: continue if t["image_id"] not in target: - target[t["image_id"]] = {"boxes" if iou_type == "bbox" else "masks": [], "labels": []} + target[t["image_id"]] = { + "boxes" if iou_type == "bbox" else "masks": [], + "labels": [], + "iscrowd": [], + "area": [], + } if iou_type == "bbox": target[t["image_id"]]["boxes"].append(t["bbox"]) else: target[t["image_id"]]["masks"].append(gt.annToMask(t)) target[t["image_id"]]["labels"].append(t["category_id"]) + target[t["image_id"]]["iscrowd"].append(t["iscrowd"]) + target[t["image_id"]]["area"].append(t["area"]) - if iou_type == "bbox": - preds = [ + batched_preds, batched_target = [], [] + for key in target: + name = "boxes" if iou_type == "bbox" else "masks" + batched_preds.append( { - "boxes": torch.tensor(p["boxes"]), - "scores": torch.tensor(p["scores"]), - "labels": torch.tensor(p["labels"]), + name: torch.tensor(preds[key]["boxes"]) if iou_type == "bbox" else torch.tensor(preds[key]["masks"]), + "scores": torch.tensor(preds[key]["scores"]), + "labels": torch.tensor(preds[key]["labels"]), } - for p in preds.values() - ] - target = [{"boxes": torch.tensor(t["boxes"]), "labels": torch.tensor(t["labels"])} for t in target.values()] - else: - preds = [ + ) + batched_target.append( { - "masks": torch.tensor(p["masks"]), - "scores": torch.tensor(p["scores"]), - "labels": torch.tensor(p["labels"]), + name: torch.tensor(target[key]["boxes"]) if iou_type == "bbox" else torch.tensor(target[key]["masks"]), + "labels": torch.tensor(target[key]["labels"]), + "iscrowd": torch.tensor(target[key]["iscrowd"]), + "area": torch.tensor(target[key]["area"]), } - for p in preds.values() - ] - target = [{"masks": torch.tensor(t["masks"]), "labels": torch.tensor(t["labels"])} for t in target.values()] + ) # create 10 batches of 10 preds/targets each - preds = [preds[10 * i : 10 * (i + 1)] for i in range(10)] - target = [target[10 * i : 10 * (i + 1)] for i in range(10)] + batched_preds = [batched_preds[10 * i : 10 * (i + 1)] for i in range(10)] + batched_target = [batched_target[10 * i : 10 * (i + 1)] for i in range(10)] + return batched_preds, batched_target - return preds, target +_coco_bbox_input = _generate_coco_inputs("bbox") +_coco_segm_input = _generate_coco_inputs("segm") -_bbox_input = _generate_inputs("bbox") -_segm_input = _generate_inputs("segm") - -def _compare_fn(preds, target, iou_type, class_metrics=True): +def _compare_again_coco_fn(preds, target, iou_type, class_metrics=True): """Taken from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb.""" gt = COCO(_DETECTION_VAL) dt = gt.loadRes(_DETECTION_BBOX if iou_type == "bbox" else _DETECTION_SEGM) - img_ids = sorted(gt.getImgIds()) - img_ids = img_ids[0:100] + coco_eval = COCOeval(gt, dt, iou_type) - coco_eval.params.imgIds = img_ids - coco_eval.evaluate() - coco_eval.accumulate() - coco_eval.summarize() + with _HidePrints(): + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() global_stats = deepcopy(coco_eval.stats) map_per_class_values = torch.Tensor([-1]) mar_100_per_class_values = torch.Tensor([-1]) - classes = Tensor(np.unique([x["category_id"] for x in gt.dataset["annotations"]])) + classes = torch.tensor( + [ + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 10.0, + 11.0, + 13.0, + 14.0, + 15.0, + 16.0, + 17.0, + 18.0, + 20.0, + 21.0, + 22.0, + 23.0, + 24.0, + 25.0, + 27.0, + 28.0, + 31.0, + 32.0, + 33.0, + 34.0, + 35.0, + 36.0, + 37.0, + 38.0, + 39.0, + 40.0, + 41.0, + 42.0, + 43.0, + 44.0, + 46.0, + 47.0, + 48.0, + 49.0, + 50.0, + 51.0, + 52.0, + 53.0, + 54.0, + 55.0, + 56.0, + 57.0, + 58.0, + 59.0, + 60.0, + 61.0, + 62.0, + 63.0, + 64.0, + 65.0, + 67.0, + 70.0, + 72.0, + 73.0, + 74.0, + 75.0, + 77.0, + 78.0, + 79.0, + 80.0, + 81.0, + 82.0, + 84.0, + 85.0, + 86.0, + 88.0, + 90.0, + ] + ) if class_metrics: map_per_class_list = [] mar_100_per_class_list = [] for class_id in classes.tolist(): coco_eval.params.catIds = [class_id] - coco_eval.evaluate() - coco_eval.accumulate() - coco_eval.summarize() + with _HidePrints(): + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() class_stats = coco_eval.stats map_per_class_list.append(torch.Tensor([class_stats[0]])) mar_100_per_class_list.append(torch.Tensor([class_stats[8]])) @@ -148,25 +238,24 @@ def _compare_fn(preds, target, iou_type, class_metrics=True): } -_pytest_condition = not (_PYCOCOTOOLS_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) - - @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 and pycocotools is installed") @pytest.mark.parametrize("iou_type", ["bbox", "segm"]) -class TestMAPNew(MetricTester): - """Test map metric.""" +@pytest.mark.parametrize("ddp", [False, True]) +class TestMAPUsingCOCOReference(MetricTester): + """Test map metric on the reference coco data.""" - # @pytest.mark.parametrize("ddp", [False, True]) - def test_map(self, iou_type): + atol = 1e-1 + + def test_map(self, iou_type, ddp): """Test modular implementation for correctness.""" - preds, target = _segm_input if iou_type == "segm" else _bbox_input + preds, target = _coco_bbox_input if iou_type == "bbox" else _coco_segm_input self.run_class_metric_test( - ddp=False, + ddp=ddp, preds=preds, target=target, metric_class=MeanAveragePrecision, - reference_metric=partial(_compare_fn, iou_type=iou_type), - metric_args={"iou_type": iou_type}, + reference_metric=partial(_compare_again_coco_fn, iou_type=iou_type, class_metrics=True), + metric_args={"iou_type": iou_type, "class_metrics": True}, check_batch=False, ) @@ -174,59 +263,6 @@ def test_map(self, iou_type): Input = namedtuple("Input", ["preds", "target"]) -def _create_inputs_masks() -> Input: - with open(_SAMPLE_DETECTION_SEGMENTATION) as fp: - inputs_json = json.load(fp) - - _mask_unsqueeze_bool = lambda m: Tensor(mask.decode(m)).unsqueeze(0).bool() - _masks_stack_bool = lambda ms: Tensor(np.stack([mask.decode(m) for m in ms])).bool() - - return Input( - preds=[ - [ - { - "masks": _mask_unsqueeze_bool(inputs_json["preds"][0]), - "scores": Tensor([0.236]), - "labels": IntTensor([4]), - }, - { - "masks": _masks_stack_bool([inputs_json["preds"][1], inputs_json["preds"][2]]), - "scores": Tensor([0.318, 0.726]), - "labels": IntTensor([3, 2]), - }, # 73 - ], - [ - { - "masks": _mask_unsqueeze_bool(inputs_json["preds"][0]), - "scores": Tensor([0.236]), - "labels": IntTensor([4]), - }, - { - "masks": _masks_stack_bool([inputs_json["preds"][1], inputs_json["preds"][2]]), - "scores": Tensor([0.318, 0.726]), - "labels": IntTensor([3, 2]), - }, # 73 - ], - ], - target=[ - [ - {"masks": _mask_unsqueeze_bool(inputs_json["targets"][0]), "labels": IntTensor([4])}, # 42 - { - "masks": _masks_stack_bool([inputs_json["targets"][1], inputs_json["targets"][2]]), - "labels": IntTensor([2, 2]), - }, # 73 - ], - [ - {"masks": _mask_unsqueeze_bool(inputs_json["targets"][0]), "labels": IntTensor([4])}, # 42 - { - "masks": _masks_stack_bool([inputs_json["targets"][1], inputs_json["targets"][2]]), - "labels": IntTensor([2, 2]), - }, # 73 - ], - ], - ) - - _inputs = Input( preds=[ [ @@ -396,222 +432,6 @@ def _create_inputs_masks() -> Input: ) -_inputs4 = Input( - preds=[ - [ - { - "boxes": torch.Tensor([[258.15, 41.29, 606.41, 285.07]]), - "scores": torch.Tensor([0.236]), - "labels": torch.IntTensor([4]), - }, # coco image id 42 - { - "boxes": torch.Tensor([[61.00, 22.75, 565.00, 632.42], [12.66, 3.32, 281.26, 275.23]]), - "scores": torch.Tensor([0.318, 0.726]), - "labels": torch.IntTensor([3, 2]), - }, # coco image id 73 - ], - [ - { - "boxes": torch.Tensor( - [ - [87.87, 276.25, 384.29, 379.43], - [0.00, 3.66, 142.15, 316.06], - [296.55, 93.96, 314.97, 152.79], - [328.94, 97.05, 342.49, 122.98], - [356.62, 95.47, 372.33, 147.55], - [464.08, 105.09, 495.74, 146.99], - [276.11, 103.84, 291.44, 150.72], - ] - ), - "scores": torch.Tensor([0.546, 0.3, 0.407, 0.611, 0.335, 0.805, 0.953]), - "labels": torch.IntTensor([4, 1, 0, 0, 0, 0, 0]), - }, # coco image id 74 - { - "boxes": torch.Tensor([[0.00, 2.87, 601.00, 421.52]]), - "scores": torch.Tensor([0.423]), - "labels": torch.IntTensor([5]), - }, # coco image id 133 - ], - ], - target=[ - [ - { - "boxes": torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), - "labels": torch.IntTensor([4]), - }, # coco image id 42 - { - "boxes": torch.Tensor( - [ - [13.00, 22.75, 548.98, 632.42], - [1.66, 3.32, 270.26, 275.23], - ] - ), - "labels": torch.IntTensor([2, 2]), - }, # coco image id 73 - ], - [ - { - "boxes": torch.Tensor( - [ - [61.87, 276.25, 358.29, 379.43], - [2.75, 3.66, 162.15, 316.06], - [295.55, 93.96, 313.97, 152.79], - [326.94, 97.05, 340.49, 122.98], - [356.62, 95.47, 372.33, 147.55], - [462.08, 105.09, 493.74, 146.99], - [277.11, 103.84, 292.44, 150.72], - ] - ), - "labels": torch.IntTensor([4, 1, 0, 0, 0, 0, 0]), - }, # coco image id 74 - { - "boxes": torch.Tensor([[13.99, 2.87, 640.00, 421.52]]), - "labels": torch.IntTensor([5]), - }, # coco image id 133 - ], - ], -) - - -def _compare_fn(preds, target) -> dict: - """Comparison function for map implementation. - - Official pycocotools results calculated from a subset of https://github.com/cocodataset/cocoapi/tree/master/results - All classes - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.706 - Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.901 - Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.846 - Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.689 - Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800 - Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.701 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.592 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.716 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.716 - Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.767 - Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.800 - Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.700 - - Class 0 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.725 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.780 - - Class 1 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.800 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.800 - - Class 2 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.454 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.450 - - Class 3 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = -1.000 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = -1.000 - - Class 4 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.650 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.650 - - Class 5 - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.900 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.900 - """ - return { - "map": torch.Tensor([0.706]), - "map_50": torch.Tensor([0.901]), - "map_75": torch.Tensor([0.846]), - "map_small": torch.Tensor([0.689]), - "map_medium": torch.Tensor([0.800]), - "map_large": torch.Tensor([0.701]), - "mar_1": torch.Tensor([0.592]), - "mar_10": torch.Tensor([0.716]), - "mar_100": torch.Tensor([0.716]), - "mar_small": torch.Tensor([0.767]), - "mar_medium": torch.Tensor([0.800]), - "mar_large": torch.Tensor([0.700]), - "map_per_class": torch.Tensor([0.725, 0.800, 0.454, -1.000, 0.650, 0.900]), - "mar_100_per_class": torch.Tensor([0.780, 0.800, 0.450, -1.000, 0.650, 0.900]), - } - - -def _compare_fn_segm(preds, target) -> dict: - """Comparison function for map implementation for instance segmentation. - - Official pycocotools results calculated from a subset of https://github.com/cocodataset/cocoapi/tree/master/results - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.352 - Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.752 - Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.252 - Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000 - Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000 - Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.352 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.350 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.350 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.350 - Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000 - Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000 - Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.350 - """ - return { - "map": Tensor([0.352]), - "map_50": Tensor([0.752]), - "map_75": Tensor([0.252]), - "map_small": Tensor([-1]), - "map_medium": Tensor([-1]), - "map_large": Tensor([0.352]), - "mar_1": Tensor([0.35]), - "mar_10": Tensor([0.35]), - "mar_100": Tensor([0.35]), - "mar_small": Tensor([-1]), - "mar_medium": Tensor([-1]), - "mar_large": Tensor([0.35]), - "map_per_class": Tensor([0.4039604, -1.0, 0.3]), - "mar_100_per_class": Tensor([0.4, -1.0, 0.3]), - "classes": Tensor([2, 3, 4]), - } - - -_pytest_condition = not _TORCHVISION_GREATER_EQUAL_0_8 - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -class TestMAP(MetricTester): - """Test the MAP metric for object detection predictions. - - Results are compared to original values from the pycocotools implementation. A subset of the first 10 fake - predictions of the official repo is used: - https://github.com/cocodataset/cocoapi/blob/master/results/instances_val2014_fakebbox100_results.json - """ - - atol = 1e-2 - - @pytest.mark.parametrize("ddp", [False, True]) - def test_map_bbox(self, ddp): - """Test modular implementation for correctness.""" - self.run_class_metric_test( - ddp=ddp, - preds=_inputs4.preds, - target=_inputs4.target, - metric_class=MeanAveragePrecision, - reference_metric=_compare_fn, - check_batch=False, - metric_args={"class_metrics": True}, - ) - - @pytest.mark.parametrize("ddp", [False, True]) - def test_map_segm(self, ddp): - """Test modular implementation for correctness.""" - _inputs_masks = _create_inputs_masks() - self.run_class_metric_test( - ddp=ddp, - preds=_inputs_masks.preds, - target=_inputs_masks.target, - metric_class=MeanAveragePrecision, - reference_metric=_compare_fn_segm, - check_batch=False, - metric_args={"class_metrics": True, "iou_type": "segm"}, - ) - - -# noinspection PyTypeChecker @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") def test_error_on_wrong_init(): """Test class raises the expected errors.""" @@ -936,87 +756,3 @@ def test_device_changing(): metric = metric.cpu() val = metric.compute() assert isinstance(val, dict) - - -def test_order(): - """Test that the ordering of input does not matter. - - Issue: https://github.com/Lightning-AI/torchmetrics/issues/1774 - """ - targets = [ - { - "boxes": torch.zeros((0, 4), dtype=torch.float32), - "labels": torch.zeros((0,), dtype=torch.long), - }, - { - "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), - "labels": torch.LongTensor([1, 2]), - }, - ] - - preds = [ - { - "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), - "labels": torch.LongTensor([1, 2]), - "scores": torch.FloatTensor([0.9, 0.8]), - }, - { - "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), - "labels": torch.LongTensor([1, 2]), - "scores": torch.FloatTensor([0.9, 0.8]), - }, - ] - metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox") - metrics = metric(preds, targets) - assert metrics["map_50"] == torch.tensor([0.5]) - - targets = [ - { - "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), - "labels": torch.LongTensor([1, 2]), - }, - { - "boxes": torch.zeros((0, 4), dtype=torch.float32), - "labels": torch.zeros((0,), dtype=torch.long), - }, - ] - - preds = [ - { - "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), - "labels": torch.LongTensor([1, 2]), - "scores": torch.FloatTensor([0.9, 0.8]), - }, - { - "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), - "labels": torch.LongTensor([1, 2]), - "scores": torch.FloatTensor([0.9, 0.8]), - }, - ] - metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox") - metrics = metric(preds, targets) - assert metrics["map_50"] == torch.tensor([0.5]) - - -def test_corner_case(): - """Issue: https://github.com/Lightning-AI/torchmetrics/issues/1184.""" - metric = MeanAveragePrecision(iou_thresholds=[0.501], class_metrics=True) - preds = [ - { - "boxes": torch.Tensor( - [[0, 0, 20, 20], [30, 30, 50, 50], [70, 70, 90, 90], [100, 100, 120, 120]] - ), # FP # FP - "scores": torch.Tensor([0.6, 0.6, 0.6, 0.6]), - "labels": torch.IntTensor([0, 1, 2, 3]), - } - ] - - targets = [ - { - "boxes": torch.Tensor([[0, 0, 20, 20], [30, 30, 50, 50]]), - "labels": torch.IntTensor([0, 1]), - } - ] - metric.update(preds, targets) - res = metric.compute() - assert res["map"] == torch.tensor([0.5]) From b8c99e77b381d8557c54c2a18fbeed9ac967dedd Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 13 Jun 2023 12:14:12 +0200 Subject: [PATCH 04/16] fix dtype casting --- src/torchmetrics/detection/mean_ap.py | 38 +++++++++++++-------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 21fd6f13e01..ae10dc31396 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -412,31 +412,31 @@ def compute(self) -> dict: coco_eval.summarize() class_stats = coco_eval.stats - map_per_class_list.append(torch.Tensor([class_stats[0]])) - mar_100_per_class_list.append(torch.Tensor([class_stats[8]])) + map_per_class_list.append(torch.tensor([class_stats[0]])) + mar_100_per_class_list.append(torch.tensor([class_stats[8]])) - map_per_class_values = torch.Tensor(map_per_class_list) - mar_100_per_class_values = torch.Tensor(mar_100_per_class_list) + map_per_class_values = torch.tensor(map_per_class_list, dtype=torch.float32) + mar_100_per_class_values = torch.tensor(mar_100_per_class_list, dtype=torch.float32) else: - map_per_class_values: Tensor = torch.Tensor([-1]) - mar_100_per_class_values: Tensor = torch.Tensor([-1]) + map_per_class_values: Tensor = torch.tensor([-1], dtype=torch.float32) + mar_100_per_class_values: Tensor = torch.tensor([-1], dtype=torch.float32) metrics = MAPMetricResults( - map=torch.Tensor([stats[0]]), - map_50=torch.Tensor([stats[1]]), - map_75=torch.Tensor([stats[2]]), - map_small=torch.Tensor([stats[3]]), - map_medium=torch.Tensor([stats[4]]), - map_large=torch.Tensor([stats[5]]), - mar_1=torch.Tensor([stats[6]]), - mar_10=torch.Tensor([stats[7]]), - mar_100=torch.Tensor([stats[8]]), - mar_small=torch.Tensor([stats[9]]), - mar_medium=torch.Tensor([stats[10]]), - mar_large=torch.Tensor([stats[11]]), + map=torch.tensor([stats[0]], dtype=torch.float32), + map_50=torch.tensor([stats[1]], dtype=torch.float32), + map_75=torch.tensor([stats[2]], dtype=torch.float32), + map_small=torch.tensor([stats[3]], dtype=torch.float32), + map_medium=torch.tensor([stats[4]], dtype=torch.float32), + map_large=torch.tensor([stats[5]], dtype=torch.float32), + mar_1=torch.tensor([stats[6]], dtype=torch.float32), + mar_10=torch.tensor([stats[7]], dtype=torch.float32), + mar_100=torch.tensor([stats[8]], dtype=torch.float32), + mar_small=torch.tensor([stats[9]], dtype=torch.float32), + mar_medium=torch.tensor([stats[10]], dtype=torch.float32), + mar_large=torch.tensor([stats[11]], dtype=torch.float32), map_per_class=map_per_class_values, mar_100_per_class=mar_100_per_class_values, - classes=torch.Tensor(self._get_classes()), + classes=torch.tensor(self._get_classes(), dtype=torch.int32), ) return metrics.__dict__ From 13fd068bfd063489cbfaf50ababe4cad42a79552 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 13 Jun 2023 12:22:07 +0200 Subject: [PATCH 05/16] remove old code --- src/torchmetrics/detection/mean_ap.py | 61 ++++++++------------------- tests/unittests/helpers/testers.py | 4 -- 2 files changed, 17 insertions(+), 48 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index ae10dc31396..eb2240f3466 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -57,31 +57,6 @@ log = logging.getLogger(__name__) -@dataclass -class MAPMetricResults: - """Dataclass to wrap the final mAP results.""" - - map: Tensor # noqa: A003 - map_50: Tensor - map_75: Tensor - map_small: Tensor - map_medium: Tensor - map_large: Tensor - mar_1: Tensor - mar_10: Tensor - mar_100: Tensor - mar_small: Tensor - mar_medium: Tensor - mar_large: Tensor - map_per_class: Tensor - mar_100_per_class: Tensor - classes: Tensor - - def __getitem__(self, key: str) -> Union[Tensor, List[Tensor]]: - """Enables accessing the results via `result['map']` instead of `result.map`.""" - return getattr(self, key) - - class WriteToLog: """Logging class to move logs to log.debug().""" @@ -421,25 +396,23 @@ def compute(self) -> dict: map_per_class_values: Tensor = torch.tensor([-1], dtype=torch.float32) mar_100_per_class_values: Tensor = torch.tensor([-1], dtype=torch.float32) - metrics = MAPMetricResults( - map=torch.tensor([stats[0]], dtype=torch.float32), - map_50=torch.tensor([stats[1]], dtype=torch.float32), - map_75=torch.tensor([stats[2]], dtype=torch.float32), - map_small=torch.tensor([stats[3]], dtype=torch.float32), - map_medium=torch.tensor([stats[4]], dtype=torch.float32), - map_large=torch.tensor([stats[5]], dtype=torch.float32), - mar_1=torch.tensor([stats[6]], dtype=torch.float32), - mar_10=torch.tensor([stats[7]], dtype=torch.float32), - mar_100=torch.tensor([stats[8]], dtype=torch.float32), - mar_small=torch.tensor([stats[9]], dtype=torch.float32), - mar_medium=torch.tensor([stats[10]], dtype=torch.float32), - mar_large=torch.tensor([stats[11]], dtype=torch.float32), - map_per_class=map_per_class_values, - mar_100_per_class=mar_100_per_class_values, - classes=torch.tensor(self._get_classes(), dtype=torch.int32), - ) - - return metrics.__dict__ + return { + "map": torch.tensor([stats[0]], dtype=torch.float32), + "map_50": torch.tensor([stats[1]], dtype=torch.float32), + "map_75": torch.tensor([stats[2]], dtype=torch.float32), + "map_small": torch.tensor([stats[3]], dtype=torch.float32), + "map_medium": torch.tensor([stats[4]], dtype=torch.float32), + "map_large": torch.tensor([stats[5]], dtype=torch.float32), + "mar_1": torch.tensor([stats[6]], dtype=torch.float32), + "mar_10": torch.tensor([stats[7]], dtype=torch.float32), + "mar_100": torch.tensor([stats[8]], dtype=torch.float32), + "mar_small": torch.tensor([stats[9]], dtype=torch.float32), + "mar_medium": torch.tensor([stats[10]], dtype=torch.float32), + "mar_large": torch.tensor([stats[11]], dtype=torch.float32), + "map_per_class": map_per_class_values, + "mar_100_per_class": mar_100_per_class_values, + "classes": torch.tensor(self._get_classes(), dtype=torch.int32), + } @staticmethod def coco_to_tm( diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index ad09d430a13..04511c4fc7d 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -22,7 +22,6 @@ import torch from torch import Tensor, tensor from torchmetrics import Metric -from torchmetrics.detection.mean_ap import MAPMetricResults from torchmetrics.utilities.data import _flatten, apply_to_collection from unittests import NUM_PROCESSES @@ -54,9 +53,6 @@ def _assert_tensor(tm_result: Any, key: Optional[str] = None) -> None: if key is None: raise KeyError("Provide Key for Dict based metric results.") assert isinstance(tm_result[key], Tensor) - elif isinstance(tm_result, MAPMetricResults): - for val_index in [a for a in dir(tm_result) if not a.startswith("__")]: - assert isinstance(tm_result[val_index], Tensor) else: assert isinstance(tm_result, Tensor) From 3c6f01cabaf1b6982c7e64cfc773bc280a7f3897 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 13 Jun 2023 12:56:56 +0200 Subject: [PATCH 06/16] improve docs --- src/torchmetrics/detection/helpers.py | 46 +++++- src/torchmetrics/detection/mean_ap.py | 194 ++++++++++++++------------ 2 files changed, 146 insertions(+), 94 deletions(-) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index c86787992f3..c4f728e85f7 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Sequence +import logging +import sys +from types import TracebackType +from typing import Dict, Optional, Sequence, Type from torch import Tensor @@ -75,3 +78,44 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor: if boxes.numel() == 0 and boxes.ndim == 1: return boxes.unsqueeze(0) return boxes + + +class _WriteToLog: + """Logging class to move logs to log.debug().""" + + _log = logging.getLogger(__name__) + + def write(self, buf: str) -> None: + """Write to log.debug() instead of stdout.""" + for line in buf.rstrip().splitlines(): + self._log.debug(line.rstrip()) + + def flush(self) -> None: + """Flush the logger.""" + for handler in self._log.handlers: + handler.flush() + + def close(self) -> None: + """Close the logger.""" + for handler in self._log.handlers: + handler.close() + + +class _HidePrints: + """Internal helper context to suppress the default output of the pycocotools package.""" + + def __init__(self) -> None: + """Initialize the context.""" + self._original_stdout = None + + def __enter__(self) -> None: + """Redirect stdout to log.debug().""" + self._original_stdout = sys.stdout # type: ignore + sys.stdout = _WriteToLog() # type: ignore + + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_t: Optional[TracebackType] + ) -> None: # type: ignore + """Restore stdout.""" + sys.stdout.close() + sys.stdout = self._original_stdout # type: ignore diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index eb2240f3466..62fec072b37 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -13,10 +13,7 @@ # limitations under the License. import json import logging -import sys -from dataclasses import dataclass -from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -24,7 +21,7 @@ from torch import distributed as dist from typing_extensions import Literal -from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator +from torchmetrics.detection.helpers import _fix_empty_tensors, _HidePrints, _input_validator from torchmetrics.metric import Metric from torchmetrics.utilities.imports import ( _MATPLOTLIB_AVAILABLE, @@ -54,48 +51,6 @@ __doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"] -log = logging.getLogger(__name__) - - -class WriteToLog: - """Logging class to move logs to log.debug().""" - - def write(self, buf: str) -> None: - """Write to log.debug() instead of stdout.""" - for line in buf.rstrip().splitlines(): - log.debug(line.rstrip()) - - def flush(self) -> None: - """Flush the logger.""" - for handler in log.handlers: - handler.flush() - - def close(self) -> None: - """Close the logger.""" - for handler in log.handlers: - handler.close() - - -class _HidePrints: - """Internal helper context to suppress the default output of the pycocotools package.""" - - def __init__(self) -> None: - """Initialize the context.""" - self._original_stdout = None - - def __enter__(self) -> None: - """Redirect stdout to log.debug().""" - self._original_stdout = sys.stdout # type: ignore - sys.stdout = WriteToLog() # type: ignore - - def __exit__( - self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_t: Optional[TracebackType] - ) -> None: # type: ignore - """Restore stdout.""" - sys.stdout.close() - sys.stdout = self._original_stdout # type: ignore - - class MeanAveragePrecision(Metric): r"""Compute the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)`_ for object detection predictions. @@ -103,8 +58,12 @@ class MeanAveragePrecision(Metric): \text{mAP} = \frac{1}{n} \sum_{i=1}^{n} AP_i where :math:`AP_i` is the average precision for class :math:`i` and :math:`n` is the number of classes. The average - precision is defined as the area under the precision-recall curve. If argument `class_metrics` is set to ``True``, - the metric will also return the mAP/mAR per class. + precision is defined as the area under the precision-recall curve. For object detection the recall and precision are + defined based on the intersection of union (IoU) between the predicted bounding boxes and the ground truth bounding + boxes e.g. if two boxes have an IoU > t (with t being some threshold) they are considered a match and therefore + considered a true positive. The precision is then defined as the number of true positives divided by the number of + all detected boxes and the recall is defined as the number of true positives divided by the number of all ground + boxes. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -141,21 +100,25 @@ class MeanAveragePrecision(Metric): - ``map_dict``: A dictionary containing the following key-values: - - map: (:class:`~torch.Tensor`) - - map_small: (:class:`~torch.Tensor`) - - map_medium:(:class:`~torch.Tensor`) - - map_large: (:class:`~torch.Tensor`) - - mar_1: (:class:`~torch.Tensor`) - - mar_10: (:class:`~torch.Tensor`) - - mar_100: (:class:`~torch.Tensor`) - - mar_small: (:class:`~torch.Tensor`) - - mar_medium: (:class:`~torch.Tensor`) - - mar_large: (:class:`~torch.Tensor`) - - map_50: (:class:`~torch.Tensor`) (-1 if 0.5 not in the list of iou thresholds) - - map_75: (:class:`~torch.Tensor`) (-1 if 0.75 not in the list of iou thresholds) - - map_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled) - - mar_100_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled) - - classes (:class:`~torch.Tensor`) + - map: (:class:`~torch.Tensor`), global mean average precision + - map_small: (:class:`~torch.Tensor`), mean average precision for small objects + - map_medium:(:class:`~torch.Tensor`), mean average precision for medium objects + - map_large: (:class:`~torch.Tensor`), mean average precision for large objects + - mar_1: (:class:`~torch.Tensor`), mean average recall for 1 detection per image + - mar_10: (:class:`~torch.Tensor`), mean average recall for 10 detections per image + - mar_100: (:class:`~torch.Tensor`), mean average recall for 100 detections per image + - mar_small: (:class:`~torch.Tensor`), mean average recall for small objects + - mar_medium: (:class:`~torch.Tensor`), mean average recall for medium objects + - mar_large: (:class:`~torch.Tensor`), mean average recall for large objects + - map_50: (:class:`~torch.Tensor`) (-1 if 0.5 not in the list of iou thresholds), mean average precision at + IoU=0.50 + - map_75: (:class:`~torch.Tensor`) (-1 if 0.75 not in the list of iou thresholds), mean average precision at + IoU=0.75 + - map_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled), mean average precision per + observed class + - mar_100_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled), mean average recall for 100 + detections per image per observed class + - classes (:class:`~torch.Tensor`), list of all observed classes For an example on how to use this metric check the `torchmetrics mAP example`_. @@ -165,23 +128,16 @@ class MeanAveragePrecision(Metric): The default properties are also accessible via fields and will raise an ``AttributeError`` if not available. .. note:: - This metric is following the mAP implementation of - `pycocotools `_, - a standard implementation for the mAP metric for object detection. - - .. note:: - This metric requires you to have `torchvision` version 0.8.0 or newer installed - (with corresponding version 1.7.0 of torch or newer). This metric requires `pycocotools` - installed when iou_type is `segm`. Please install with ``pip install torchvision`` or - ``pip install torchmetrics[detection]``. + This metric utilizes the official `pycocotools` implementation as its backend. This means that the metric + requires you to have `pycocotools` installed. In addition we require `torchvision` version 0.8.0 or newer. + Please install with ``pip install torchmetrics[detection]``. Args: box_format: Input format of given boxes. Supported formats are ``[`xyxy`, `xywh`, `cxcywh`]``. iou_type: Type of input (either masks or bounding-boxes) used for computing IOU. - Supported IOU types are ``["bbox", "segm"]``. - If using ``"segm"``, masks should be provided (see :meth:`update`). + Supported IOU types are ``["bbox", "segm"]``. If using ``"segm"``, masks should be provided in input. iou_thresholds: IoU thresholds for evaluation. If set to ``None`` it corresponds to the stepped range ``[0.5,...,0.95]`` with step ``0.05``. Else provide a list of floats. @@ -197,27 +153,21 @@ class MeanAveragePrecision(Metric): Raises: ModuleNotFoundError: - If ``torchvision`` is not installed or version installed is lower than 0.8.0 + If ``pycocotools`` is not installed ModuleNotFoundError: - If ``iou_type`` is equal to ``segm`` and ``pycocotools`` is not installed - ValueError: - If ``class_metrics`` is not a boolean - ValueError: - If ``preds`` is not of type (:class:`~List[Dict[str, Tensor]]`) - ValueError: - If ``target`` is not of type ``List[Dict[str, Tensor]]`` + If ``torchvision`` is not installed or version installed is lower than 0.8.0 ValueError: - If ``preds`` and ``target`` are not of the same length + If ``box_format`` is not one of ``"xyxy"``, ``"xywh"`` or ``"cxcywh"`` ValueError: - If any of ``preds.boxes``, ``preds.scores`` and ``preds.labels`` are not of the same length + If ``iou_type`` is not one of ``"bbox"`` or ``"segm"`` ValueError: - If any of ``target.boxes`` and ``target.labels`` are not of the same length + If ``iou_thresholds`` is not None or a list of floats ValueError: - If any box is not type float and of length 4 + If ``rec_thresholds`` is not None or a list of floats ValueError: - If any class is not type int and of length 1 + If ``max_detection_thresholds`` is not None or a list of ints ValueError: - If any score is not type float and of length 1 + If ``class_metrics`` is not a boolean Example: >>> from torch import tensor @@ -282,7 +232,7 @@ def __init__( super().__init__(**kwargs) if not _PYCOCOTOOLS_AVAILABLE: - raise ImportError( + raise ModuleNotFoundError( "`MAP` metric requires that `pycocotools` installed." " Please install with `pip install pycocotools` or `pip install torchmetrics[detection]`" ) @@ -335,7 +285,26 @@ def __init__( self.add_state("groundtruth_area", default=[], dist_reduce_fx=None) def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore - """Update metric state.""" + """Update metric state. + + Raises: + ValueError: + If ``preds`` is not of type (:class:`~List[Dict[str, Tensor]]`) + ValueError: + If ``target`` is not of type ``List[Dict[str, Tensor]]`` + ValueError: + If ``preds`` and ``target`` are not of the same length + ValueError: + If any of ``preds.boxes``, ``preds.scores`` and ``preds.labels`` are not of the same length + ValueError: + If any of ``target.boxes`` and ``target.labels`` are not of the same length + ValueError: + If any box is not type float and of length 4 + ValueError: + If any class is not type int and of length 1 + ValueError: + If any score is not type float and of length 1 + """ _input_validator(preds, target, iou_type=self.iou_type) for item in preds: @@ -420,7 +389,10 @@ def coco_to_tm( coco_target: str, iou_type: Literal["bbox", "segm"] = "bbox", ) -> Tuple[List[Dict[str, Tensor]], List[Dict[str, Tensor]]]: - """Convert coco format to the input format of the map metric. + """Utility function for converting .json coco format files to the input format of this metric. + + The function accepts a file for the predictions and a file for the target in coco format and converts them to + a list of dictionaries containing the boxes, labels and scores in the input format of this metric. Args: coco_preds: Path to the json file containing the predictions in coco format @@ -431,6 +403,17 @@ def coco_to_tm( preds: List of dictionaries containing the predictions in the input format of this metric target: List of dictionaries containing the targets in the input format of this metric + Example: + >>> # File formats are defined at https://cocodataset.org/#format-data + >>> # Example files can be found at + >>> # https://github.com/cocodataset/cocoapi/tree/master/results + >>> from torchmetrics.detection import MeanAveragePrecision + >>> preds, target = MeanAveragePrecision.coco_to_tm( + ... "instances_val2014_fakebbox100_results.json.json", + ... "val2014_fake_eval_res.txt.json" + ... iou_type="bbox" + ... ) # doctest: +SKIP + """ gt = COCO(coco_target) dt = gt.loadRes(coco_preds) @@ -495,10 +478,35 @@ def coco_to_tm( return batched_preds, batched_target def tm_to_coco(self, name: str = "tm_map_input") -> None: - """Write the input to the map metric to a json file in coco format. + """Utility function for converting the input for this metric to coco format and saving it to a json file. + + This function should be used after calling `.update(...)` or `.forward(...)` on all data that should be written + to the file, as the input is then internally cached. The function then converts to information to coco format + a writes it to json files. Args: name: Name of the output file, which will be appended with "_preds.json" and "_target.json" + + Example: + >>> from torch import tensor + >>> from torchmetrics.detection import MeanAveragePrecision + >>> preds = [ + ... dict( + ... boxes=tensor([[258.0, 41.0, 606.0, 285.0]]), + ... scores=tensor([0.536]), + ... labels=tensor([0]), + ... ) + ... ] + >>> target = [ + ... dict( + ... boxes=tensor([[214.0, 41.0, 562.0, 285.0]]), + ... labels=tensor([0]), + ... ) + ... ] + >>> metric = MeanAveragePrecision() + >>> metric.update(preds, target) + >>> metric.tm_to_coco("tm_map_input") # doctest: +SKIP + """ target_dataset = self._get_coco_format(self.groundtruths, self.groundtruth_labels) preds_dataset = self._get_coco_format(self.detections, self.detection_labels, self.detection_scores) From 6132051a471bca0eadb9372ea9ca441fd378e33b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 13 Jun 2023 13:01:04 +0200 Subject: [PATCH 07/16] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fa48d2de0d..76051590423 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -149,6 +149,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed calculation in `PearsonCorrCoeff` to be more robust in certain cases ([#1729](https://github.com/Lightning-AI/torchmetrics/pull/1729)) + +- Changed `MeanAveragePrecision` to `pycocotools` backend ([#1832](https://github.com/Lightning-AI/torchmetrics/pull/1832)) + + ### Deprecated - Deprecated domain metrics import from package root ( From 908f658d61095815717541a9fefacf70ff48f88b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 13 Jun 2023 13:29:54 +0200 Subject: [PATCH 08/16] refactor --- tests/unittests/detection/test_map.py | 142 +------------------------- 1 file changed, 5 insertions(+), 137 deletions(-) diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 21b34ca3c44..ebd1207be27 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -39,65 +39,9 @@ def _generate_coco_inputs(iou_type): https://github.com/cocodataset/cocoapi/tree/master/results and should therefore correspond directly to the result on the webpage """ - gt = COCO(_DETECTION_VAL) - dt = gt.loadRes(_DETECTION_BBOX if iou_type == "bbox" else _DETECTION_SEGM) - img_ids = sorted(gt.getImgIds()) - img_ids = img_ids[0:100] - - gt_dataset = gt.dataset["annotations"] - dt_dataset = dt.dataset["annotations"] - - preds = {} - for p in dt_dataset: - if p["image_id"] not in preds: - preds[p["image_id"]] = {"boxes" if iou_type == "bbox" else "masks": [], "scores": [], "labels": []} - if iou_type == "bbox": - preds[p["image_id"]]["boxes"].append(p["bbox"]) - else: - preds[p["image_id"]]["masks"].append(gt.annToMask(p)) - preds[p["image_id"]]["scores"].append(p["score"]) - preds[p["image_id"]]["labels"].append(p["category_id"]) - missing_pred = set(img_ids) - set(preds.keys()) - for i in missing_pred: - preds[i] = {"boxes" if iou_type == "bbox" else "masks": [], "scores": [], "labels": []} - - target = {} - for t in gt_dataset: - if t["image_id"] not in img_ids: - continue - if t["image_id"] not in target: - target[t["image_id"]] = { - "boxes" if iou_type == "bbox" else "masks": [], - "labels": [], - "iscrowd": [], - "area": [], - } - if iou_type == "bbox": - target[t["image_id"]]["boxes"].append(t["bbox"]) - else: - target[t["image_id"]]["masks"].append(gt.annToMask(t)) - target[t["image_id"]]["labels"].append(t["category_id"]) - target[t["image_id"]]["iscrowd"].append(t["iscrowd"]) - target[t["image_id"]]["area"].append(t["area"]) - - batched_preds, batched_target = [], [] - for key in target: - name = "boxes" if iou_type == "bbox" else "masks" - batched_preds.append( - { - name: torch.tensor(preds[key]["boxes"]) if iou_type == "bbox" else torch.tensor(preds[key]["masks"]), - "scores": torch.tensor(preds[key]["scores"]), - "labels": torch.tensor(preds[key]["labels"]), - } - ) - batched_target.append( - { - name: torch.tensor(target[key]["boxes"]) if iou_type == "bbox" else torch.tensor(target[key]["masks"]), - "labels": torch.tensor(target[key]["labels"]), - "iscrowd": torch.tensor(target[key]["iscrowd"]), - "area": torch.tensor(target[key]["area"]), - } - ) + batched_preds, batched_target = MeanAveragePrecision.coco_to_tm( + _DETECTION_BBOX if iou_type == "bbox" else _DETECTION_SEGM, _DETECTION_VAL, iou_type + ) # create 10 batches of 10 preds/targets each batched_preds = [batched_preds[10 * i : 10 * (i + 1)] for i in range(10)] @@ -124,85 +68,9 @@ def _compare_again_coco_fn(preds, target, iou_type, class_metrics=True): map_per_class_values = torch.Tensor([-1]) mar_100_per_class_values = torch.Tensor([-1]) classes = torch.tensor( - [ - 1.0, - 2.0, - 3.0, - 4.0, - 5.0, - 6.0, - 7.0, - 8.0, - 9.0, - 10.0, - 11.0, - 13.0, - 14.0, - 15.0, - 16.0, - 17.0, - 18.0, - 20.0, - 21.0, - 22.0, - 23.0, - 24.0, - 25.0, - 27.0, - 28.0, - 31.0, - 32.0, - 33.0, - 34.0, - 35.0, - 36.0, - 37.0, - 38.0, - 39.0, - 40.0, - 41.0, - 42.0, - 43.0, - 44.0, - 46.0, - 47.0, - 48.0, - 49.0, - 50.0, - 51.0, - 52.0, - 53.0, - 54.0, - 55.0, - 56.0, - 57.0, - 58.0, - 59.0, - 60.0, - 61.0, - 62.0, - 63.0, - 64.0, - 65.0, - 67.0, - 70.0, - 72.0, - 73.0, - 74.0, - 75.0, - 77.0, - 78.0, - 79.0, - 80.0, - 81.0, - 82.0, - 84.0, - 85.0, - 86.0, - 88.0, - 90.0, - ] + list(set(torch.arange(91).tolist()) - {0, 12, 19, 26, 29, 30, 45, 66, 68, 69, 71, 76, 83, 87, 89}) ) + if class_metrics: map_per_class_list = [] mar_100_per_class_list = [] From fd21da01f245b5665615e273d3267fd8323ae571 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 13 Jun 2023 13:33:07 +0200 Subject: [PATCH 09/16] fix doc formatting --- src/torchmetrics/detection/mean_ap.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 62fec072b37..457dbcad3ed 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -90,11 +90,11 @@ class MeanAveragePrecision(Metric): - masks: :class:`~torch.bool` of shape ``(num_boxes, image_height, image_width)`` containing boolean masks. Only required when `iou_type="segm"`. - iscrowd: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0/1 values indicating whether - the bounding box/masks indicate a crowd of objects. Value is optional, and if not provided it will - automatically be set to 0. + the bounding box/masks indicate a crowd of objects. Value is optional, and if not provided it will + automatically be set to 0. - area: :class:`~torch.FloatTensor` of shape ``(num_boxes)`` containing the area of the object. Value if - optional, and if not provided will be automatically calculated based on the bounding box/masks provided. - Only affects which samples contribute to the `map_small`, `map_medium`, `map_large` values + optional, and if not provided will be automatically calculated based on the bounding box/masks provided. + Only affects which samples contribute to the `map_small`, `map_medium`, `map_large` values As output of ``forward`` and ``compute`` the metric returns the following output: @@ -111,13 +111,13 @@ class MeanAveragePrecision(Metric): - mar_medium: (:class:`~torch.Tensor`), mean average recall for medium objects - mar_large: (:class:`~torch.Tensor`), mean average recall for large objects - map_50: (:class:`~torch.Tensor`) (-1 if 0.5 not in the list of iou thresholds), mean average precision at - IoU=0.50 + IoU=0.50 - map_75: (:class:`~torch.Tensor`) (-1 if 0.75 not in the list of iou thresholds), mean average precision at - IoU=0.75 + IoU=0.75 - map_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled), mean average precision per - observed class + observed class - mar_100_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled), mean average recall for 100 - detections per image per observed class + detections per image per observed class - classes (:class:`~torch.Tensor`), list of all observed classes For an example on how to use this metric check the `torchmetrics mAP example`_. From 8c0a41d2fe8b3b70406ec40744e5e214563fd969 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 13 Jun 2023 14:34:14 +0200 Subject: [PATCH 10/16] mypy --- src/torchmetrics/detection/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index c4f728e85f7..f519b97568d 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -115,7 +115,7 @@ def __enter__(self) -> None: def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_t: Optional[TracebackType] - ) -> None: # type: ignore + ) -> None: """Restore stdout.""" sys.stdout.close() sys.stdout = self._original_stdout # type: ignore From 5651da050cfbbe2d87347c7217c143e137838330 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 14 Jun 2023 10:57:19 +0200 Subject: [PATCH 11/16] skip doctest on missing import --- src/torchmetrics/detection/mean_ap.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 457dbcad3ed..e5a1dcb0262 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -38,7 +38,12 @@ from torchvision.ops import box_convert else: box_convert = None - __doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"] + __doctest_skip__ = [ + "MeanAveragePrecision.plot", + "MeanAveragePrecision", + "MeanAveragePrecision.tm_to_coco", + "MeanAveragePrecision.coco_to_tm", + ] if _PYCOCOTOOLS_AVAILABLE: @@ -48,7 +53,12 @@ else: COCO, COCOeval = None, None mask_utils = None - __doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"] + __doctest_skip__ = [ + "MeanAveragePrecision.plot", + "MeanAveragePrecision", + "MeanAveragePrecision.tm_to_coco", + "MeanAveragePrecision.coco_to_tm", + ] class MeanAveragePrecision(Metric): From 1e7bb62d7771004d2b058662881e087f5156d6f7 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 15 Jun 2023 14:06:55 +0200 Subject: [PATCH 12/16] remove helper --- src/torchmetrics/detection/helpers.py | 46 +-------------------------- src/torchmetrics/detection/mean_ap.py | 20 ++++++------ tests/unittests/detection/test_map.py | 18 +++++++---- 3 files changed, 21 insertions(+), 63 deletions(-) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index f519b97568d..c86787992f3 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -11,10 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys -from types import TracebackType -from typing import Dict, Optional, Sequence, Type +from typing import Dict, Sequence from torch import Tensor @@ -78,44 +75,3 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor: if boxes.numel() == 0 and boxes.ndim == 1: return boxes.unsqueeze(0) return boxes - - -class _WriteToLog: - """Logging class to move logs to log.debug().""" - - _log = logging.getLogger(__name__) - - def write(self, buf: str) -> None: - """Write to log.debug() instead of stdout.""" - for line in buf.rstrip().splitlines(): - self._log.debug(line.rstrip()) - - def flush(self) -> None: - """Flush the logger.""" - for handler in self._log.handlers: - handler.flush() - - def close(self) -> None: - """Close the logger.""" - for handler in self._log.handlers: - handler.close() - - -class _HidePrints: - """Internal helper context to suppress the default output of the pycocotools package.""" - - def __init__(self) -> None: - """Initialize the context.""" - self._original_stdout = None - - def __enter__(self) -> None: - """Redirect stdout to log.debug().""" - self._original_stdout = sys.stdout # type: ignore - sys.stdout = _WriteToLog() # type: ignore - - def __exit__( - self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_t: Optional[TracebackType] - ) -> None: - """Restore stdout.""" - sys.stdout.close() - sys.stdout = self._original_stdout # type: ignore diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index e5a1dcb0262..5d8cf797605 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import io import json -import logging from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -21,7 +22,7 @@ from torch import distributed as dist from typing_extensions import Literal -from torchmetrics.detection.helpers import _fix_empty_tensors, _HidePrints, _input_validator +from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator from torchmetrics.metric import Metric from torchmetrics.utilities.imports import ( _MATPLOTLIB_AVAILABLE, @@ -340,7 +341,7 @@ def compute(self) -> dict: ) coco_preds.dataset = self._get_coco_format(self.detections, self.detection_labels, scores=self.detection_scores) - with _HidePrints(): + with contextlib.redirect_stdout(io.StringIO()): coco_target.createIndex() coco_preds.createIndex() @@ -360,7 +361,7 @@ def compute(self) -> dict: mar_100_per_class_list = [] for class_id in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist(): coco_eval.params.catIds = [class_id] - with _HidePrints(): + with contextlib.redirect_stdout(io.StringIO()): coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() @@ -425,8 +426,9 @@ def coco_to_tm( ... ) # doctest: +SKIP """ - gt = COCO(coco_target) - dt = gt.loadRes(coco_preds) + with contextlib.redirect_stdout(io.StringIO()): + gt = COCO(coco_target) + dt = gt.loadRes(coco_preds) gt_dataset = gt.dataset["annotations"] dt_dataset = dt.dataset["annotations"] @@ -697,11 +699,7 @@ def _apply(self, fn: Callable) -> torch.nn.Module: Excludes the detections and groundtruths from the casting when the iou_type is set to `segm` as the state is no longer a tensor but a tuple. """ - if self.iou_type == "segm": - this = super()._apply(fn, exclude_state=("detections", "groundtruths")) - else: - this = super()._apply(fn) - return this + return super()._apply(fn, exclude_state=("detections", "groundtruths") if self.iou_type == "segm" else "") def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None: """Custom sync function. diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index ebd1207be27..26444d68901 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json +import contextlib +import io from collections import namedtuple from copy import deepcopy from functools import partial @@ -23,7 +24,7 @@ from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval from torch import IntTensor, Tensor -from torchmetrics.detection.mean_ap import MeanAveragePrecision, _HidePrints +from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchmetrics.utilities.imports import _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL, _SAMPLE_DETECTION_SEGMENTATION @@ -52,14 +53,17 @@ def _generate_coco_inputs(iou_type): _coco_bbox_input = _generate_coco_inputs("bbox") _coco_segm_input = _generate_coco_inputs("segm") +with contextlib.redirect_stdout(io.StringIO()): + gt = COCO(_DETECTION_VAL) + dt_box = gt.loadRes(_DETECTION_BBOX) + dt_segm = gt.loadRes(_DETECTION_SEGM) + def _compare_again_coco_fn(preds, target, iou_type, class_metrics=True): """Taken from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb.""" - gt = COCO(_DETECTION_VAL) - dt = gt.loadRes(_DETECTION_BBOX if iou_type == "bbox" else _DETECTION_SEGM) - + dt = dt_box if iou_type == "bbox" else dt_segm coco_eval = COCOeval(gt, dt, iou_type) - with _HidePrints(): + with contextlib.redirect_stdout(io.StringIO()): coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() @@ -76,7 +80,7 @@ def _compare_again_coco_fn(preds, target, iou_type, class_metrics=True): mar_100_per_class_list = [] for class_id in classes.tolist(): coco_eval.params.catIds = [class_id] - with _HidePrints(): + with contextlib.redirect_stdout(io.StringIO()): coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() From dc875d9122a4068dbed77d82b81613bf8a7e4742 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 21 Jun 2023 08:42:13 +0200 Subject: [PATCH 13/16] fix tests --- src/torchmetrics/detection/mean_ap.py | 22 +++++++++++----------- tests/unittests/conftest.py | 3 ++- tests/unittests/detection/test_map.py | 11 ++++------- tests/unittests/helpers/testers.py | 7 ++++++- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 5d8cf797605..69cbb0c457c 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -330,7 +330,7 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] self.groundtruths.append(groundtruths) self.groundtruth_labels.append(item["labels"]) self.groundtruth_crowds.append(item.get("iscrowd", torch.zeros_like(item["labels"]))) - self.groundtruth_area.append(item.get("area", -1 * torch.zeros_like(item["labels"]))) + self.groundtruth_area.append(item.get("area", torch.zeros_like(item["labels"]))) def compute(self) -> dict: """Computes the metric.""" @@ -359,7 +359,7 @@ def compute(self) -> dict: if self.class_metrics: map_per_class_list = [] mar_100_per_class_list = [] - for class_id in torch.cat(self.detection_labels + self.groundtruth_labels).unique().cpu().tolist(): + for class_id in self._get_classes(): coco_eval.params.catIds = [class_id] with contextlib.redirect_stdout(io.StringIO()): coco_eval.evaluate() @@ -469,21 +469,21 @@ def coco_to_tm( name = "boxes" if iou_type == "bbox" else "masks" batched_preds.append( { - name: torch.tensor(preds[key]["boxes"]) + name: torch.tensor(np.array(preds[key]["boxes"]), dtype=torch.float32) if iou_type == "bbox" - else torch.tensor(preds[key]["masks"]), - "scores": torch.tensor(preds[key]["scores"]), - "labels": torch.tensor(preds[key]["labels"]), + else torch.tensor(np.array(preds[key]["masks"]), dtype=torch.uint8), + "scores": torch.tensor(preds[key]["scores"], dtype=torch.float32), + "labels": torch.tensor(preds[key]["labels"], dtype=torch.int32), } ) batched_target.append( { - name: torch.tensor(target[key]["boxes"]) + name: torch.tensor(target[key]["boxes"], dtype=torch.float32) if iou_type == "bbox" - else torch.tensor(target[key]["masks"]), - "labels": torch.tensor(target[key]["labels"]), - "iscrowd": torch.tensor(target[key]["iscrowd"]), - "area": torch.tensor(target[key]["area"]), + else torch.tensor(np.array(target[key]["masks"]), dtype=torch.uint8), + "labels": torch.tensor(target[key]["labels"], dtype=torch.int32), + "iscrowd": torch.tensor(target[key]["iscrowd"], dtype=torch.int32), + "area": torch.tensor(target[key]["area"], dtype=torch.float32), } ) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index ce8495511d9..2f59972e5de 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -17,10 +17,11 @@ import pytest import torch -from torch.multiprocessing import Pool, set_start_method +from torch.multiprocessing import Pool, set_sharing_strategy, set_start_method with contextlib.suppress(RuntimeError): set_start_method("spawn") + set_sharing_strategy("file_system") NUM_PROCESSES = 2 # torch.cuda.device_count() if torch.cuda.is_available() else 2 NUM_BATCHES = 2 * NUM_PROCESSES # Need to be divisible with the number of processes diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 26444d68901..954b9c0509a 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -53,17 +53,13 @@ def _generate_coco_inputs(iou_type): _coco_bbox_input = _generate_coco_inputs("bbox") _coco_segm_input = _generate_coco_inputs("segm") -with contextlib.redirect_stdout(io.StringIO()): - gt = COCO(_DETECTION_VAL) - dt_box = gt.loadRes(_DETECTION_BBOX) - dt_segm = gt.loadRes(_DETECTION_SEGM) - def _compare_again_coco_fn(preds, target, iou_type, class_metrics=True): """Taken from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb.""" - dt = dt_box if iou_type == "bbox" else dt_segm - coco_eval = COCOeval(gt, dt, iou_type) with contextlib.redirect_stdout(io.StringIO()): + gt = COCO(_DETECTION_VAL) + dt = gt.loadRes(_DETECTION_BBOX) if iou_type == "bbox" else gt.loadRes(_DETECTION_SEGM) + coco_eval = COCOeval(gt, dt, iou_type) coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() @@ -116,6 +112,7 @@ def _compare_again_coco_fn(preds, target, iou_type, class_metrics=True): class TestMAPUsingCOCOReference(MetricTester): """Test map metric on the reference coco data.""" + # the aggregated metrics pass with atol < 1e-2, but class_metrics=True only passes with atol=1e-1 atol = 1e-1 def test_map(self, iou_type, ddp): diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index 04511c4fc7d..ddf385bf789 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -149,7 +149,12 @@ def _class_test( for i in range(rank, num_batches, world_size): batch_kwargs_update = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} - batch_result = metric(preds[i], target[i], **batch_kwargs_update) + if (dist_sync_on_step and check_dist_sync_on_step == 0 and rank == 0) or ( + check_batch and not dist_sync_on_step + ): + batch_result = metric(preds[i], target[i], **batch_kwargs_update) + else: + metric.update(preds[i], target[i], **batch_kwargs_update) if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0: if isinstance(preds, Tensor): From 2dec2dbad3d916e49b0ff70ee0187c202852cdc5 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 21 Jun 2023 10:02:46 +0200 Subject: [PATCH 14/16] fix tests --- tests/unittests/detection/test_map.py | 41 +++++++++++++++++++++++---- tests/unittests/helpers/testers.py | 16 +++++------ 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 954b9c0509a..8f432f8c0cf 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -54,12 +54,18 @@ def _generate_coco_inputs(iou_type): _coco_segm_input = _generate_coco_inputs("segm") -def _compare_again_coco_fn(preds, target, iou_type, class_metrics=True): +def _compare_again_coco_fn(preds, target, iou_type, iou_thresholds=None, rec_thresholds=None, class_metrics=True): """Taken from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb.""" with contextlib.redirect_stdout(io.StringIO()): gt = COCO(_DETECTION_VAL) dt = gt.loadRes(_DETECTION_BBOX) if iou_type == "bbox" else gt.loadRes(_DETECTION_SEGM) + coco_eval = COCOeval(gt, dt, iou_type) + if iou_thresholds is not None: + coco_eval.params.iouThrs = np.array(iou_thresholds, dtype=np.float64) + if rec_thresholds is not None: + coco_eval.params.recThrs = np.array(rec_thresholds, dtype=np.float64) + coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() @@ -112,12 +118,36 @@ def _compare_again_coco_fn(preds, target, iou_type, class_metrics=True): class TestMAPUsingCOCOReference(MetricTester): """Test map metric on the reference coco data.""" - # the aggregated metrics pass with atol < 1e-2, but class_metrics=True only passes with atol=1e-1 - atol = 1e-1 - - def test_map(self, iou_type, ddp): + @pytest.mark.parametrize("iou_thresholds", [None, [0.25, 0.5, 0.75]]) + @pytest.mark.parametrize("rec_thresholds", [None, [0.25, 0.5, 0.75]]) + def test_map(self, iou_type, iou_thresholds, rec_thresholds, ddp): """Test modular implementation for correctness.""" preds, target = _coco_bbox_input if iou_type == "bbox" else _coco_segm_input + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MeanAveragePrecision, + reference_metric=partial( + _compare_again_coco_fn, + iou_type=iou_type, + iou_thresholds=iou_thresholds, + rec_thresholds=rec_thresholds, + class_metrics=False, + ), + metric_args={ + "iou_type": iou_type, + "iou_thresholds": iou_thresholds, + "rec_thresholds": rec_thresholds, + "class_metrics": False, + }, + check_batch=False, + atol=1e-2, + ) + + def test_map_classwise(self, iou_type, ddp): + """Test modular implementation for correctness with classwise=True. Needs bigger atol to be stable.""" + preds, target = _coco_bbox_input if iou_type == "bbox" else _coco_segm_input self.run_class_metric_test( ddp=ddp, preds=preds, @@ -126,6 +156,7 @@ def test_map(self, iou_type, ddp): reference_metric=partial(_compare_again_coco_fn, iou_type=iou_type, class_metrics=True), metric_args={"iou_type": iou_type, "class_metrics": True}, check_batch=False, + atol=1e-1, ) diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index ddf385bf789..7ebb3538050 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -149,12 +149,8 @@ def _class_test( for i in range(rank, num_batches, world_size): batch_kwargs_update = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} - if (dist_sync_on_step and check_dist_sync_on_step == 0 and rank == 0) or ( - check_batch and not dist_sync_on_step - ): - batch_result = metric(preds[i], target[i], **batch_kwargs_update) - else: - metric.update(preds[i], target[i], **batch_kwargs_update) + # compute batch stats and aggregate for global stats + batch_result = metric(preds[i], target[i], **batch_kwargs_update) if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0: if isinstance(preds, Tensor): @@ -376,6 +372,7 @@ def run_class_metric_test( check_batch: bool = True, fragment_kwargs: bool = False, check_scriptable: bool = True, + atol: Optional[float] = None, **kwargs_update: Any, ): """Core method that should be used for testing class. Call this inside testing methods. @@ -394,9 +391,12 @@ def run_class_metric_test( calculated across devices for each batch (and not just at the end) fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes check_scriptable: bool indicating if metric should also be tested if it can be scripted + atol: absolute tolerance used for comparison of results, if None will use self.atol kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. + """ + atol = atol or self.atol metric_args = metric_args or {} if ddp: if sys.platform == "win32": @@ -413,7 +413,7 @@ def run_class_metric_test( metric_args=metric_args, check_dist_sync_on_step=check_dist_sync_on_step, check_batch=check_batch, - atol=self.atol, + atol=atol, fragment_kwargs=fragment_kwargs, check_scriptable=check_scriptable, **kwargs_update, @@ -434,7 +434,7 @@ def run_class_metric_test( metric_args=metric_args, check_dist_sync_on_step=check_dist_sync_on_step, check_batch=check_batch, - atol=self.atol, + atol=atol, device=device, fragment_kwargs=fragment_kwargs, check_scriptable=check_scriptable, From b97eeb9115d4fab807d08200b44f6dda28e33d70 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 29 Jun 2023 13:30:16 +0200 Subject: [PATCH 15/16] readd old implementation --- src/torchmetrics/detection/_mean_ap.py | 970 +++++++++++++++++++++++++ 1 file changed, 970 insertions(+) create mode 100644 src/torchmetrics/detection/_mean_ap.py diff --git a/src/torchmetrics/detection/_mean_ap.py b/src/torchmetrics/detection/_mean_ap.py new file mode 100644 index 00000000000..b2a8435066a --- /dev/null +++ b/src/torchmetrics/detection/_mean_ap.py @@ -0,0 +1,970 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch import IntTensor, Tensor + +from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import _cumsum +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MeanAveragePrecision.plot"] + +if _TORCHVISION_GREATER_EQUAL_0_8: + from torchvision.ops import box_area, box_convert, box_iou +else: + box_convert = box_iou = box_area = None + __doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"] + +if _PYCOCOTOOLS_AVAILABLE: + import pycocotools.mask as mask_utils +else: + mask_utils = None + __doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"] + + +log = logging.getLogger(__name__) + + +def compute_area(inputs: List[Any], iou_type: str = "bbox") -> Tensor: + """Compute area of input depending on the specified iou_type. + + Default output for empty input is :class:`~torch.Tensor` + """ + if len(inputs) == 0: + return Tensor([]) + + if iou_type == "bbox": + return box_area(torch.stack(inputs)) + if iou_type == "segm": + inputs = [{"size": i[0], "counts": i[1]} for i in inputs] + return torch.tensor(mask_utils.area(inputs).astype("float")) + + raise Exception(f"IOU type {iou_type} is not supported") + + +def compute_iou( + det: List[Any], + gt: List[Any], + iou_type: str = "bbox", +) -> Tensor: + """Compute IOU between detections and ground-truth using the specified iou_type.""" + if iou_type == "bbox": + return box_iou(torch.stack(det), torch.stack(gt)) + if iou_type == "segm": + return _segm_iou(det, gt) + raise Exception(f"IOU type {iou_type} is not supported") + + +class BaseMetricResults(dict): + """Base metric class, that allows fields for pre-defined metrics.""" + + def __getattr__(self, key: str) -> Tensor: + """Get a specific metric attribute.""" + # Using this you get the correct error message, an AttributeError instead of a KeyError + if key in self: + return self[key] + raise AttributeError(f"No such attribute: {key}") + + def __setattr__(self, key: str, value: Tensor) -> None: + """Set a specific metric attribute.""" + self[key] = value + + def __delattr__(self, key: str) -> None: + """Delete a specific metric attribute.""" + if key in self: + del self[key] + raise AttributeError(f"No such attribute: {key}") + + +class MAPMetricResults(BaseMetricResults): + """Class to wrap the final mAP results.""" + + __slots__ = ("map", "map_50", "map_75", "map_small", "map_medium", "map_large", "classes") + + +class MARMetricResults(BaseMetricResults): + """Class to wrap the final mAR results.""" + + __slots__ = ("mar_1", "mar_10", "mar_100", "mar_small", "mar_medium", "mar_large") + + +class COCOMetricResults(BaseMetricResults): + """Class to wrap the final COCO metric results including various mAP/mAR values.""" + + __slots__ = ( + "map", + "map_50", + "map_75", + "map_small", + "map_medium", + "map_large", + "mar_1", + "mar_10", + "mar_100", + "mar_small", + "mar_medium", + "mar_large", + "map_per_class", + "mar_100_per_class", + ) + + +def _segm_iou(det: List[Tuple[np.ndarray, np.ndarray]], gt: List[Tuple[np.ndarray, np.ndarray]]) -> Tensor: + """Compute IOU between detections and ground-truths using mask-IOU. + + Implementation is based on pycocotools toolkit for mask_utils. + + Args: + det: A list of detection masks as ``[(RLE_SIZE, RLE_COUNTS)]``, where ``RLE_SIZE`` is (width, height) dimension + of the input and RLE_COUNTS is its RLE representation; + + gt: A list of ground-truth masks as ``[(RLE_SIZE, RLE_COUNTS)]``, where ``RLE_SIZE`` is (width, height) dimension + of the input and RLE_COUNTS is its RLE representation; + + """ + det_coco_format = [{"size": i[0], "counts": i[1]} for i in det] + gt_coco_format = [{"size": i[0], "counts": i[1]} for i in gt] + + return torch.tensor(mask_utils.iou(det_coco_format, gt_coco_format, [False for _ in gt])) + + +class MeanAveragePrecision(Metric): + r"""Compute the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)`_ for object detection predictions. + + .. math:: + \text{mAP} = \frac{1}{n} \sum_{i=1}^{n} AP_i + + where :math:`AP_i` is the average precision for class :math:`i` and :math:`n` is the number of classes. The average + precision is defined as the area under the precision-recall curve. If argument `class_metrics` is set to ``True``, + the metric will also return the mAP/mAR per class. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~List`): A list consisting of dictionaries each containing the key-values + (each dictionary corresponds to a single image). Parameters that should be provided per dict + + - boxes: (:class:`~torch.FloatTensor`) of shape ``(num_boxes, 4)`` containing ``num_boxes`` detection + boxes of the format specified in the constructor. + By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. + - scores: :class:`~torch.FloatTensor` of shape ``(num_boxes)`` containing detection scores for the boxes. + - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed detection classes for + the boxes. + - masks: :class:`~torch.bool` of shape ``(num_boxes, image_height, image_width)`` containing boolean masks. + Only required when `iou_type="segm"`. + + - ``target`` (:class:`~List`) A list consisting of dictionaries each containing the key-values + (each dictionary corresponds to a single image). Parameters that should be provided per dict: + + - boxes: :class:`~torch.FloatTensor` of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground truth + boxes of the format specified in the constructor. + By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. + - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed ground truth + classes for the boxes. + - masks: :class:`~torch.bool` of shape ``(num_boxes, image_height, image_width)`` containing boolean masks. + Only required when `iou_type="segm"`. + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``map_dict``: A dictionary containing the following key-values: + + - map: (:class:`~torch.Tensor`) + - map_small: (:class:`~torch.Tensor`) + - map_medium:(:class:`~torch.Tensor`) + - map_large: (:class:`~torch.Tensor`) + - mar_1: (:class:`~torch.Tensor`) + - mar_10: (:class:`~torch.Tensor`) + - mar_100: (:class:`~torch.Tensor`) + - mar_small: (:class:`~torch.Tensor`) + - mar_medium: (:class:`~torch.Tensor`) + - mar_large: (:class:`~torch.Tensor`) + - map_50: (:class:`~torch.Tensor`) (-1 if 0.5 not in the list of iou thresholds) + - map_75: (:class:`~torch.Tensor`) (-1 if 0.75 not in the list of iou thresholds) + - map_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled) + - mar_100_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled) + - classes (:class:`~torch.Tensor`) + + For an example on how to use this metric check the `torchmetrics mAP example`_. + + .. note:: + ``map`` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ]. + Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well. + The default properties are also accessible via fields and will raise an ``AttributeError`` if not available. + + .. note:: + This metric is following the mAP implementation of + `pycocotools `_, + a standard implementation for the mAP metric for object detection. + + .. note:: + This metric requires you to have `torchvision` version 0.8.0 or newer installed + (with corresponding version 1.7.0 of torch or newer). This metric requires `pycocotools` + installed when iou_type is `segm`. Please install with ``pip install torchvision`` or + ``pip install torchmetrics[detection]``. + + Args: + box_format: + Input format of given boxes. Supported formats are ``[`xyxy`, `xywh`, `cxcywh`]``. + iou_type: + Type of input (either masks or bounding-boxes) used for computing IOU. + Supported IOU types are ``["bbox", "segm"]``. + If using ``"segm"``, masks should be provided (see :meth:`update`). + iou_thresholds: + IoU thresholds for evaluation. If set to ``None`` it corresponds to the stepped range ``[0.5,...,0.95]`` + with step ``0.05``. Else provide a list of floats. + rec_thresholds: + Recall thresholds for evaluation. If set to ``None`` it corresponds to the stepped range ``[0,...,1]`` + with step ``0.01``. Else provide a list of floats. + max_detection_thresholds: + Thresholds on max detections per image. If set to `None` will use thresholds ``[1, 10, 100]``. + Else, please provide a list of ints. + class_metrics: + Option to enable per-class metrics for mAP and mAR_100. Has a performance impact. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ModuleNotFoundError: + If ``torchvision`` is not installed or version installed is lower than 0.8.0 + ModuleNotFoundError: + If ``iou_type`` is equal to ``segm`` and ``pycocotools`` is not installed + ValueError: + If ``class_metrics`` is not a boolean + ValueError: + If ``preds`` is not of type (:class:`~List[Dict[str, Tensor]]`) + ValueError: + If ``target`` is not of type ``List[Dict[str, Tensor]]`` + ValueError: + If ``preds`` and ``target`` are not of the same length + ValueError: + If any of ``preds.boxes``, ``preds.scores`` and ``preds.labels`` are not of the same length + ValueError: + If any of ``target.boxes`` and ``target.labels`` are not of the same length + ValueError: + If any box is not type float and of length 4 + ValueError: + If any class is not type int and of length 1 + ValueError: + If any score is not type float and of length 1 + + Example: + >>> from torch import tensor + >>> from torchmetrics.detection import MeanAveragePrecision + >>> preds = [ + ... dict( + ... boxes=tensor([[258.0, 41.0, 606.0, 285.0]]), + ... scores=tensor([0.536]), + ... labels=tensor([0]), + ... ) + ... ] + >>> target = [ + ... dict( + ... boxes=tensor([[214.0, 41.0, 562.0, 285.0]]), + ... labels=tensor([0]), + ... ) + ... ] + >>> metric = MeanAveragePrecision() + >>> metric.update(preds, target) + >>> from pprint import pprint + >>> pprint(metric.compute()) + {'classes': tensor(0, dtype=torch.int32), + 'map': tensor(0.6000), + 'map_50': tensor(1.), + 'map_75': tensor(1.), + 'map_large': tensor(0.6000), + 'map_medium': tensor(-1.), + 'map_per_class': tensor(-1.), + 'map_small': tensor(-1.), + 'mar_1': tensor(0.6000), + 'mar_10': tensor(0.6000), + 'mar_100': tensor(0.6000), + 'mar_100_per_class': tensor(-1.), + 'mar_large': tensor(0.6000), + 'mar_medium': tensor(-1.), + 'mar_small': tensor(-1.)} + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = True + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + detections: List[Tensor] + detection_scores: List[Tensor] + detection_labels: List[Tensor] + groundtruths: List[Tensor] + groundtruth_labels: List[Tensor] + + def __init__( + self, + box_format: str = "xyxy", + iou_type: str = "bbox", + iou_thresholds: Optional[List[float]] = None, + rec_thresholds: Optional[List[float]] = None, + max_detection_thresholds: Optional[List[int]] = None, + class_metrics: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + if not _TORCHVISION_GREATER_EQUAL_0_8: + raise ModuleNotFoundError( + "`MeanAveragePrecision` metric requires that `torchvision` version 0.8.0 or newer is installed." + " Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`." + ) + + allowed_box_formats = ("xyxy", "xywh", "cxcywh") + allowed_iou_types = ("segm", "bbox") + if box_format not in allowed_box_formats: + raise ValueError(f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}") + self.box_format = box_format + self.iou_thresholds = iou_thresholds or torch.linspace(0.5, 0.95, round((0.95 - 0.5) / 0.05) + 1).tolist() + self.rec_thresholds = rec_thresholds or torch.linspace(0.0, 1.00, round(1.00 / 0.01) + 1).tolist() + max_det_thr, _ = torch.sort(IntTensor(max_detection_thresholds or [1, 10, 100])) + self.max_detection_thresholds = max_det_thr.tolist() + if iou_type not in allowed_iou_types: + raise ValueError(f"Expected argument `iou_type` to be one of {allowed_iou_types} but got {iou_type}") + if iou_type == "segm" and not _PYCOCOTOOLS_AVAILABLE: + raise ModuleNotFoundError("When `iou_type` is set to 'segm', pycocotools need to be installed") + self.iou_type = iou_type + self.bbox_area_ranges = { + "all": (float(0**2), float(1e5**2)), + "small": (float(0**2), float(32**2)), + "medium": (float(32**2), float(96**2)), + "large": (float(96**2), float(1e5**2)), + } + + if not isinstance(class_metrics, bool): + raise ValueError("Expected argument `class_metrics` to be a boolean") + + self.class_metrics = class_metrics + self.add_state("detections", default=[], dist_reduce_fx=None) + self.add_state("detection_scores", default=[], dist_reduce_fx=None) + self.add_state("detection_labels", default=[], dist_reduce_fx=None) + self.add_state("groundtruths", default=[], dist_reduce_fx=None) + self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) + + def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: + """Update state with predictions and targets.""" + _input_validator(preds, target, iou_type=self.iou_type) + + for item in preds: + detections = self._get_safe_item_values(item) + + self.detections.append(detections) + self.detection_labels.append(item["labels"]) + self.detection_scores.append(item["scores"]) + + for item in target: + groundtruths = self._get_safe_item_values(item) + self.groundtruths.append(groundtruths) + self.groundtruth_labels.append(item["labels"]) + + def _move_list_states_to_cpu(self) -> None: + """Move list states to cpu to save GPU memory.""" + for key in self._defaults: + current_val = getattr(self, key) + current_to_cpu = [] + if isinstance(current_val, Sequence): + for cur_v in current_val: + # Cannot handle RLE as Tensor + if not isinstance(cur_v, tuple): + cur_v = cur_v.to("cpu") + current_to_cpu.append(cur_v) + setattr(self, key, current_to_cpu) + + def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: + if self.iou_type == "bbox": + boxes = _fix_empty_tensors(item["boxes"]) + if boxes.numel() > 0: + boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy") + return boxes + if self.iou_type == "segm": + masks = [] + for i in item["masks"].cpu().numpy(): + rle = mask_utils.encode(np.asfortranarray(i)) + masks.append((tuple(rle["size"]), rle["counts"])) + return tuple(masks) + raise Exception(f"IOU type {self.iou_type} is not supported") + + def _get_classes(self) -> List: + """Return a list of unique classes found in ground truth and detection data.""" + if len(self.detection_labels) > 0 or len(self.groundtruth_labels) > 0: + return torch.cat(self.detection_labels + self.groundtruth_labels).unique().tolist() + return [] + + def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor: + """Compute the Intersection over Union (IoU) between bounding boxes for the given image and class. + + Args: + idx: + Image Id, equivalent to the index of supplied samples + class_id: + Class Id of the supplied ground truth and detection labels + max_det: + Maximum number of evaluated detection bounding boxes + """ + # if self.iou_type == "bbox": + gt = self.groundtruths[idx] + det = self.detections[idx] + + gt_label_mask = (self.groundtruth_labels[idx] == class_id).nonzero().squeeze(1) + det_label_mask = (self.detection_labels[idx] == class_id).nonzero().squeeze(1) + + if len(gt_label_mask) == 0 or len(det_label_mask) == 0: + return Tensor([]) + + gt = [gt[i] for i in gt_label_mask] + det = [det[i] for i in det_label_mask] + + if len(gt) == 0 or len(det) == 0: + return Tensor([]) + + # Sort by scores and use only max detections + scores = self.detection_scores[idx] + scores_filtered = scores[self.detection_labels[idx] == class_id] + inds = torch.argsort(scores_filtered, descending=True) + + # TODO Fix (only for masks is necessary) + det = [det[i] for i in inds] + if len(det) > max_det: + det = det[:max_det] + + return compute_iou(det, gt, self.iou_type).to(self.device) + + def __evaluate_image_gt_no_preds( + self, gt: Tensor, gt_label_mask: Tensor, area_range: Tuple[int, int], nb_iou_thrs: int + ) -> Dict[str, Any]: + """Evaluate images with a ground truth but no predictions.""" + # GTs + gt = [gt[i] for i in gt_label_mask] + nb_gt = len(gt) + areas = compute_area(gt, iou_type=self.iou_type).to(self.device) + ignore_area = (areas < area_range[0]) | (areas > area_range[1]) + gt_ignore, _ = torch.sort(ignore_area.to(torch.uint8)) + gt_ignore = gt_ignore.to(torch.bool) + + # Detections + nb_det = 0 + det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device) + + return { + "dtMatches": torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device), + "gtMatches": torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device), + "dtScores": torch.zeros(nb_det, dtype=torch.float32, device=self.device), + "gtIgnore": gt_ignore, + "dtIgnore": det_ignore, + } + + def __evaluate_image_preds_no_gt( + self, det: Tensor, idx: int, det_label_mask: Tensor, max_det: int, area_range: Tuple[int, int], nb_iou_thrs: int + ) -> Dict[str, Any]: + """Evaluate images with a prediction but no ground truth.""" + # GTs + nb_gt = 0 + + gt_ignore = torch.zeros(nb_gt, dtype=torch.bool, device=self.device) + + # Detections + + det = [det[i] for i in det_label_mask] + scores = self.detection_scores[idx] + scores_filtered = scores[det_label_mask] + scores_sorted, dtind = torch.sort(scores_filtered, descending=True) + + det = [det[i] for i in dtind] + if len(det) > max_det: + det = det[:max_det] + nb_det = len(det) + det_areas = compute_area(det, iou_type=self.iou_type).to(self.device) + det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1]) + ar = det_ignore_area.reshape((1, nb_det)) + det_ignore = torch.repeat_interleave(ar, nb_iou_thrs, 0) + + return { + "dtMatches": torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device), + "gtMatches": torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device), + "dtScores": scores_sorted.to(self.device), + "gtIgnore": gt_ignore.to(self.device), + "dtIgnore": det_ignore.to(self.device), + } + + def _evaluate_image( + self, idx: int, class_id: int, area_range: Tuple[int, int], max_det: int, ious: dict + ) -> Optional[dict]: + """Perform evaluation for single class and image. + + Args: + idx: + Image Id, equivalent to the index of supplied samples. + class_id: + Class Id of the supplied ground truth and detection labels. + area_range: + List of lower and upper bounding box area threshold. + max_det: + Maximum number of evaluated detection bounding boxes. + ious: + IoU results for image and class. + """ + gt = self.groundtruths[idx] + det = self.detections[idx] + gt_label_mask = (self.groundtruth_labels[idx] == class_id).nonzero().squeeze(1) + det_label_mask = (self.detection_labels[idx] == class_id).nonzero().squeeze(1) + + # No Gt and No predictions --> ignore image + if len(gt_label_mask) == 0 and len(det_label_mask) == 0: + return None + + nb_iou_thrs = len(self.iou_thresholds) + + # Some GT but no predictions + if len(gt_label_mask) > 0 and len(det_label_mask) == 0: + return self.__evaluate_image_gt_no_preds(gt, gt_label_mask, area_range, nb_iou_thrs) + + # Some predictions but no GT + if len(gt_label_mask) == 0 and len(det_label_mask) >= 0: + return self.__evaluate_image_preds_no_gt(det, idx, det_label_mask, max_det, area_range, nb_iou_thrs) + + gt = [gt[i] for i in gt_label_mask] + det = [det[i] for i in det_label_mask] + if len(gt) == 0 and len(det) == 0: + return None + if isinstance(det, dict): + det = [det] + if isinstance(gt, dict): + gt = [gt] + + areas = compute_area(gt, iou_type=self.iou_type).to(self.device) + + ignore_area = torch.logical_or(areas < area_range[0], areas > area_range[1]) + + # sort dt highest score first, sort gt ignore last + ignore_area_sorted, gtind = torch.sort(ignore_area.to(torch.uint8)) + # Convert to uint8 temporarily and back to bool, because "Sort currently does not support bool dtype on CUDA" + + ignore_area_sorted = ignore_area_sorted.to(torch.bool).to(self.device) + + gt = [gt[i] for i in gtind] + scores = self.detection_scores[idx] + scores_filtered = scores[det_label_mask] + scores_sorted, dtind = torch.sort(scores_filtered, descending=True) + det = [det[i] for i in dtind] + if len(det) > max_det: + det = det[:max_det] + # load computed ious + ious = ious[idx, class_id][:, gtind] if len(ious[idx, class_id]) > 0 else ious[idx, class_id] + + nb_iou_thrs = len(self.iou_thresholds) + nb_gt = len(gt) + nb_det = len(det) + gt_matches = torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device) + det_matches = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device) + gt_ignore = ignore_area_sorted + det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device) + + if torch.numel(ious) > 0: + for idx_iou, t in enumerate(self.iou_thresholds): + for idx_det, _ in enumerate(det): + m = MeanAveragePrecision._find_best_gt_match(t, gt_matches, idx_iou, gt_ignore, ious, idx_det) + if m == -1: + continue + det_ignore[idx_iou, idx_det] = gt_ignore[m] + det_matches[idx_iou, idx_det] = 1 + gt_matches[idx_iou, m] = 1 + + # set unmatched detections outside of area range to ignore + det_areas = compute_area(det, iou_type=self.iou_type).to(self.device) + det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1]) + ar = det_ignore_area.reshape((1, nb_det)) + det_ignore = torch.logical_or( + det_ignore, torch.logical_and(det_matches == 0, torch.repeat_interleave(ar, nb_iou_thrs, 0)) + ) + + return { + "dtMatches": det_matches.to(self.device), + "gtMatches": gt_matches.to(self.device), + "dtScores": scores_sorted.to(self.device), + "gtIgnore": gt_ignore.to(self.device), + "dtIgnore": det_ignore.to(self.device), + } + + @staticmethod + def _find_best_gt_match( + thr: int, gt_matches: Tensor, idx_iou: float, gt_ignore: Tensor, ious: Tensor, idx_det: int + ) -> int: + """Return id of best ground truth match with current detection. + + Args: + thr: + Current threshold value. + gt_matches: + Tensor showing if a ground truth matches for threshold ``t`` exists. + idx_iou: + Id of threshold ``t``. + gt_ignore: + Tensor showing if ground truth should be ignored. + ious: + IoUs for all combinations of detection and ground truth. + idx_det: + Id of current detection. + """ + previously_matched = gt_matches[idx_iou] + # Remove previously matched or ignored gts + remove_mask = previously_matched | gt_ignore + gt_ious = ious[idx_det] * ~remove_mask + match_idx = gt_ious.argmax().item() + if gt_ious[match_idx] > thr: + return match_idx + return -1 + + def _summarize( + self, + results: Dict, + avg_prec: bool = True, + iou_threshold: Optional[float] = None, + area_range: str = "all", + max_dets: int = 100, + ) -> Tensor: + """Perform evaluation for single class and image. + + Args: + results: + Dictionary including precision, recall and scores for all combinations. + avg_prec: + Calculate average precision. Else calculate average recall. + iou_threshold: + IoU threshold. If set to ``None`` it all values are used. Else results are filtered. + area_range: + Bounding box area range key. + max_dets: + Maximum detections. + """ + area_inds = [i for i, k in enumerate(self.bbox_area_ranges.keys()) if k == area_range] + mdet_inds = [i for i, k in enumerate(self.max_detection_thresholds) if k == max_dets] + if avg_prec: + # dimension of precision: [TxRxKxAxM] + prec = results["precision"] + # IoU + if iou_threshold is not None: + thr = self.iou_thresholds.index(iou_threshold) + prec = prec[thr, :, :, area_inds, mdet_inds] + else: + prec = prec[:, :, :, area_inds, mdet_inds] + else: + # dimension of recall: [TxKxAxM] + prec = results["recall"] + if iou_threshold is not None: + thr = self.iou_thresholds.index(iou_threshold) + prec = prec[thr, :, :, area_inds, mdet_inds] + else: + prec = prec[:, :, area_inds, mdet_inds] + + return torch.tensor([-1.0]) if len(prec[prec > -1]) == 0 else torch.mean(prec[prec > -1]) + + def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResults]: + """Calculate the precision and recall for all supplied classes to calculate mAP/mAR. + + Args: + class_ids: + List of label class Ids. + """ + img_ids = range(len(self.groundtruths)) + max_detections = self.max_detection_thresholds[-1] + area_ranges = self.bbox_area_ranges.values() + + ious = { + (idx, class_id): self._compute_iou(idx, class_id, max_detections) + for idx in img_ids + for class_id in class_ids + } + + eval_imgs = [ + self._evaluate_image(img_id, class_id, area, max_detections, ious) + for class_id in class_ids + for area in area_ranges + for img_id in img_ids + ] + + nb_iou_thrs = len(self.iou_thresholds) + nb_rec_thrs = len(self.rec_thresholds) + nb_classes = len(class_ids) + nb_bbox_areas = len(self.bbox_area_ranges) + nb_max_det_thrs = len(self.max_detection_thresholds) + nb_imgs = len(img_ids) + precision = -torch.ones((nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs)) + recall = -torch.ones((nb_iou_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs)) + scores = -torch.ones((nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs)) + + # move tensors if necessary + rec_thresholds_tensor = torch.tensor(self.rec_thresholds) + + # retrieve E at each category, area range, and max number of detections + for idx_cls, _ in enumerate(class_ids): + for idx_bbox_area, _ in enumerate(self.bbox_area_ranges): + for idx_max_det_thrs, max_det in enumerate(self.max_detection_thresholds): + recall, precision, scores = MeanAveragePrecision.__calculate_recall_precision_scores( + recall, + precision, + scores, + idx_cls=idx_cls, + idx_bbox_area=idx_bbox_area, + idx_max_det_thrs=idx_max_det_thrs, + eval_imgs=eval_imgs, + rec_thresholds=rec_thresholds_tensor, + max_det=max_det, + nb_imgs=nb_imgs, + nb_bbox_areas=nb_bbox_areas, + ) + + return precision, recall + + def _summarize_results(self, precisions: Tensor, recalls: Tensor) -> Tuple[MAPMetricResults, MARMetricResults]: + """Summarizes the precision and recall values to calculate mAP/mAR. + + Args: + precisions: + Precision values for different thresholds + recalls: + Recall values for different thresholds + """ + results = {"precision": precisions, "recall": recalls} + map_metrics = MAPMetricResults() + last_max_det_thr = self.max_detection_thresholds[-1] + map_metrics.map = self._summarize(results, True, max_dets=last_max_det_thr) + if 0.5 in self.iou_thresholds: + map_metrics.map_50 = self._summarize(results, True, iou_threshold=0.5, max_dets=last_max_det_thr) + else: + map_metrics.map_50 = torch.tensor([-1]) + if 0.75 in self.iou_thresholds: + map_metrics.map_75 = self._summarize(results, True, iou_threshold=0.75, max_dets=last_max_det_thr) + else: + map_metrics.map_75 = torch.tensor([-1]) + map_metrics.map_small = self._summarize(results, True, area_range="small", max_dets=last_max_det_thr) + map_metrics.map_medium = self._summarize(results, True, area_range="medium", max_dets=last_max_det_thr) + map_metrics.map_large = self._summarize(results, True, area_range="large", max_dets=last_max_det_thr) + + mar_metrics = MARMetricResults() + for max_det in self.max_detection_thresholds: + mar_metrics[f"mar_{max_det}"] = self._summarize(results, False, max_dets=max_det) + mar_metrics.mar_small = self._summarize(results, False, area_range="small", max_dets=last_max_det_thr) + mar_metrics.mar_medium = self._summarize(results, False, area_range="medium", max_dets=last_max_det_thr) + mar_metrics.mar_large = self._summarize(results, False, area_range="large", max_dets=last_max_det_thr) + + return map_metrics, mar_metrics + + @staticmethod + def __calculate_recall_precision_scores( + recall: Tensor, + precision: Tensor, + scores: Tensor, + idx_cls: int, + idx_bbox_area: int, + idx_max_det_thrs: int, + eval_imgs: list, + rec_thresholds: Tensor, + max_det: int, + nb_imgs: int, + nb_bbox_areas: int, + ) -> Tuple[Tensor, Tensor, Tensor]: + nb_rec_thrs = len(rec_thresholds) + idx_cls_pointer = idx_cls * nb_bbox_areas * nb_imgs + idx_bbox_area_pointer = idx_bbox_area * nb_imgs + # Load all image evals for current class_id and area_range + img_eval_cls_bbox = [eval_imgs[idx_cls_pointer + idx_bbox_area_pointer + i] for i in range(nb_imgs)] + img_eval_cls_bbox = [e for e in img_eval_cls_bbox if e is not None] + if not img_eval_cls_bbox: + return recall, precision, scores + + det_scores = torch.cat([e["dtScores"][:max_det] for e in img_eval_cls_bbox]) + + # different sorting method generates slightly different results. + # mergesort is used to be consistent as Matlab implementation. + # Sort in PyTorch does not support bool types on CUDA (yet, 1.11.0) + dtype = torch.uint8 if det_scores.is_cuda and det_scores.dtype is torch.bool else det_scores.dtype + # Explicitly cast to uint8 to avoid error for bool inputs on CUDA to argsort + inds = torch.argsort(det_scores.to(dtype), descending=True) + det_scores_sorted = det_scores[inds] + + det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] + det_ignore = torch.cat([e["dtIgnore"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] + gt_ignore = torch.cat([e["gtIgnore"] for e in img_eval_cls_bbox]) + npig = torch.count_nonzero(gt_ignore == False) # noqa: E712 + if npig == 0: + return recall, precision, scores + tps = torch.logical_and(det_matches, torch.logical_not(det_ignore)) + fps = torch.logical_and(torch.logical_not(det_matches), torch.logical_not(det_ignore)) + + tp_sum = _cumsum(tps, dim=1, dtype=torch.float) + fp_sum = _cumsum(fps, dim=1, dtype=torch.float) + for idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): + nd = len(tp) + rc = tp / npig + pr = tp / (fp + tp + torch.finfo(torch.float64).eps) + prec = torch.zeros((nb_rec_thrs,)) + score = torch.zeros((nb_rec_thrs,)) + + recall[idx, idx_cls, idx_bbox_area, idx_max_det_thrs] = rc[-1] if nd else 0 + + # Remove zigzags for AUC + diff_zero = torch.zeros((1,), device=pr.device) + diff = torch.ones((1,), device=pr.device) + while not torch.all(diff == 0): + diff = torch.clamp(torch.cat(((pr[1:] - pr[:-1]), diff_zero), 0), min=0) + pr += diff + + inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False) + num_inds = inds.argmax() if inds.max() >= nd else nb_rec_thrs + inds = inds[:num_inds] + prec[:num_inds] = pr[inds] + score[:num_inds] = det_scores_sorted[inds] + precision[idx, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = prec + scores[idx, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = score + + return recall, precision, scores + + def compute(self) -> dict: + """Compute metric.""" + classes = self._get_classes() + precisions, recalls = self._calculate(classes) + map_val, mar_val = self._summarize_results(precisions, recalls) + + # if class mode is enabled, evaluate metrics per class + map_per_class_values: Tensor = torch.tensor([-1.0]) + mar_max_dets_per_class_values: Tensor = torch.tensor([-1.0]) + if self.class_metrics: + map_per_class_list = [] + mar_max_dets_per_class_list = [] + + for class_idx, _ in enumerate(classes): + cls_precisions = precisions[:, :, class_idx].unsqueeze(dim=2) + cls_recalls = recalls[:, class_idx].unsqueeze(dim=1) + cls_map, cls_mar = self._summarize_results(cls_precisions, cls_recalls) + map_per_class_list.append(cls_map.map) + mar_max_dets_per_class_list.append(cls_mar[f"mar_{self.max_detection_thresholds[-1]}"]) + + map_per_class_values = torch.tensor(map_per_class_list, dtype=torch.float) + mar_max_dets_per_class_values = torch.tensor(mar_max_dets_per_class_list, dtype=torch.float) + + metrics = COCOMetricResults() + metrics.update(map_val) + metrics.update(mar_val) + metrics.map_per_class = map_per_class_values + metrics[f"mar_{self.max_detection_thresholds[-1]}_per_class"] = mar_max_dets_per_class_values + metrics.classes = torch.tensor(classes, dtype=torch.int) + return metrics + + def _apply(self, fn: Callable) -> torch.nn.Module: + """Custom apply function. + + Excludes the detections and groundtruths from the casting when the iou_type is set to `segm` as the state is + no longer a tensor but a tuple. + """ + if self.iou_type == "segm": + this = super()._apply(fn, exclude_state=("detections", "groundtruths")) + else: + this = super()._apply(fn) + return this + + def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None: + """Custom sync function. + + For the iou_type `segm` the detections and groundtruths are no longer tensors but tuples. Therefore, we need + to gather the list of tuples and then convert it back to a list of tuples. + + """ + super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) + + if self.iou_type == "segm": + self.detections = self._gather_tuple_list(self.detections, process_group) + self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group) + + @staticmethod + def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]: + """Gather a list of tuples over multiple devices.""" + world_size = dist.get_world_size(group=process_group) + dist.barrier(group=process_group) + + list_gathered = [None for _ in range(world_size)] + dist.all_gather_object(list_gathered, list_to_gather, group=process_group) + + list_merged = [] + for idx in range(len(list_gathered[0])): + for rank in range(world_size): + list_merged.append(list_gathered[rank][idx]) + + return list_merged + + def plot( + self, val: Optional[Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import tensor + >>> from torchmetrics.detection.mean_ap import MeanAveragePrecision + >>> preds = [dict( + ... boxes=tensor([[258.0, 41.0, 606.0, 285.0]]), + ... scores=tensor([0.536]), + ... labels=tensor([0]), + ... )] + >>> target = [dict( + ... boxes=tensor([[214.0, 41.0, 562.0, 285.0]]), + ... labels=tensor([0]), + ... )] + >>> metric = MeanAveragePrecision() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.detection.mean_ap import MeanAveragePrecision + >>> preds = lambda: [dict( + ... boxes=torch.tensor([[258.0, 41.0, 606.0, 285.0]]) + torch.randint(10, (1,4)), + ... scores=torch.tensor([0.536]) + 0.1*torch.rand(1), + ... labels=torch.tensor([0]), + ... )] + >>> target = [dict( + ... boxes=torch.tensor([[214.0, 41.0, 562.0, 285.0]]), + ... labels=torch.tensor([0]), + ... )] + >>> metric = MeanAveragePrecision() + >>> vals = [] + >>> for _ in range(20): + ... vals.append(metric(preds(), target)) + >>> fig_, ax_ = metric.plot(vals) + """ + return self._plot(val, ax) From f5ba0fdcdbf3e41634aa57ad1f39cd3057de34e7 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 1 Jul 2023 15:21:01 +0200 Subject: [PATCH 16/16] ignore mypy --- pyproject.toml | 1 + src/torchmetrics/detection/mean_ap.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a2743a85e28..496242f3a87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,6 +184,7 @@ module = [ "torchmetrics.classification.roc", "torchmetrics.classification.specificity", "torchmetrics.classification.stat_scores", + "torchmetrics.detection._mean_ap", "torchmetrics.detection.mean_ap", "torchmetrics.functional.classification.calibration_error", "torchmetrics.functional.classification.confusion_matrix", diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 69cbb0c457c..754b3aff09f 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -295,7 +295,7 @@ def __init__( self.add_state("groundtruth_crowds", default=[], dist_reduce_fx=None) self.add_state("groundtruth_area", default=[], dist_reduce_fx=None) - def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore + def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: """Update metric state. Raises: @@ -373,8 +373,8 @@ def compute(self) -> dict: map_per_class_values = torch.tensor(map_per_class_list, dtype=torch.float32) mar_100_per_class_values = torch.tensor(mar_100_per_class_list, dtype=torch.float32) else: - map_per_class_values: Tensor = torch.tensor([-1], dtype=torch.float32) - mar_100_per_class_values: Tensor = torch.tensor([-1], dtype=torch.float32) + map_per_class_values = torch.tensor([-1], dtype=torch.float32) + mar_100_per_class_values = torch.tensor([-1], dtype=torch.float32) return { "map": torch.tensor([stats[0]], dtype=torch.float32), @@ -693,7 +693,7 @@ def plot( # specialized syncronization and apply functions for this metric # -------------------- - def _apply(self, fn: Callable) -> torch.nn.Module: + def _apply(self, fn: Callable) -> torch.nn.Module: # type: ignore[override] """Custom apply function. Excludes the detections and groundtruths from the casting when the iou_type is set to `segm` as the state is