Skip to content

Commit

Permalink
IOU with segm masks and MAP for instance segment (#822)
Browse files Browse the repository at this point in the history
* implementaed IOU with segmentation masks and MAP for instance segmentation

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
5 people authored May 24, 2022
1 parent 85d798e commit c31f43f
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 65 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `Dice` to classification package ([#1021](https://github.com/PyTorchLightning/metrics/pull/1021))

- Added support to segmentation type `segm` as IOU for mean average precision ([#822](https://github.com/PyTorchLightning/metrics/pull/822))

### Changed

Expand Down
1 change: 1 addition & 0 deletions requirements/detection.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
torchvision>=0.8
pycocotools
1 change: 1 addition & 0 deletions requirements/detection_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pycocotools
1 change: 1 addition & 0 deletions requirements/devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
-r image_test.txt
-r text_test.txt
-r audio_test.txt
-r detection_test.txt
5 changes: 5 additions & 0 deletions tests/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os

from tests import _PATH_ROOT

_SAMPLE_DETECTION_SEGMENTATION = os.path.join(_PATH_ROOT, "_data", "detection", "instance_segmentation_inputs.json")
115 changes: 110 additions & 5 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from collections import namedtuple

import numpy as np
import pytest
import torch
from pycocotools import mask
from torch import IntTensor, Tensor

from tests.detection import _SAMPLE_DETECTION_SEGMENTATION
from tests.helpers.testers import MetricTester
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8

Input = namedtuple("Input", ["preds", "target"])

with open(_SAMPLE_DETECTION_SEGMENTATION) as fp:
inputs_json = json.load(fp)

_mask_unsqueeze_bool = lambda m: Tensor(mask.decode(m)).unsqueeze(0).bool()
_masks_stack_bool = lambda ms: Tensor(np.stack([mask.decode(m) for m in ms])).bool()

_inputs_masks = Input(
preds=[
[
dict(masks=_mask_unsqueeze_bool(inputs_json["preds"][0]), scores=Tensor([0.236]), labels=IntTensor([4])),
dict(
masks=_masks_stack_bool([inputs_json["preds"][1], inputs_json["preds"][2]]),
scores=Tensor([0.318, 0.726]),
labels=IntTensor([3, 2]),
), # 73
],
],
target=[
[
dict(masks=_mask_unsqueeze_bool(inputs_json["targets"][0]), labels=IntTensor([4])), # 42
dict(
masks=_masks_stack_bool([inputs_json["targets"][1], inputs_json["targets"][2]]),
labels=IntTensor([2, 2]),
), # 73
],
],
)


_inputs = Input(
preds=[
[
Expand Down Expand Up @@ -139,15 +172,15 @@
_inputs3 = Input(
preds=[
[
dict(boxes=torch.tensor([]), scores=torch.tensor([]), labels=torch.tensor([])),
dict(boxes=Tensor([]), scores=Tensor([]), labels=Tensor([])),
],
],
target=[
[
dict(
boxes=torch.tensor([[1.0, 2.0, 3.0, 4.0]]),
scores=torch.tensor([0.8]),
labels=torch.tensor([1]),
boxes=Tensor([[1.0, 2.0, 3.0, 4.0]]),
scores=Tensor([0.8]),
labels=Tensor([1]),
),
],
],
Expand Down Expand Up @@ -214,6 +247,41 @@ def _compare_fn(preds, target) -> dict:
}


def _compare_fn_segm(preds, target) -> dict:
"""Comparison function for map implementation for instance segmentation.
Official pycocotools results calculated from a subset of https://github.com/cocodataset/cocoapi/tree/master/results
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.352
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.752
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.252
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.352
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.350
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.350
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.350
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.350
"""
return {
"map": Tensor([0.352]),
"map_50": Tensor([0.742]),
"map_75": Tensor([0.252]),
"map_small": Tensor([-1]),
"map_medium": Tensor([-1]),
"map_large": Tensor([0.352]),
"mar_1": Tensor([0.35]),
"mar_10": Tensor([0.35]),
"mar_100": Tensor([0.35]),
"mar_small": Tensor([-1]),
"mar_medium": Tensor([-1]),
"mar_large": Tensor([0.35]),
"map_per_class": Tensor([0.4039604, -1.0, 0.3]),
"mar_100_per_class": Tensor([0.4, -1.0, 0.3]),
}


_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8)


Expand All @@ -230,7 +298,8 @@ class TestMAP(MetricTester):
atol = 1e-1

@pytest.mark.parametrize("ddp", [False, True])
def test_map(self, compute_on_cpu, ddp):
def test_map_bbox(self, compute_on_cpu, ddp):

"""Test modular implementation for correctness."""
self.run_class_metric_test(
ddp=ddp,
Expand All @@ -243,6 +312,21 @@ def test_map(self, compute_on_cpu, ddp):
metric_args={"class_metrics": True, "compute_on_cpu": compute_on_cpu},
)

@pytest.mark.parametrize("ddp", [False])
def test_map_segm(self, compute_on_cpu, ddp):
"""Test modular implementation for correctness."""

self.run_class_metric_test(
ddp=ddp,
preds=_inputs_masks.preds,
target=_inputs_masks.target,
metric_class=MeanAveragePrecision,
sk_metric=_compare_fn_segm,
dist_sync_on_step=False,
check_batch=False,
metric_args={"class_metrics": True, "compute_on_cpu": compute_on_cpu, "iou_type": "segm"},
)


# noinspection PyTypeChecker
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
Expand Down Expand Up @@ -377,6 +461,27 @@ def test_missing_gt():
assert result["map"] < 1, "MAP cannot be 1, as there is an image with no ground truth, but some predictions."


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_segm_iou_empty_mask():
"""Test empty ground truths."""
metric = MeanAveragePrecision(iou_type="segm")

metric.update(
[
dict(
masks=torch.randint(0, 1, (1, 10, 10)).bool(),
scores=Tensor([0.5]),
labels=IntTensor([4]),
),
],
[
dict(masks=Tensor([]), labels=IntTensor([])),
],
)

metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_error_on_wrong_input():
"""Test class input validation."""
Expand Down
1 change: 1 addition & 0 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _class_test(
_assert_allclose(batch_result, sk_batch_result, atol=atol)

# check that metrics are hashable

assert hash(metric)

# assert that state dict is empty
Expand Down
Loading

0 comments on commit c31f43f

Please sign in to comment.