Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add argument average to MeanAveragePrecision #2018

Merged
merged 13 commits into from
Aug 28, 2023
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `average` argument to `MeanAveragePrecision` ([#2018](https://github.com/Lightning-AI/torchmetrics/pull/2018)


### Changed
Expand Down
69 changes: 51 additions & 18 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ class MeanAveragePrecision(Metric):
IoU thresholds, ``K`` is the number of classes, ``A`` is the number of areas and ``M`` is the number
of max detections per image.

average:
Method for averaging scores over labels. Choose between "``macro``"" and "``micro``". Default is "macro"

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand Down Expand Up @@ -329,6 +332,7 @@ def __init__(
max_detection_thresholds: Optional[List[int]] = None,
class_metrics: bool = False,
extended_summary: bool = False,
average: Literal["macro", "micro"] = "macro",
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -379,6 +383,10 @@ def __init__(
raise ValueError("Expected argument `extended_summary` to be a boolean")
self.extended_summary = extended_summary

if average not in ("macro", "micro"):
raise ValueError(f"Expected argument `average` to be one of ('macro', 'micro') but got {average}")
self.average = average

self.add_state("detection_box", default=[], dist_reduce_fx=None)
self.add_state("detection_mask", default=[], dist_reduce_fx=None)
self.add_state("detection_scores", default=[], dist_reduce_fx=None)
Expand Down Expand Up @@ -434,27 +442,10 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]

def compute(self) -> dict:
"""Computes the metric."""
coco_target, coco_preds = COCO(), COCO()

coco_target.dataset = self._get_coco_format(
labels=self.groundtruth_labels,
boxes=self.groundtruth_box if len(self.groundtruth_box) > 0 else None,
masks=self.groundtruth_mask if len(self.groundtruth_mask) > 0 else None,
crowds=self.groundtruth_crowds,
area=self.groundtruth_area,
)
coco_preds.dataset = self._get_coco_format(
labels=self.detection_labels,
boxes=self.detection_box if len(self.detection_box) > 0 else None,
masks=self.detection_mask if len(self.detection_mask) > 0 else None,
scores=self.detection_scores,
)
coco_preds, coco_target = self._get_coco_datasets(average=self.average)

result_dict = {}
with contextlib.redirect_stdout(io.StringIO()):
coco_target.createIndex()
coco_preds.createIndex()

for i_type in self.iou_type:
prefix = "" if len(self.iou_type) == 1 else f"{i_type}_"
if len(self.iou_type) > 1:
Expand Down Expand Up @@ -487,6 +478,15 @@ def compute(self) -> dict:

# if class mode is enabled, evaluate metrics per class
if self.class_metrics:
if self.average == "micro":
# since micro averaging have all the data in one class, we need to reinitialize the coco_eval
# object in macro mode to get the per class stats
coco_preds, coco_target = self._get_coco_datasets(average="macro")
coco_eval = COCOeval(coco_target, coco_preds, iouType=i_type)
coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64)
coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64)
coco_eval.params.maxDets = self.max_detection_thresholds

map_per_class_list = []
mar_100_per_class_list = []
for class_id in self._get_classes():
Expand Down Expand Up @@ -516,8 +516,41 @@ def compute(self) -> dict:

return result_dict

def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[COCO, COCO]:
"""Returns the coco datasets for the target and the predictions."""
if average == "micro":
# for micro averaging we set everything to be the same class
groundtruth_labels = apply_to_collection(self.groundtruth_labels, Tensor, lambda x: torch.zeros_like(x))
detection_labels = apply_to_collection(self.detection_labels, Tensor, lambda x: torch.zeros_like(x))
else:
groundtruth_labels = self.groundtruth_labels
detection_labels = self.detection_labels

coco_target, coco_preds = COCO(), COCO()

coco_target.dataset = self._get_coco_format(
labels=groundtruth_labels,
boxes=self.groundtruth_box if len(self.groundtruth_box) > 0 else None,
masks=self.groundtruth_mask if len(self.groundtruth_mask) > 0 else None,
crowds=self.groundtruth_crowds,
area=self.groundtruth_area,
)
coco_preds.dataset = self._get_coco_format(
labels=detection_labels,
boxes=self.detection_box if len(self.detection_box) > 0 else None,
masks=self.detection_mask if len(self.detection_mask) > 0 else None,
scores=self.detection_scores,
)

with contextlib.redirect_stdout(io.StringIO()):
coco_target.createIndex()
coco_preds.createIndex()

return coco_preds, coco_target

@staticmethod
def _coco_stats_to_tensor_dict(stats: List[float], prefix: str) -> Dict[str, Tensor]:
"""Converts the output of COCOeval.stats to a dict of tensors."""
return {
f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32),
f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32),
Expand Down
36 changes: 36 additions & 0 deletions tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
import pytest
import torch
from lightning_utilities import apply_to_collection
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from torch import IntTensor, Tensor
Expand Down Expand Up @@ -794,3 +795,38 @@ def test_for_extended_stats(preds, target, expected_iou_len, iou_keys, precision
recall = result["recall"]
assert isinstance(recall, Tensor)
assert recall.shape == recall_shape


@pytest.mark.parametrize("class_metrics", [False, True])
def test_average_argument(class_metrics):
"""Test that average argument works.

Calculating macro on inputs that only have one label should be the same as micro. Calculating class metrics should
be the same regardless of average argument.

"""
if class_metrics:
_preds = _inputs.preds
_target = _inputs.target
else:
_preds = apply_to_collection(_inputs.preds, IntTensor, lambda x: torch.ones_like(x))
_target = apply_to_collection(_inputs.target, IntTensor, lambda x: torch.ones_like(x))

metric_macro = MeanAveragePrecision(average="macro", class_metrics=class_metrics)
metric_macro.update(_preds[0], _target[0])
metric_macro.update(_preds[1], _target[1])
result_macro = metric_macro.compute()

metric_micro = MeanAveragePrecision(average="micro", class_metrics=class_metrics)
metric_micro.update(_inputs.preds[0], _inputs.target[0])
metric_micro.update(_inputs.preds[1], _inputs.target[1])
result_micro = metric_micro.compute()

if class_metrics:
assert torch.allclose(result_macro["map_per_class"], result_micro["map_per_class"])
assert torch.allclose(result_macro["mar_100_per_class"], result_micro["mar_100_per_class"])
else:
for key in result_macro:
if key == "classes":
continue
assert torch.allclose(result_macro[key], result_micro[key])
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved