Skip to content

Commit a4bcb12

Browse files
nibergerNicolas Bergerpre-commit-ci[bot]SkafteNickiBorda
authored
Detection: panoptic quality (#929)
Co-authored-by: Nicolas Berger <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka <[email protected]> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 61c2385 commit a4bcb12

File tree

10 files changed

+688
-0
lines changed

10 files changed

+688
-0
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2323
- Added `classes` to output from `MAP` metric ([#1419](https://github.com/Lightning-AI/metrics/pull/1419))
2424

2525

26+
- Added new detection metric `PanopticQuality` ([#929](https://github.com/PyTorchLightning/metrics/pull/929))
27+
28+
2629
- Add `ClassificationTask` Enum and use in metrics ([#1479](https://github.com/Lightning-AI/metrics/pull/1479))
2730

2831

@@ -314,6 +317,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
314317
- Added `reset_real_features` argument image quality assessment metrics ([#722](https://github.com/Lightning-AI/metrics/pull/722))
315318
- Added new keyword argument `compute_on_cpu` to all metrics ([#867](https://github.com/Lightning-AI/metrics/pull/867))
316319

320+
317321
### Changed
318322

319323
- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/Lightning-AI/metrics/pull/853), [#914](https://github.com/Lightning-AI/metrics/pull/914))
+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
.. customcarditem::
2+
:header: Panoptic Quality
3+
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
4+
:tags: Detection
5+
6+
################
7+
Panoptic Quality
8+
################
9+
10+
Module Interface
11+
________________
12+
13+
.. autoclass:: torchmetrics.PanopticQuality
14+
:noindex:
15+
:exclude-members: update, compute
16+
17+
Functional Interface
18+
____________________
19+
20+
.. autofunction:: torchmetrics.functional.panoptic_quality
21+
:noindex:

docs/source/links.rst

+1
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,5 @@
128128
.. _kid ref2: https://arxiv.org/abs/1706.08500
129129
.. _Spectral Angle Mapper: https://ntrs.nasa.gov/citations/19940012238
130130
.. _Multilabel coverage error: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34
131+
.. _Panoptic Quality: https://arxiv.org/abs/1801.00868
131132
.. _torchmetrics mAP example: https://github.com/Lightning-AI/metrics/blob/master/examples/detection_map.py

src/torchmetrics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
StatScores,
4444
)
4545
from torchmetrics.collections import MetricCollection # noqa: E402
46+
from torchmetrics.detection import PanopticQuality # noqa: E402
4647
from torchmetrics.image import ( # noqa: E402
4748
ErrorRelativeGlobalDimensionlessSynthesis,
4849
MultiScaleStructuralSimilarityIndexMeasure,
@@ -153,6 +154,7 @@
153154
"MinMetric",
154155
"MultioutputWrapper",
155156
"MultiScaleStructuralSimilarityIndexMeasure",
157+
"PanopticQuality",
156158
"PearsonCorrCoef",
157159
"PearsonsContingencyCoefficient",
158160
"PermutationInvariantTraining",

src/torchmetrics/detection/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@
1515

1616
if _TORCHVISION_GREATER_EQUAL_0_8:
1717
from torchmetrics.detection.mean_ap import MeanAveragePrecision # noqa: F401
18+
19+
from torchmetrics.detection.panoptic_quality import PanopticQuality # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import warnings
15+
from typing import Any, Set
16+
17+
import torch
18+
from torch import Tensor
19+
20+
from torchmetrics.functional.detection.panoptic_quality import (
21+
_get_category_id_to_continuous_id,
22+
_get_void_color,
23+
_panoptic_quality_compute,
24+
_panoptic_quality_update,
25+
_prepocess_image,
26+
_validate_categories,
27+
_validate_inputs,
28+
)
29+
from torchmetrics.metric import Metric
30+
31+
32+
class PanopticQuality(Metric):
33+
r"""Compute the `Panoptic Quality`_ for panoptic segmentations.
34+
35+
.. math::
36+
PQ = \frac{IOU}{TP + 0.5 FP + 0.5 FN}
37+
38+
where IOU, TP, FP and FN are respectively the sum of the intersection over union for true positives,
39+
the number of true postitives, false positives and false negatives. This metric is inspired by the PQ
40+
implementati on of panopticapi, a standard implementation for the PQ metric for object detection.
41+
42+
.. note:
43+
Metric is currently experimental
44+
45+
Args:
46+
things:
47+
Set of ``category_id`` for countable things.
48+
stuffs:
49+
Set of ``category_id`` for uncountable stuffs.
50+
allow_unknown_preds_category:
51+
Bool indication if unknown categories in preds is allowed
52+
53+
Raises:
54+
ValueError:
55+
If ``things``, ``stuffs`` share the same ``category_id``.
56+
57+
Example:
58+
>>> from torch import tensor
59+
>>> preds = tensor([[[6, 0], [0, 0], [6, 0], [6, 0]],
60+
... [[0, 0], [0, 0], [6, 0], [0, 1]],
61+
... [[0, 0], [0, 0], [6, 0], [0, 1]],
62+
... [[0, 0], [7, 0], [6, 0], [1, 0]],
63+
... [[0, 0], [7, 0], [7, 0], [7, 0]]])
64+
>>> target = tensor([[[6, 0], [0, 1], [6, 0], [0, 1]],
65+
... [[0, 1], [0, 1], [6, 0], [0, 1]],
66+
... [[0, 1], [0, 1], [6, 0], [1, 0]],
67+
... [[0, 1], [7, 0], [1, 0], [1, 0]],
68+
... [[0, 1], [7, 0], [7, 0], [7, 0]]])
69+
>>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7})
70+
>>> panoptic_quality(preds, target)
71+
tensor(0.5463, dtype=torch.float64)
72+
"""
73+
is_differentiable: bool = False
74+
higher_is_better: bool = True
75+
full_state_update: bool = False
76+
77+
iou_sum: Tensor
78+
true_positives: Tensor
79+
false_positives: Tensor
80+
false_negatives: Tensor
81+
82+
def __init__(
83+
self,
84+
things: Set[int],
85+
stuffs: Set[int],
86+
allow_unknown_preds_category: bool = False,
87+
**kwargs: Any,
88+
):
89+
super().__init__(**kwargs)
90+
91+
# todo: better testing for correctness of metric
92+
warnings.warn("This is experimental version and are actively working on its stability.")
93+
94+
_validate_categories(things, stuffs)
95+
self.things = things
96+
self.stuffs = stuffs
97+
self.void_color = _get_void_color(things, stuffs)
98+
self.cat_id_to_continuous_id = _get_category_id_to_continuous_id(things, stuffs)
99+
self.allow_unknown_preds_category = allow_unknown_preds_category
100+
101+
# per category intermediate metrics
102+
n_categories = len(things) + len(stuffs)
103+
self.add_state("iou_sum", default=torch.zeros(n_categories, dtype=torch.double), dist_reduce_fx="sum")
104+
self.add_state("true_positives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum")
105+
self.add_state("false_positives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum")
106+
self.add_state("false_negatives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum")
107+
108+
def update(self, preds: Tensor, target: Tensor) -> None:
109+
r"""Update state with predictions and targets.
110+
111+
Args:
112+
preds: panoptic detection of shape ``[height, width, 2]`` containing
113+
the pair ``(category_id, instance_id)`` for each pixel of the image.
114+
If the ``category_id`` refer to a stuff, the instance_id is ignored.
115+
116+
target: ground truth of shape ``[height, width, 2]`` containing
117+
the pair ``(category_id, instance_id)`` for each pixel of the image.
118+
If the ``category_id`` refer to a stuff, the instance_id is ignored.
119+
120+
Raises:
121+
TypeError:
122+
If ``preds`` or ``target`` is not an ``torch.Tensor``
123+
ValueError:
124+
If ``preds`` or ``target`` has different shape.
125+
ValueError:
126+
If ``preds`` is not a 3D tensor where the final dimension have size 2
127+
"""
128+
_validate_inputs(preds, target)
129+
flatten_preds = _prepocess_image(
130+
self.things, self.stuffs, preds, self.void_color, self.allow_unknown_preds_category
131+
)
132+
flatten_target = _prepocess_image(self.things, self.stuffs, target, self.void_color, True)
133+
iou_sum, true_positives, false_positives, false_negatives = _panoptic_quality_update(
134+
flatten_preds, flatten_target, self.cat_id_to_continuous_id, self.void_color
135+
)
136+
self.iou_sum += iou_sum
137+
self.true_positives += true_positives
138+
self.false_positives += false_positives
139+
self.false_negatives += false_negatives
140+
141+
def compute(self) -> Tensor:
142+
"""Computes panoptic quality based on inputs passed in to ``update`` previously."""
143+
return _panoptic_quality_compute(self.iou_sum, self.true_positives, self.false_positives, self.false_negatives)

src/torchmetrics/functional/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torchmetrics.functional.classification.roc import roc
3333
from torchmetrics.functional.classification.specificity import specificity
3434
from torchmetrics.functional.classification.stat_scores import stat_scores
35+
from torchmetrics.functional.detection.panoptic_quality import panoptic_quality
3536
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
3637
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
3738
from torchmetrics.functional.image.gradients import image_gradients
@@ -138,6 +139,7 @@
138139
"pairwise_euclidean_distance",
139140
"pairwise_linear_similarity",
140141
"pairwise_manhattan_distance",
142+
"panoptic_quality",
141143
"pearson_corrcoef",
142144
"pearsons_contingency_coefficient",
143145
"pearsons_contingency_coefficient_matrix",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from torchmetrics.functional.detection.panoptic_quality import panoptic_quality # noqa: F401

0 commit comments

Comments
 (0)