diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a9a7c49d0e..537d5cf61f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `QualityWithNoReference` metric ([#2288](https://github.com/Lightning-AI/torchmetrics/pull/2288)) +- Added support for calculating segmentation quality and recognition quality in `PanopticQuality` metric ([#2381](https://github.com/Lightning-AI/torchmetrics/pull/2381)) + + ### Changed - Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424)) diff --git a/src/torchmetrics/detection/_deprecated.py b/src/torchmetrics/detection/_deprecated.py index c162c751554..898f341bd62 100644 --- a/src/torchmetrics/detection/_deprecated.py +++ b/src/torchmetrics/detection/_deprecated.py @@ -1,8 +1,17 @@ from typing import Any, Collection from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from torchmetrics.utilities.prints import _deprecated_root_import_class +if not _TORCH_GREATER_EQUAL_1_12: + __doctest_skip__ = [ + "_PanopticQuality", + "_PanopticQuality.*", + "_ModifiedPanopticQuality", + "_ModifiedPanopticQuality.*", + ] + class _ModifiedPanopticQuality(ModifiedPanopticQuality): """Wrapper for deprecated import. diff --git a/src/torchmetrics/detection/panoptic_qualities.py b/src/torchmetrics/detection/panoptic_qualities.py index 914cf7cf185..5f8aefabbc1 100644 --- a/src/torchmetrics/detection/panoptic_qualities.py +++ b/src/torchmetrics/detection/panoptic_qualities.py @@ -26,13 +26,17 @@ _validate_inputs, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_GREATER_EQUAL_1_12 from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["PanopticQuality.plot", "ModifiedPanopticQuality.plot"] +if not _TORCH_GREATER_EQUAL_1_12: + __doctest_skip__ = ["PanopticQuality", "PanopticQuality.*", "ModifiedPanopticQuality", "ModifiedPanopticQuality.*"] + + class PanopticQuality(Metric): r"""Compute the `Panoptic Quality`_ for panoptic segmentations. @@ -47,6 +51,23 @@ class PanopticQuality(Metric): Points in the target tensor that do not map to a known category ID are automatically ignored in the metric computation. + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, *spatial_dims, 2)``, where there needs to + be at least one spatial dimension. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, *spatial_dims, 2)``, where there needs to + be at least one spatial dimension. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``quality`` (:class:`~torch.Tensor`): If ``return_sq_and_rq=False`` and ``return_per_class=False`` then a + single scalar tensor is returned with average panoptic quality over all classes. If ``return_sq_and_rq=True`` + and ``return_per_class=False`` a tensor of length 3 is returned with panoptic, segmentation and recognition + quality (in that order). If If ``return_sq_and_rq=False`` and ``return_per_class=True`` a tensor of length + equal to the number of classes are returned, with panoptic quality for each class. Finally, if both arguments + are ``True`` a tensor of shape ``(3, C)`` is returned with individual panoptic, segmentation and recognition + quality for each class. + Args: things: Set of ``category_id`` for countable things. @@ -55,6 +76,10 @@ class PanopticQuality(Metric): allow_unknown_preds_category: Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric computation or raise an exception when found. + return_sq_and_rq: + Boolean flag to specify if Segmentation Quality and Recognition Quality should be also returned. + return_per_class: + Boolean flag to specify if the per-class values should be returned or the class average. Raises: @@ -80,6 +105,40 @@ class PanopticQuality(Metric): >>> panoptic_quality(preds, target) tensor(0.5463, dtype=torch.float64) + You can also return the segmentation and recognition quality alognside the PQ + >>> from torch import tensor + >>> from torchmetrics.detection import PanopticQuality + >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [7, 0], [6, 0], [1, 0]], + ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) + >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [1, 0]], + ... [[0, 1], [7, 0], [1, 0], [1, 0]], + ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) + >>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}, return_sq_and_rq=True) + >>> panoptic_quality(preds, target) + tensor([0.5463, 0.6111, 0.6667], dtype=torch.float64) + + You can also specify to return the per-class metrics + >>> from torch import tensor + >>> from torchmetrics.detection import PanopticQuality + >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [7, 0], [6, 0], [1, 0]], + ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) + >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [1, 0]], + ... [[0, 1], [7, 0], [1, 0], [1, 0]], + ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) + >>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}, return_per_class=True) + >>> panoptic_quality(preds, target) + tensor([[0.5185, 0.0000, 0.6667, 1.0000]], dtype=torch.float64) + """ is_differentiable: bool = False @@ -98,9 +157,13 @@ def __init__( things: Collection[int], stuffs: Collection[int], allow_unknown_preds_category: bool = False, + return_sq_and_rq: bool = False, + return_per_class: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) + if not _TORCH_GREATER_EQUAL_1_12: + raise RuntimeError("Panoptic Quality metric requires PyTorch 1.12 or later") things, stuffs = _parse_categories(things, stuffs) self.things = things @@ -108,6 +171,8 @@ def __init__( self.void_color = _get_void_color(things, stuffs) self.cat_id_to_continuous_id = _get_category_id_to_continuous_id(things, stuffs) self.allow_unknown_preds_category = allow_unknown_preds_category + self.return_sq_and_rq = return_sq_and_rq + self.return_per_class = return_per_class # per category intermediate metrics num_categories = len(things) + len(stuffs) @@ -154,7 +219,16 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute panoptic quality based on inputs passed in to ``update`` previously.""" - return _panoptic_quality_compute(self.iou_sum, self.true_positives, self.false_positives, self.false_negatives) + pq, sq, rq, pq_avg, sq_avg, rq_avg = _panoptic_quality_compute( + self.iou_sum, self.true_positives, self.false_positives, self.false_negatives + ) + if self.return_per_class: + if self.return_sq_and_rq: + return torch.stack((pq, sq, rq), dim=-1) + return pq.view(1, -1) + if self.return_sq_and_rq: + return torch.stack((pq_avg, sq_avg, rq_avg), dim=0) + return pq_avg def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None @@ -337,7 +411,10 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute panoptic quality based on inputs passed in to ``update`` previously.""" - return _panoptic_quality_compute(self.iou_sum, self.true_positives, self.false_positives, self.false_negatives) + _, _, _, pq_avg, _, _ = _panoptic_quality_compute( + self.iou_sum, self.true_positives, self.false_positives, self.false_negatives + ) + return pq_avg def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/src/torchmetrics/functional/detection/_deprecated.py b/src/torchmetrics/functional/detection/_deprecated.py index ce0e1ba6acf..b2500d34f0d 100644 --- a/src/torchmetrics/functional/detection/_deprecated.py +++ b/src/torchmetrics/functional/detection/_deprecated.py @@ -3,8 +3,12 @@ from torch import Tensor from torchmetrics.functional.detection.panoptic_qualities import modified_panoptic_quality, panoptic_quality +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from torchmetrics.utilities.prints import _deprecated_root_import_func +if not _TORCH_GREATER_EQUAL_1_12: + __doctest_skip__ = ["_panoptic_quality", "_modified_panoptic_quality"] + def _modified_panoptic_quality( preds: Tensor, diff --git a/src/torchmetrics/functional/detection/_panoptic_quality_common.py b/src/torchmetrics/functional/detection/_panoptic_quality_common.py index a8df9dd2d23..e00beb98bd2 100644 --- a/src/torchmetrics/functional/detection/_panoptic_quality_common.py +++ b/src/torchmetrics/functional/detection/_panoptic_quality_common.py @@ -449,7 +449,7 @@ def _panoptic_quality_compute( true_positives: Tensor, false_positives: Tensor, false_negatives: Tensor, -) -> Tensor: +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Compute the final panoptic quality from interim values. Args: @@ -459,11 +459,17 @@ def _panoptic_quality_compute( false_negatives: the FN value from the update step Returns: - Panoptic quality as a tensor containing a single scalar. + A tuple containing the per-class panoptic, segmentation and recognition quality followed by the averages """ - # per category calculation - denominator = (true_positives + 0.5 * false_positives + 0.5 * false_negatives).double() - panoptic_quality = torch.where(denominator > 0.0, iou_sum / denominator, 0.0) - # Reduce across categories. TODO: is it useful to have the option of returning per class metrics? - return torch.mean(panoptic_quality[denominator > 0]) + # compute segmentation and recognition quality (per-class) + sq: Tensor = torch.where(true_positives > 0.0, iou_sum / true_positives, 0.0) + denominator: Tensor = true_positives + 0.5 * false_positives + 0.5 * false_negatives + rq: Tensor = torch.where(denominator > 0.0, true_positives / denominator, 0.0) + # compute per-class panoptic quality + pq: Tensor = sq * rq + # compute averages + pq_avg: Tensor = torch.mean(pq[denominator > 0]) + sq_avg: Tensor = torch.mean(sq[denominator > 0]) + rq_avg: Tensor = torch.mean(rq[denominator > 0]) + return pq, sq, rq, pq_avg, sq_avg, rq_avg diff --git a/src/torchmetrics/functional/detection/panoptic_qualities.py b/src/torchmetrics/functional/detection/panoptic_qualities.py index be34439f883..2de9fa09bfa 100644 --- a/src/torchmetrics/functional/detection/panoptic_qualities.py +++ b/src/torchmetrics/functional/detection/panoptic_qualities.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Collection +import torch from torch import Tensor from torchmetrics.functional.detection._panoptic_quality_common import ( @@ -24,6 +25,10 @@ _prepocess_inputs, _validate_inputs, ) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 + +if not _TORCH_GREATER_EQUAL_1_12: + __doctest_skip__ = ["panoptic_quality", "modified_panoptic_quality"] def panoptic_quality( @@ -32,6 +37,8 @@ def panoptic_quality( things: Collection[int], stuffs: Collection[int], allow_unknown_preds_category: bool = False, + return_sq_and_rq: bool = False, + return_per_class: bool = False, ) -> Tensor: r"""Compute `Panoptic Quality`_ for panoptic segmentations. @@ -61,6 +68,10 @@ def panoptic_quality( allow_unknown_preds_category: Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric computation or raise an exception when found. + return_sq_and_rq: + Boolean flag to specify if Segmentation Quality and Recognition Quality should be also returned. + return_per_class: + Boolean flag to specify if the per-class values should be returned or the class average. Raises: ValueError: @@ -91,7 +102,59 @@ def panoptic_quality( >>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}) tensor(0.5463, dtype=torch.float64) + You can also return the segmentation and recognition quality alognside the PQ + >>> from torch import tensor + >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [7, 0], [6, 0], [1, 0]], + ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) + >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [1, 0]], + ... [[0, 1], [7, 0], [1, 0], [1, 0]], + ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) + >>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}, return_sq_and_rq=True) + tensor([0.5463, 0.6111, 0.6667], dtype=torch.float64) + + You can also specify to return the per-class metrics + >>> from torch import tensor + >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [7, 0], [6, 0], [1, 0]], + ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) + >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [1, 0]], + ... [[0, 1], [7, 0], [1, 0], [1, 0]], + ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) + >>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}, return_per_class=True) + tensor([[0.5185, 0.0000, 0.6667, 1.0000]], dtype=torch.float64) + + You can also specify to return the per-class metrics and the segmentation and recognition quality + >>> from torch import tensor + >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [7, 0], [6, 0], [1, 0]], + ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) + >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [1, 0]], + ... [[0, 1], [7, 0], [1, 0], [1, 0]], + ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) + >>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}, + ... return_per_class=True, return_sq_and_rq=True) + tensor([[0.5185, 0.7778, 0.6667], + [0.0000, 0.0000, 0.0000], + [0.6667, 0.6667, 1.0000], + [1.0000, 1.0000, 1.0000]], dtype=torch.float64) + """ + if not _TORCH_GREATER_EQUAL_1_12: + raise RuntimeError("Panoptic Quality metric requires PyTorch 1.12 or later") + things, stuffs = _parse_categories(things, stuffs) _validate_inputs(preds, target) void_color = _get_void_color(things, stuffs) @@ -101,7 +164,19 @@ def panoptic_quality( iou_sum, true_positives, false_positives, false_negatives = _panoptic_quality_update( flatten_preds, flatten_target, cat_id_to_continuous_id, void_color ) - return _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives) + pq, sq, rq, pq_avg, sq_avg, rq_avg = _panoptic_quality_compute( + iou_sum, + true_positives, + false_positives, + false_negatives, + ) + if return_per_class: + if return_sq_and_rq: + return torch.stack((pq, sq, rq), dim=-1) + return pq.view(1, -1) + if return_sq_and_rq: + return torch.stack((pq_avg, sq_avg, rq_avg), dim=0) + return pq_avg def modified_panoptic_quality( @@ -177,4 +252,5 @@ def modified_panoptic_quality( void_color, modified_metric_stuffs=stuffs, ) - return _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives) + _, _, _, pq_avg, _, _ = _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives) + return pq_avg diff --git a/tests/unittests/detection/test_modified_panoptic_quality.py b/tests/unittests/detection/test_modified_panoptic_quality.py index 4c864d0e9af..1d5a067a609 100644 --- a/tests/unittests/detection/test_modified_panoptic_quality.py +++ b/tests/unittests/detection/test_modified_panoptic_quality.py @@ -18,6 +18,7 @@ import torch from torchmetrics.detection import ModifiedPanopticQuality from torchmetrics.functional.detection import modified_panoptic_quality +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from unittests import _Input from unittests._helpers import seed_all @@ -76,6 +77,7 @@ def _reference_fn_1_2(preds, target) -> np.ndarray: return np.array([23 / 30]) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") class TestModifiedPanopticQuality(MetricTester): """Test class for `ModifiedPanopticQuality` metric.""" @@ -111,6 +113,7 @@ def test_panoptic_quality_functional(self): ) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") def test_empty_metric(): """Test empty metric.""" with pytest.raises(ValueError, match="At least one of `things` and `stuffs` must be non-empty"): @@ -120,6 +123,7 @@ def test_empty_metric(): assert torch.isnan(metric.compute()) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") def test_error_on_wrong_input(): """Test class input validation.""" with pytest.raises(TypeError, match="Expected argument `stuffs` to contain `int` categories.*"): @@ -162,6 +166,7 @@ def test_error_on_wrong_input(): metric.update(preds, preds) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") def test_extreme_values(): """Test that the metric returns expected values in trivial cases.""" # Exact match between preds and target => metric is 1 @@ -170,6 +175,7 @@ def test_extreme_values(): assert modified_panoptic_quality(_INPUTS_0.target[0], _INPUTS_0.target[0] + 1, **_ARGS_0) == 0.0 +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") @pytest.mark.parametrize( ("inputs", "args", "cat_dim"), [ diff --git a/tests/unittests/detection/test_panoptic_quality.py b/tests/unittests/detection/test_panoptic_quality.py index a8263de61e0..4d087073266 100644 --- a/tests/unittests/detection/test_panoptic_quality.py +++ b/tests/unittests/detection/test_panoptic_quality.py @@ -18,6 +18,7 @@ import torch from torchmetrics.detection.panoptic_qualities import PanopticQuality from torchmetrics.functional.detection.panoptic_qualities import panoptic_quality +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from unittests import _Input from unittests._helpers import seed_all @@ -83,6 +84,7 @@ def _reference_fn_1_2(preds, target) -> np.ndarray: return np.array([(2 / 3 + 1 + 2 / 3) / 3]) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") class TestPanopticQuality(MetricTester): """Test class for `PanopticQuality` metric.""" @@ -118,6 +120,7 @@ def test_panoptic_quality_functional(self): ) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") def test_empty_metric(): """Test empty metric.""" with pytest.raises(ValueError, match="At least one of `things` and `stuffs` must be non-empty"): @@ -127,6 +130,7 @@ def test_empty_metric(): assert torch.isnan(metric.compute()) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") def test_error_on_wrong_input(): """Test class input validation.""" with pytest.raises(TypeError, match="Expected argument `stuffs` to contain `int` categories.*"): @@ -169,6 +173,7 @@ def test_error_on_wrong_input(): metric.update(preds, preds) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") def test_extreme_values(): """Test that the metric returns expected values in trivial cases.""" # Exact match between preds and target => metric is 1 @@ -177,6 +182,7 @@ def test_extreme_values(): assert panoptic_quality(_INPUTS_0.target[0], _INPUTS_0.target[0] + 1, **_ARGS_0) == 0.0 +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12") @pytest.mark.parametrize( ("inputs", "args", "cat_dim"), [