Skip to content

Commit c63de49

Browse files
authored
Merge branch 'master' into fix/same_dict_metriccollection
2 parents 46883d7 + 63c7bbe commit c63de49

File tree

3 files changed

+102
-35
lines changed

3 files changed

+102
-35
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
### Added
1313

14+
- Added `average` argument to `MeanAveragePrecision` ([#2018](https://github.com/Lightning-AI/torchmetrics/pull/2018)
15+
16+
1417
- Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008)
1518

1619

src/torchmetrics/detection/mean_ap.py

+51-18
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ class MeanAveragePrecision(Metric):
187187
IoU thresholds, ``K`` is the number of classes, ``A`` is the number of areas and ``M`` is the number
188188
of max detections per image.
189189
190+
average:
191+
Method for averaging scores over labels. Choose between "``macro``"" and "``micro``". Default is "macro"
192+
190193
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
191194
192195
Raises:
@@ -329,6 +332,7 @@ def __init__(
329332
max_detection_thresholds: Optional[List[int]] = None,
330333
class_metrics: bool = False,
331334
extended_summary: bool = False,
335+
average: Literal["macro", "micro"] = "macro",
332336
**kwargs: Any,
333337
) -> None:
334338
super().__init__(**kwargs)
@@ -379,6 +383,10 @@ def __init__(
379383
raise ValueError("Expected argument `extended_summary` to be a boolean")
380384
self.extended_summary = extended_summary
381385

386+
if average not in ("macro", "micro"):
387+
raise ValueError(f"Expected argument `average` to be one of ('macro', 'micro') but got {average}")
388+
self.average = average
389+
382390
self.add_state("detection_box", default=[], dist_reduce_fx=None)
383391
self.add_state("detection_mask", default=[], dist_reduce_fx=None)
384392
self.add_state("detection_scores", default=[], dist_reduce_fx=None)
@@ -434,27 +442,10 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
434442

435443
def compute(self) -> dict:
436444
"""Computes the metric."""
437-
coco_target, coco_preds = COCO(), COCO()
438-
439-
coco_target.dataset = self._get_coco_format(
440-
labels=self.groundtruth_labels,
441-
boxes=self.groundtruth_box if len(self.groundtruth_box) > 0 else None,
442-
masks=self.groundtruth_mask if len(self.groundtruth_mask) > 0 else None,
443-
crowds=self.groundtruth_crowds,
444-
area=self.groundtruth_area,
445-
)
446-
coco_preds.dataset = self._get_coco_format(
447-
labels=self.detection_labels,
448-
boxes=self.detection_box if len(self.detection_box) > 0 else None,
449-
masks=self.detection_mask if len(self.detection_mask) > 0 else None,
450-
scores=self.detection_scores,
451-
)
445+
coco_preds, coco_target = self._get_coco_datasets(average=self.average)
452446

453447
result_dict = {}
454448
with contextlib.redirect_stdout(io.StringIO()):
455-
coco_target.createIndex()
456-
coco_preds.createIndex()
457-
458449
for i_type in self.iou_type:
459450
prefix = "" if len(self.iou_type) == 1 else f"{i_type}_"
460451
if len(self.iou_type) > 1:
@@ -487,6 +478,15 @@ def compute(self) -> dict:
487478

488479
# if class mode is enabled, evaluate metrics per class
489480
if self.class_metrics:
481+
if self.average == "micro":
482+
# since micro averaging have all the data in one class, we need to reinitialize the coco_eval
483+
# object in macro mode to get the per class stats
484+
coco_preds, coco_target = self._get_coco_datasets(average="macro")
485+
coco_eval = COCOeval(coco_target, coco_preds, iouType=i_type)
486+
coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64)
487+
coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64)
488+
coco_eval.params.maxDets = self.max_detection_thresholds
489+
490490
map_per_class_list = []
491491
mar_100_per_class_list = []
492492
for class_id in self._get_classes():
@@ -516,8 +516,41 @@ def compute(self) -> dict:
516516

517517
return result_dict
518518

519+
def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[COCO, COCO]:
520+
"""Returns the coco datasets for the target and the predictions."""
521+
if average == "micro":
522+
# for micro averaging we set everything to be the same class
523+
groundtruth_labels = apply_to_collection(self.groundtruth_labels, Tensor, lambda x: torch.zeros_like(x))
524+
detection_labels = apply_to_collection(self.detection_labels, Tensor, lambda x: torch.zeros_like(x))
525+
else:
526+
groundtruth_labels = self.groundtruth_labels
527+
detection_labels = self.detection_labels
528+
529+
coco_target, coco_preds = COCO(), COCO()
530+
531+
coco_target.dataset = self._get_coco_format(
532+
labels=groundtruth_labels,
533+
boxes=self.groundtruth_box if len(self.groundtruth_box) > 0 else None,
534+
masks=self.groundtruth_mask if len(self.groundtruth_mask) > 0 else None,
535+
crowds=self.groundtruth_crowds,
536+
area=self.groundtruth_area,
537+
)
538+
coco_preds.dataset = self._get_coco_format(
539+
labels=detection_labels,
540+
boxes=self.detection_box if len(self.detection_box) > 0 else None,
541+
masks=self.detection_mask if len(self.detection_mask) > 0 else None,
542+
scores=self.detection_scores,
543+
)
544+
545+
with contextlib.redirect_stdout(io.StringIO()):
546+
coco_target.createIndex()
547+
coco_preds.createIndex()
548+
549+
return coco_preds, coco_target
550+
519551
@staticmethod
520552
def _coco_stats_to_tensor_dict(stats: List[float], prefix: str) -> Dict[str, Tensor]:
553+
"""Converts the output of COCOeval.stats to a dict of tensors."""
521554
return {
522555
f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32),
523556
f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32),

tests/unittests/detection/test_map.py

+48-17
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
import pytest
2424
import torch
25+
from lightning_utilities import apply_to_collection
2526
from pycocotools.coco import COCO
2627
from pycocotools.cocoeval import COCOeval
2728
from torch import IntTensor, Tensor
@@ -474,37 +475,32 @@ def test_empty_preds_cxcywh():
474475
metric.compute()
475476

476477

477-
_gpu_test_condition = not torch.cuda.is_available()
478-
479-
480-
def _move_to_gpu(inputs):
481-
for x in inputs:
482-
for key in x:
483-
if torch.is_tensor(x[key]):
484-
x[key] = x[key].to("cuda")
485-
return inputs
486-
487-
488478
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
489-
@pytest.mark.skipif(_gpu_test_condition, reason="test requires CUDA availability")
479+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires CUDA availability")
490480
@pytest.mark.parametrize("inputs", [_inputs, _inputs2, _inputs3])
491481
def test_map_gpu(inputs):
492482
"""Test predictions on single gpu."""
493483
metric = MeanAveragePrecision()
494484
metric = metric.to("cuda")
495-
for preds, targets in zip(inputs.preds, inputs.target):
496-
metric.update(_move_to_gpu(preds), _move_to_gpu(targets))
485+
for preds, targets in zip(deepcopy(inputs.preds), deepcopy(inputs.target)):
486+
metric.update(
487+
apply_to_collection(preds, Tensor, lambda x: x.to("cuda")),
488+
apply_to_collection(targets, Tensor, lambda x: x.to("cuda")),
489+
)
497490
metric.compute()
498491

499492

500493
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
501-
@pytest.mark.skipif(_gpu_test_condition, reason="test requires CUDA availability")
494+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires CUDA availability")
502495
def test_map_with_custom_thresholds():
503496
"""Test that map works with custom iou thresholds."""
504497
metric = MeanAveragePrecision(iou_thresholds=[0.1, 0.2])
505498
metric = metric.to("cuda")
506-
for preds, targets in zip(_inputs.preds, _inputs.target):
507-
metric.update(_move_to_gpu(preds), _move_to_gpu(targets))
499+
for preds, targets in zip(deepcopy(_inputs.preds), deepcopy(_inputs.target)):
500+
metric.update(
501+
apply_to_collection(preds, Tensor, lambda x: x.to("cuda")),
502+
apply_to_collection(targets, Tensor, lambda x: x.to("cuda")),
503+
)
508504
res = metric.compute()
509505
assert res["map_50"].item() == -1
510506
assert res["map_75"].item() == -1
@@ -794,3 +790,38 @@ def test_for_extended_stats(preds, target, expected_iou_len, iou_keys, precision
794790
recall = result["recall"]
795791
assert isinstance(recall, Tensor)
796792
assert recall.shape == recall_shape
793+
794+
795+
@pytest.mark.parametrize("class_metrics", [False, True])
796+
def test_average_argument(class_metrics):
797+
"""Test that average argument works.
798+
799+
Calculating macro on inputs that only have one label should be the same as micro. Calculating class metrics should
800+
be the same regardless of average argument.
801+
802+
"""
803+
if class_metrics:
804+
_preds = _inputs.preds
805+
_target = _inputs.target
806+
else:
807+
_preds = apply_to_collection(deepcopy(_inputs.preds), IntTensor, lambda x: torch.ones_like(x))
808+
_target = apply_to_collection(deepcopy(_inputs.target), IntTensor, lambda x: torch.ones_like(x))
809+
810+
metric_macro = MeanAveragePrecision(average="macro", class_metrics=class_metrics)
811+
metric_macro.update(_preds[0], _target[0])
812+
metric_macro.update(_preds[1], _target[1])
813+
result_macro = metric_macro.compute()
814+
815+
metric_micro = MeanAveragePrecision(average="micro", class_metrics=class_metrics)
816+
metric_micro.update(_inputs.preds[0], _inputs.target[0])
817+
metric_micro.update(_inputs.preds[1], _inputs.target[1])
818+
result_micro = metric_micro.compute()
819+
820+
if class_metrics:
821+
assert torch.allclose(result_macro["map_per_class"], result_micro["map_per_class"])
822+
assert torch.allclose(result_macro["mar_100_per_class"], result_micro["mar_100_per_class"])
823+
else:
824+
for key in result_macro:
825+
if key == "classes":
826+
continue
827+
assert torch.allclose(result_macro[key], result_micro[key])

0 commit comments

Comments
 (0)