Skip to content

Commit

Permalink
Fix bug in MeanAveragePrecision.coco_to_tm (#2588)
Browse files Browse the repository at this point in the history
* fix implementation

* tests

---------

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
SkafteNicki and Borda authored Jun 5, 2024
1 parent 1b77689 commit 8172c58
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
26 changes: 20 additions & 6 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def coco_to_tm(
>>> # https://github.com/cocodataset/cocoapi/tree/master/results
>>> from torchmetrics.detection import MeanAveragePrecision
>>> preds, target = MeanAveragePrecision.coco_to_tm(
... "instances_val2014_fakebbox100_results.json.json",
... "instances_val2014_fakebbox100_results.json",
... "val2014_fake_eval_res.txt.json"
... iou_type="bbox"
... ) # doctest: +SKIP
Expand Down Expand Up @@ -775,21 +775,35 @@ def tm_to_coco(self, name: str = "tm_map_input") -> None:
... labels=tensor([0]),
... )
... ]
>>> metric = MeanAveragePrecision()
>>> metric = MeanAveragePrecision(iou_type="bbox")
>>> metric.update(preds, target)
>>> metric.tm_to_coco("tm_map_input") # doctest: +SKIP
>>> metric.tm_to_coco("tm_map_input")
"""
target_dataset = self._get_coco_format(
labels=self.groundtruth_labels,
boxes=self.groundtruth_box,
masks=self.groundtruth_mask,
boxes=self.groundtruth_box if len(self.groundtruth_box) > 0 else None,
masks=self.groundtruth_mask if len(self.groundtruth_mask) > 0 else None,
crowds=self.groundtruth_crowds,
area=self.groundtruth_area,
)
preds_dataset = self._get_coco_format(
labels=self.detection_labels, boxes=self.detection_box, masks=self.detection_mask
labels=self.detection_labels,
boxes=self.detection_box if len(self.detection_box) > 0 else None,
masks=self.detection_mask if len(self.detection_mask) > 0 else None,
scores=self.detection_scores,
)
if "segm" in self.iou_type:
# the rle masks needs to be decoded to be written to a file
preds_dataset["annotations"] = apply_to_collection(
preds_dataset["annotations"], dtype=bytes, function=lambda x: x.decode("utf-8")
)
preds_dataset["annotations"] = apply_to_collection(
preds_dataset["annotations"],
dtype=np.uint32,
function=lambda x: int(x),
)
target_dataset = apply_to_collection(target_dataset, dtype=bytes, function=lambda x: x.decode("utf-8"))

preds_json = json.dumps(preds_dataset["annotations"], indent=4)
target_json = json.dumps(target_dataset, indent=4)
Expand Down
52 changes: 52 additions & 0 deletions tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,58 @@ def _generate_coco_inputs(iou_type):
_coco_segm_input = _generate_coco_inputs("segm")


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 and pycocotools is installed")
@pytest.mark.parametrize("iou_type", ["bbox", "segm"])
@pytest.mark.parametrize("backend", ["pycocotools", "faster_coco_eval"])
def test_tm_to_coco(tmpdir, iou_type, backend):
"""Test that the conversion from TM to COCO format works."""
preds, target = _coco_bbox_input if iou_type == "bbox" else _coco_segm_input
metric = MeanAveragePrecision(iou_type=iou_type, backend=backend, box_format="xywh")
for bp, bt in zip(preds, target):
metric.update(bp, bt)
metric.tm_to_coco(f"{tmpdir}/tm_map_input")
preds_2, target_2 = MeanAveragePrecision.coco_to_tm(
f"{tmpdir}/tm_map_input_preds.json",
f"{tmpdir}/tm_map_input_target.json",
iou_type=iou_type,
backend=backend,
)

preds = [p for batch in preds for p in batch]
target = [t for batch in target for t in batch]

# make sure that every prediction/target is found in the new prediction/target after saving and loading
for sample1 in preds:
sample_found = False
for sample2 in preds_2:
if iou_type == "segm":
if sample1["masks"].shape == sample2["masks"].shape and torch.allclose(
sample1["masks"], sample2["masks"]
):
sample_found = True
else:
if sample1["boxes"].shape == sample2["boxes"].shape and torch.allclose(
sample1["boxes"], sample2["boxes"]
):
sample_found = True
assert sample_found, "preds not found"

for sample1 in target:
sample_found = False
for sample2 in target_2:
if iou_type == "segm":
if sample1["masks"].shape == sample2["masks"].shape and torch.allclose(
sample1["masks"], sample2["masks"]
):
sample_found = True
else:
if sample1["boxes"].shape == sample2["boxes"].shape and torch.allclose(
sample1["boxes"], sample2["boxes"]
):
sample_found = True
assert sample_found, "target not found"


def _compare_against_coco_fn(preds, target, iou_type, iou_thresholds=None, rec_thresholds=None, class_metrics=True):
"""Taken from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb."""
with contextlib.redirect_stdout(io.StringIO()):
Expand Down

0 comments on commit 8172c58

Please sign in to comment.