From b9fa4fda7f9172b8be656c0a098438878ff3959b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 21 Mar 2023 08:08:49 +0100 Subject: [PATCH 1/6] build(deps): update jiwer requirement from <=2.5.2,>=2.3.0 to >=2.3.0,<=3.0.0 in /requirements (#1635) build(deps): update jiwer requirement in /requirements Updates the requirements on [jiwer](https://github.com/jitsi/jiwer) to permit the latest version. - [Release notes](https://github.com/jitsi/jiwer/releases) - [Commits](https://github.com/jitsi/jiwer/compare/v2.3.0...v3.0.0) --- updated-dependencies: - dependency-name: jiwer dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/text_test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/text_test.txt b/requirements/text_test.txt index 861ba187abf..9c271bf803e 100644 --- a/requirements/text_test.txt +++ b/requirements/text_test.txt @@ -1,7 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -jiwer >=2.3.0, <=2.5.2 +jiwer >=2.3.0, <=3.0.0 rouge-score >0.1.0, <=0.1.2 bert_score ==0.3.13 transformers >4.4.0, <4.26.2 From 18d3dd862517cfd72c6f6c58dc9b7f7ca2b5df4e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 21 Mar 2023 09:22:06 +0100 Subject: [PATCH 2/6] build(deps): bump mypy from 1.0.1 to 1.1.1 in /requirements (#1637) * build(deps): bump mypy from 1.0.1 to 1.1.1 in /requirements Bumps [mypy](https://github.com/python/mypy) from 1.0.1 to 1.1.1. - [Release notes](https://github.com/python/mypy/releases) - [Commits](https://github.com/python/mypy/compare/v1.0.1...v1.1.1) --- updated-dependencies: - dependency-name: mypy dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * typing * typing --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jirka --- requirements/typing.txt | 2 +- src/torchmetrics/metric.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/requirements/typing.txt b/requirements/typing.txt index db9bef08324..1ae1238eaa4 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,4 +1,4 @@ -mypy==1.0.1 +mypy==1.1.1 types-PyYAML types-emoji diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 954ff90ad30..9a392c8dd5e 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -138,8 +138,8 @@ def __init__( # initialize self._update_signature = inspect.signature(self.update) - self.update: Callable = self._wrap_update(self.update) # type: ignore[assignment] - self.compute: Callable = self._wrap_compute(self.compute) # type: ignore[assignment] + self.update: Callable = self._wrap_update(self.update) # type: ignore[method-assign] + self.compute: Callable = self._wrap_compute(self.compute) # type: ignore[method-assign] self._computed = None self._forward_cache = None self._update_count = 0 @@ -639,8 +639,8 @@ def __setstate__(self, state: Dict[str, Any]) -> None: # manually restore update and compute functions for pickling self.__dict__.update(state) self._update_signature = inspect.signature(self.update) - self.update: Callable = self._wrap_update(self.update) # type: ignore[assignment] - self.compute: Callable = self._wrap_compute(self.compute) # type: ignore[assignment] + self.update: Callable = self._wrap_update(self.update) # type: ignore[method-assign] + self.compute: Callable = self._wrap_compute(self.compute) # type: ignore[method-assign] def __setattr__(self, name: str, value: Any) -> None: """Overwrite default method to prevent specific attributes from being set by user.""" From b785d287adb4316cd19246c8455650870e433dd3 Mon Sep 17 00:00:00 2001 From: Marco Caccin Date: Tue, 21 Mar 2023 10:28:23 +0100 Subject: [PATCH 3/6] Feature/modified panoptic quality (#1627) Co-authored-by: SkafteNicki Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka --- CHANGELOG.md | 4 + .../detection/modified_panoptic_quality.rst | 23 ++ docs/source/links.rst | 1 + src/torchmetrics/__init__.py | 3 +- src/torchmetrics/detection/__init__.py | 1 + .../detection/modified_panoptic_quality.py | 210 ++++++++++++++++++ .../detection/panoptic_quality.py | 56 ++--- src/torchmetrics/functional/__init__.py | 2 + .../functional/detection/__init__.py | 1 + .../detection/_panoptic_quality_common.py | 39 +++- .../detection/modified_panoptic_quality.py | 101 +++++++++ .../test_modified_panoptic_quality.py | 198 +++++++++++++++++ 12 files changed, 602 insertions(+), 37 deletions(-) create mode 100644 docs/source/detection/modified_panoptic_quality.rst create mode 100644 src/torchmetrics/detection/modified_panoptic_quality.py create mode 100644 src/torchmetrics/functional/detection/modified_panoptic_quality.py create mode 100644 tests/unittests/detection/test_modified_panoptic_quality.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 00a5b6509f6..40ba935bf4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -73,6 +73,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for plotting of aggregation metrics through `.plot()` method ([#1485](https://github.com/Lightning-AI/metrics/pull/1485)) + +- Added `ModifiedPanopticQuality` metric to detection package ([#1627](https://github.com/Lightning-AI/metrics/pull/1627)) + + ### Changed - Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370)) diff --git a/docs/source/detection/modified_panoptic_quality.rst b/docs/source/detection/modified_panoptic_quality.rst new file mode 100644 index 00000000000..933df05d907 --- /dev/null +++ b/docs/source/detection/modified_panoptic_quality.rst @@ -0,0 +1,23 @@ +.. customcarditem:: + :header: Modified Panoptic Quality + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg + :tags: Detection + +######################### +Modified Panoptic Quality +######################### + +.. include:: ../links.rst + +Module Interface +________________ + +.. autoclass:: torchmetrics.ModifiedPanopticQuality + :noindex: + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.modified_panoptic_quality + :noindex: diff --git a/docs/source/links.rst b/docs/source/links.rst index fcc1d46a5d9..f8f8cf0ef9a 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -134,3 +134,4 @@ .. _Minkowski Distance: https://en.wikipedia.org/wiki/Minkowski_distance .. _Demographic parity: http://www.fairmlbook.org/ .. _Equal opportunity: https://proceedings.neurips.cc/paper/2016/hash/9d2682367c3935defcb1f9e247a97c0d-Abstract.html +.. _Seamless Scene Segmentation paper: https://arxiv.org/abs/1905.01220 diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 9bddec9a057..d62a1b8452b 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -43,7 +43,7 @@ StatScores, ) from torchmetrics.collections import MetricCollection # noqa: E402 -from torchmetrics.detection import PanopticQuality # noqa: E402 +from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality # noqa: E402 from torchmetrics.image import ( # noqa: E402 ErrorRelativeGlobalDimensionlessSynthesis, MultiScaleStructuralSimilarityIndexMeasure, @@ -152,6 +152,7 @@ "MetricTracker", "MinMaxMetric", "MinMetric", + "ModifiedPanopticQuality", "MultioutputWrapper", "MultiScaleStructuralSimilarityIndexMeasure", "PanopticQuality", diff --git a/src/torchmetrics/detection/__init__.py b/src/torchmetrics/detection/__init__.py index dc5523aaa87..68dd3723400 100644 --- a/src/torchmetrics/detection/__init__.py +++ b/src/torchmetrics/detection/__init__.py @@ -16,4 +16,5 @@ if _TORCHVISION_GREATER_EQUAL_0_8: from torchmetrics.detection.mean_ap import MeanAveragePrecision # noqa: F401 +from torchmetrics.detection.modified_panoptic_quality import ModifiedPanopticQuality # noqa: F401 from torchmetrics.detection.panoptic_quality import PanopticQuality # noqa: F401 diff --git a/src/torchmetrics/detection/modified_panoptic_quality.py b/src/torchmetrics/detection/modified_panoptic_quality.py new file mode 100644 index 00000000000..aa3d0b0fb5f --- /dev/null +++ b/src/torchmetrics/detection/modified_panoptic_quality.py @@ -0,0 +1,210 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Collection, Optional, Sequence, Union + +import torch +from torch import Tensor + +from torchmetrics.functional.detection._panoptic_quality_common import ( + _get_category_id_to_continuous_id, + _get_void_color, + _panoptic_quality_compute, + _panoptic_quality_update, + _parse_categories, + _prepocess_inputs, + _validate_inputs, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["ModifiedPanopticQuality.plot"] + + +class ModifiedPanopticQuality(Metric): + r"""Compute `Modified Panoptic Quality`_ for panoptic segmentations. + + The metric was introduced in `Seamless Scene Segmentation paper`_, and is an adaptation of the original + `Panoptic Quality`_ where the metric for a stuff class is computed as + + .. math:: + PQ^{\dagger}_c = \frac{IOU_c}{|S_c|} + + where IOU_c is the sum of the intersection over union of all matching segments for a given class, and \|S_c| is + the overall number of segments in the ground truth for that class. + + .. note: + Points in the target tensor that do not map to a known category ID are automatically ignored in the metric + computation. + + Args: + things: + Set of ``category_id`` for countable things. + stuffs: + Set of ``category_id`` for uncountable stuffs. + allow_unknown_preds_category: + Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric + computation or raise an exception when found. + + + Raises: + ValueError: + If ``things``, ``stuffs`` have at least one common ``category_id``. + TypeError: + If ``things``, ``stuffs`` contain non-integer ``category_id``. + + Example: + >>> from torch import tensor + >>> from torchmetrics import ModifiedPanopticQuality + >>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]]) + >>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]]) + >>> pq_modified = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7}) + >>> pq_modified(preds, target) + tensor(0.7667, dtype=torch.float64) + """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + iou_sum: Tensor + true_positives: Tensor + false_positives: Tensor + false_negatives: Tensor + + def __init__( + self, + things: Collection[int], + stuffs: Collection[int], + allow_unknown_preds_category: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + things, stuffs = _parse_categories(things, stuffs) + self.things = things + self.stuffs = stuffs + self.void_color = _get_void_color(things, stuffs) + self.cat_id_to_continuous_id = _get_category_id_to_continuous_id(things, stuffs) + self.allow_unknown_preds_category = allow_unknown_preds_category + + # per category intermediate metrics + n_categories = len(things) + len(stuffs) + self.add_state("iou_sum", default=torch.zeros(n_categories, dtype=torch.double), dist_reduce_fx="sum") + self.add_state("true_positives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum") + self.add_state("false_positives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum") + self.add_state("false_negatives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + r"""Update state with predictions and targets. + + Args: + preds: panoptic detection of shape ``[batch, *spatial_dims, 2]`` containing + the pair ``(category_id, instance_id)`` for each point. + If the ``category_id`` refer to a stuff, the instance_id is ignored. + + target: ground truth of shape ``[batch, *spatial_dims, 2]`` containing + the pair ``(category_id, instance_id)`` for each pixel of the image. + If the ``category_id`` refer to a stuff, the instance_id is ignored. + + Raises: + TypeError: + If ``preds`` or ``target`` is not an ``torch.Tensor``. + ValueError: + If ``preds`` and ``target`` have different shape. + ValueError: + If ``preds`` has less than 3 dimensions. + ValueError: + If the final dimension of ``preds`` has size != 2. + """ + _validate_inputs(preds, target) + flatten_preds = _prepocess_inputs( + self.things, self.stuffs, preds, self.void_color, self.allow_unknown_preds_category + ) + flatten_target = _prepocess_inputs(self.things, self.stuffs, target, self.void_color, True) + iou_sum, true_positives, false_positives, false_negatives = _panoptic_quality_update( + flatten_preds, + flatten_target, + self.cat_id_to_continuous_id, + self.void_color, + modified_metric_stuffs=self.stuffs, + ) + self.iou_sum += iou_sum + self.true_positives += true_positives + self.false_positives += false_positives + self.false_negatives += false_negatives + + def compute(self) -> Tensor: + """Compute panoptic quality based on inputs passed in to ``update`` previously.""" + return _panoptic_quality_compute(self.iou_sum, self.true_positives, self.false_positives, self.false_negatives) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import tensor + >>> from torchmetrics import ModifiedPanopticQuality + >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [7, 0], [6, 0], [1, 0]], + ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) + >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [1, 0]], + ... [[0, 1], [7, 0], [1, 0], [1, 0]], + ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) + >>> metric = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7}) + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torch import tensor + >>> from torchmetrics import ModifiedPanopticQuality + >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [7, 0], [6, 0], [1, 0]], + ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) + >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [1, 0]], + ... [[0, 1], [7, 0], [1, 0], [1, 0]], + ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) + >>> metric = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7}) + >>> vals = [] + >>> for _ in range(20): + ... vals.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(vals) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/detection/panoptic_quality.py b/src/torchmetrics/detection/panoptic_quality.py index a13f2c4c97e..8d73c977d4b 100644 --- a/src/torchmetrics/detection/panoptic_quality.py +++ b/src/torchmetrics/detection/panoptic_quality.py @@ -48,37 +48,37 @@ class PanopticQuality(Metric): computation. Args: - things: - Set of ``category_id`` for countable things. - stuffs: - Set of ``category_id`` for uncountable stuffs. - allow_unknown_preds_category: - Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric - computation or raise an exception when found. + things: + Set of ``category_id`` for countable things. + stuffs: + Set of ``category_id`` for uncountable stuffs. + allow_unknown_preds_category: + Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric + computation or raise an exception when found. Raises: - ValueError: - If ``things``, ``stuffs`` have at least one common ``category_id``. - TypeError: - If ``things``, ``stuffs`` contain non-integer ``category_id``. - - Example:ty - >>> from torch import tensor - >>> from torchmetrics import PanopticQuality - >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], - ... [[0, 0], [0, 0], [6, 0], [0, 1]], - ... [[0, 0], [0, 0], [6, 0], [0, 1]], - ... [[0, 0], [7, 0], [6, 0], [1, 0]], - ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) - >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], - ... [[0, 1], [0, 1], [6, 0], [0, 1]], - ... [[0, 1], [0, 1], [6, 0], [1, 0]], - ... [[0, 1], [7, 0], [1, 0], [1, 0]], - ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) - >>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}) - >>> panoptic_quality(preds, target) - tensor(0.5463, dtype=torch.float64) + ValueError: + If ``things``, ``stuffs`` have at least one common ``category_id``. + TypeError: + If ``things``, ``stuffs`` contain non-integer ``category_id``. + + Example: + >>> from torch import tensor + >>> from torchmetrics import PanopticQuality + >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [0, 0], [6, 0], [0, 1]], + ... [[0, 0], [7, 0], [6, 0], [1, 0]], + ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) + >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [0, 1]], + ... [[0, 1], [0, 1], [6, 0], [1, 0]], + ... [[0, 1], [7, 0], [1, 0], [1, 0]], + ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) + >>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}) + >>> panoptic_quality(preds, target) + tensor(0.5463, dtype=torch.float64) """ is_differentiable: bool = False diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 424dbe19eba..dc050c66f95 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -32,6 +32,7 @@ from torchmetrics.functional.classification.roc import roc from torchmetrics.functional.classification.specificity import specificity from torchmetrics.functional.classification.stat_scores import stat_scores +from torchmetrics.functional.detection.modified_panoptic_quality import modified_panoptic_quality from torchmetrics.functional.detection.panoptic_quality import panoptic_quality from torchmetrics.functional.image.d_lambda import spectral_distortion_index from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis @@ -139,6 +140,7 @@ "mean_squared_error", "mean_squared_log_error", "minkowski_distance", + "modified_panoptic_quality", "multiscale_structural_similarity_index_measure", "pairwise_cosine_similarity", "pairwise_euclidean_distance", diff --git a/src/torchmetrics/functional/detection/__init__.py b/src/torchmetrics/functional/detection/__init__.py index f605c1d057d..15c5417b473 100644 --- a/src/torchmetrics/functional/detection/__init__.py +++ b/src/torchmetrics/functional/detection/__init__.py @@ -11,4 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.functional.detection.modified_panoptic_quality import modified_panoptic_quality # noqa: F401 from torchmetrics.functional.detection.panoptic_quality import panoptic_quality # noqa: F401 diff --git a/src/torchmetrics/functional/detection/_panoptic_quality_common.py b/src/torchmetrics/functional/detection/_panoptic_quality_common.py index 12a924363f1..69245393799 100644 --- a/src/torchmetrics/functional/detection/_panoptic_quality_common.py +++ b/src/torchmetrics/functional/detection/_panoptic_quality_common.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Dict, Iterator, List, Set, Tuple, cast +from typing import Collection, Dict, Iterator, List, Optional, Set, Tuple, cast import torch from torch import Tensor @@ -302,16 +302,25 @@ def _panoptic_quality_update_sample( flatten_target: Tensor, cat_id_to_continuous_id: Dict[int, int], void_color: Tuple[int, int], + stuffs_modified_metric: Optional[Set[int]] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """Calculate stat scores required to compute accuracy **for a single sample**. + """Calculate stat scores required to compute the metric **for a single sample**. Computed scores: iou sum, true positives, false positives, false negatives. + NOTE: For the modified PQ case, this implementation uses the `true_positives` output tensor to aggregate the actual + TPs for things classes, but the number of target segments for stuff classes. + The `iou_sum` output tensor, instead, aggregates the IoU values at different thresholds (i.e., 0.5 for things + and 0 for stuffs). + This allows seamlessly using the same `.compute()` method for both PQ variants. + Args: flatten_preds: A flattened prediction tensor referring to a single sample, shape (num_points, 2). flatten_target: A flattened target tensor referring to a single sample, shape (num_points, 2). cat_id_to_continuous_id: Mapping from original category IDs to continuous IDs void_color: an additional, unused color. + stuffs_modified_metric: Set of stuff category IDs for which the PQ metric is computed using the "modified" + formula. If not specified, the original formula is used for all categories. Returns: - IOU Sum @@ -319,6 +328,7 @@ def _panoptic_quality_update_sample( - False positives - False negatives. """ + stuffs_modified_metric = stuffs_modified_metric or set() device = flatten_preds.device n_categories = len(cat_id_to_continuous_id) iou_sum = torch.zeros(n_categories, dtype=torch.double, device=device) @@ -345,19 +355,28 @@ def _panoptic_quality_update_sample( continue iou = _calculate_iou(pred_color, target_color, pred_areas, target_areas, intersection_areas, void_color) continuous_id = cat_id_to_continuous_id[target_color[0]] - if iou > 0.5: + if target_color[0] not in stuffs_modified_metric and iou > 0.5: pred_segment_matched.add(pred_color) target_segment_matched.add(target_color) iou_sum[continuous_id] += iou true_positives[continuous_id] += 1 + elif target_color[0] in stuffs_modified_metric and iou > 0: + iou_sum[continuous_id] += iou for cat_id in _filter_false_negatives(target_areas, target_segment_matched, intersection_areas, void_color): - continuous_id = cat_id_to_continuous_id[cat_id] - false_negatives[continuous_id] += 1 + if cat_id not in stuffs_modified_metric: + continuous_id = cat_id_to_continuous_id[cat_id] + false_negatives[continuous_id] += 1 for cat_id in _filter_false_positives(pred_areas, pred_segment_matched, intersection_areas, void_color): - continuous_id = cat_id_to_continuous_id[cat_id] - false_positives[continuous_id] += 1 + if cat_id not in stuffs_modified_metric: + continuous_id = cat_id_to_continuous_id[cat_id] + false_positives[continuous_id] += 1 + + for cat_id, _ in target_areas: + if cat_id in stuffs_modified_metric: + continuous_id = cat_id_to_continuous_id[cat_id] + true_positives[continuous_id] += 1 return iou_sum, true_positives, false_positives, false_negatives @@ -367,8 +386,9 @@ def _panoptic_quality_update( flatten_target: Tensor, cat_id_to_continuous_id: Dict[int, int], void_color: Tuple[int, int], + modified_metric_stuffs: Optional[Set[int]] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """Calculate stat scores required to compute accuracy. + """Calculate stat scores required to compute the metric for a full batch. Computed scores: iou sum, true positives, false positives, false negatives. @@ -377,6 +397,8 @@ def _panoptic_quality_update( flatten_target: A flattened target tensor, shape (B, num_points, 2). cat_id_to_continuous_id: Mapping from original category IDs to continuous IDs. void_color: an additional, unused color. + modified_metric_stuffs: Set of stuff category IDs for which the PQ metric is computed using the "modified" + formula. If not specified, the original formula is used for all categories. Returns: - IOU Sum @@ -398,6 +420,7 @@ def _panoptic_quality_update( flatten_target_single, cat_id_to_continuous_id, void_color, + stuffs_modified_metric=modified_metric_stuffs, ) iou_sum += result[0] true_positives += result[1] diff --git a/src/torchmetrics/functional/detection/modified_panoptic_quality.py b/src/torchmetrics/functional/detection/modified_panoptic_quality.py new file mode 100644 index 00000000000..cda9c894b1d --- /dev/null +++ b/src/torchmetrics/functional/detection/modified_panoptic_quality.py @@ -0,0 +1,101 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Collection + +from torch import Tensor + +from torchmetrics.functional.detection._panoptic_quality_common import ( + _get_category_id_to_continuous_id, + _get_void_color, + _panoptic_quality_compute, + _panoptic_quality_update, + _parse_categories, + _prepocess_inputs, + _validate_inputs, +) + + +def modified_panoptic_quality( + preds: Tensor, + target: Tensor, + things: Collection[int], + stuffs: Collection[int], + allow_unknown_preds_category: bool = False, +) -> Tensor: + r"""Compute `Modified Panoptic Quality`_ for panoptic segmentations. + + The metric was introduced in `Seamless Scene Segmentation paper`_, and is an adaptation of the original + `Panoptic Quality`_ where the metric for a stuff class is computed as + + .. math:: + PQ^{\dagger}_c = \frac{IOU_c}{|S_c|} + + where IOU_c is the sum of the intersection over union of all matching segments for a given class, and \|S_c| is + the overall number of segments in the ground truth for that class. + + .. note: + Points in the target tensor that do not map to a known category ID are automatically ignored in the metric + computation. + + Args: + preds: + torch tensor with panoptic detection of shape [height, width, 2] containing the pair + (category_id, instance_id) for each pixel of the image. If the category_id refer to a stuff, the + instance_id is ignored. + target: + torch tensor with ground truth of shape [height, width, 2] containing the pair (category_id, instance_id) + for each pixel of the image. If the category_id refer to a stuff, the instance_id is ignored. + things: + Set of ``category_id`` for countable things. + stuffs: + Set of ``category_id`` for uncountable stuffs. + allow_unknown_preds_category: + Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric + computation or raise an exception when found. + + Raises: + ValueError: + If ``things``, ``stuffs`` have at least one common ``category_id``. + TypeError: + If ``things``, ``stuffs`` contain non-integer ``category_id``. + TypeError: + If ``preds`` or ``target`` is not an ``torch.Tensor``. + ValueError: + If ``preds`` or ``target`` has different shape. + ValueError: + If ``preds`` has less than 3 dimensions. + ValueError: + If the final dimension of ``preds`` has size != 2. + + Example: + >>> from torch import tensor + >>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]]) + >>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]]) + >>> modified_panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}) + tensor(0.7667, dtype=torch.float64) + """ + things, stuffs = _parse_categories(things, stuffs) + _validate_inputs(preds, target) + void_color = _get_void_color(things, stuffs) + cat_id_to_continuous_id = _get_category_id_to_continuous_id(things, stuffs) + flatten_preds = _prepocess_inputs(things, stuffs, preds, void_color, allow_unknown_preds_category) + flatten_target = _prepocess_inputs(things, stuffs, target, void_color, True) + iou_sum, true_positives, false_positives, false_negatives = _panoptic_quality_update( + flatten_preds, + flatten_target, + cat_id_to_continuous_id, + void_color, + modified_metric_stuffs=stuffs, + ) + return _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives) diff --git a/tests/unittests/detection/test_modified_panoptic_quality.py b/tests/unittests/detection/test_modified_panoptic_quality.py new file mode 100644 index 00000000000..33eaec17a99 --- /dev/null +++ b/tests/unittests/detection/test_modified_panoptic_quality.py @@ -0,0 +1,198 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from typing import Any, Dict + +import numpy as np +import pytest +import torch + +from torchmetrics.detection.modified_panoptic_quality import ModifiedPanopticQuality +from torchmetrics.functional.detection.modified_panoptic_quality import modified_panoptic_quality +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) + +_INPUTS_0 = Input( + # Shape of input tensors is (num_batches, batch_size, height, width, 2). + preds=torch.tensor( + [ + [[6, 0], [0, 0], [6, 0], [6, 0], [0, 1]], + [[0, 0], [0, 0], [6, 0], [0, 1], [0, 1]], + [[0, 0], [0, 0], [6, 0], [0, 1], [1, 0]], + [[0, 0], [7, 0], [6, 0], [1, 0], [1, 0]], + [[0, 0], [7, 0], [7, 0], [7, 0], [7, 0]], + ] + ) + .reshape((1, 1, 5, 5, 2)) + .repeat(2, 1, 1, 1, 1), + target=torch.tensor( + [ + [[6, 0], [6, 0], [6, 0], [6, 0], [0, 0]], + [[0, 1], [0, 1], [6, 0], [0, 0], [0, 0]], + [[0, 1], [0, 1], [6, 0], [1, 0], [1, 0]], + [[0, 1], [7, 0], [7, 0], [1, 0], [1, 0]], + [[0, 1], [7, 0], [7, 0], [7, 0], [7, 0]], + ] + ) + .reshape((1, 1, 5, 5, 2)) + .repeat(2, 1, 1, 1, 1), +) +_INPUTS_1 = Input( + # Shape of input tensors is (num_batches, batch_size, num_points, 2). + # NOTE: IoU for stuff category 6 is < 0.5, modified PQ behaves differently there. + preds=torch.tensor([[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]).reshape((1, 1, 6, 2)).repeat(2, 1, 1, 1), + target=torch.tensor([[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]).reshape((1, 1, 6, 2)).repeat(2, 1, 1, 1), +) +_ARGS_0 = {"things": {0, 1}, "stuffs": {6, 7}} +_ARGS_1 = {"things": {2}, "stuffs": {3}, "allow_unknown_preds_category": True} +_ARGS_2 = {"things": {0, 1}, "stuffs": {6, 7}} + +# TODO: Improve _compare_fn by calling https://github.com/cocodataset/panopticapi/blob/master/panopticapi/evaluation.py +# directly and compare at runtime on multiple examples. + + +def _compare_fn_0_0(preds, target) -> np.ndarray: + """Baseline result for the _INPUTS_0, _ARGS_0 combination.""" + return np.array([0.7753]) + + +def _compare_fn_0_1(preds, target) -> np.ndarray: + """Baseline result for the _INPUTS_0, _ARGS_1 combination.""" + return np.array([np.nan]) + + +def _compare_fn_1_2(preds, target) -> np.ndarray: + """Baseline result for the _INPUTS_1, _ARGS_2 combination.""" + return np.array([23 / 30]) + + +class TestModifiedPanopticQuality(MetricTester): + """Test class for `ModifiedPanopticQuality` metric.""" + + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize( + ("inputs", "args", "reference_metric"), + [ + (_INPUTS_0, _ARGS_0, _compare_fn_0_0), + (_INPUTS_0, _ARGS_1, _compare_fn_0_1), + (_INPUTS_1, _ARGS_2, _compare_fn_1_2), + ], + ) + def test_panoptic_quality_class(self, ddp, inputs, args, reference_metric): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=inputs.preds, + target=inputs.target, + metric_class=ModifiedPanopticQuality, + reference_metric=reference_metric, + check_batch=False, + metric_args=args, + ) + + def test_panoptic_quality_functional(self): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + _INPUTS_0.preds, + _INPUTS_0.target, + metric_functional=modified_panoptic_quality, + reference_metric=_compare_fn_0_0, + metric_args=_ARGS_0, + ) + + +def test_empty_metric(): + """Test empty metric.""" + with pytest.raises(ValueError, match="At least one of `things` and `stuffs` must be non-empty"): + metric = ModifiedPanopticQuality(things=[], stuffs=[]) + + metric = ModifiedPanopticQuality(things=[0], stuffs=[]) + assert torch.isnan(metric.compute()) + + +def test_error_on_wrong_input(): + """Test class input validation.""" + with pytest.raises(TypeError, match="Expected argument `stuffs` to contain `int` categories.*"): + ModifiedPanopticQuality(things={0}, stuffs={"sky"}) + + with pytest.raises(ValueError, match="Expected arguments `things` and `stuffs` to have distinct keys.*"): + ModifiedPanopticQuality(things={0}, stuffs={0}) + + metric = ModifiedPanopticQuality(things={0, 1, 3}, stuffs={2, 8}, allow_unknown_preds_category=True) + valid_images = torch.randint(low=0, high=9, size=(8, 64, 64, 2)) + metric.update(valid_images, valid_images) + valid_point_clouds = torch.randint(low=0, high=9, size=(1, 100, 2)) + metric.update(valid_point_clouds, valid_point_clouds) + + with pytest.raises(TypeError, match="Expected argument `preds` to be of type `torch.Tensor`.*"): + metric.update([], valid_images) + + with pytest.raises(TypeError, match="Expected argument `target` to be of type `torch.Tensor`.*"): + metric.update(valid_images, []) + + preds = torch.randint(low=0, high=9, size=(2, 400, 300, 2)) + target = torch.randint(low=0, high=9, size=(2, 30, 40, 2)) + with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same shape.*"): + metric.update(preds, target) + + preds = torch.randint(low=0, high=9, size=(1, 2)) + with pytest.raises(ValueError, match="Expected argument `preds` to have at least one spatial dimension.*"): + metric.update(preds, preds) + + preds = torch.randint(low=0, high=9, size=(1, 64, 64, 8)) + with pytest.raises( + ValueError, match="Expected argument `preds` to have exactly 2 channels in the last dimension.*" + ): + metric.update(preds, preds) + + metric = ModifiedPanopticQuality(things=[0], stuffs=[1], allow_unknown_preds_category=False) + preds = torch.randint(low=0, high=1, size=(1, 100, 2)) + preds[0, 0, 0] = 2 + with pytest.raises(ValueError, match="Unknown categories found.*"): + metric.update(preds, preds) + + +def test_extreme_values(): + """Test that the metric returns expected values in trivial cases.""" + # Exact match between preds and target => metric is 1 + assert modified_panoptic_quality(_INPUTS_0.target[0], _INPUTS_0.target[0], **_ARGS_0) == 1.0 + # Every element of the prediction is wrong => metric is 0 + assert modified_panoptic_quality(_INPUTS_0.target[0], _INPUTS_0.target[0] + 1, **_ARGS_0) == 0.0 + + +@pytest.mark.parametrize( + ("inputs", "args", "cat_dim"), + [ + (_INPUTS_0, _ARGS_0, 0), + (_INPUTS_0, _ARGS_0, 1), + (_INPUTS_0, _ARGS_0, 2), + (_INPUTS_1, _ARGS_2, 0), + (_INPUTS_1, _ARGS_2, 1), + ], +) +def test_ignore_mask(inputs: Input, args: Dict[str, Any], cat_dim: int): + """Test that the metric correctly ignores regions of the inputs that do not map to a know category ID.""" + preds = inputs.preds[0] + target = inputs.target[0] + value = modified_panoptic_quality(preds, target, **args) + ignored_regions = torch.zeros_like(preds) + ignored_regions[..., 0] = 255 + preds_new = torch.cat([preds, preds], dim=cat_dim) + target_new = torch.cat([target, ignored_regions], dim=cat_dim) + value_new = modified_panoptic_quality(preds_new, target_new, **args) + assert value == value_new From f735df14a9deed13f5e58dfec2da2b42d275990d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 21 Mar 2023 11:28:06 +0100 Subject: [PATCH 4/6] Update documentation to lightning 2.0 (#1633) Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- docs/source/pages/lightning.rst | 66 ++++++++------------ docs/source/pages/overview.rst | 26 -------- tests/integrations/lightning/boring_model.py | 33 ++-------- tests/integrations/test_lightning.py | 40 ++++-------- 4 files changed, 43 insertions(+), 122 deletions(-) diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index f845b376443..fb7f8408c95 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -45,29 +45,31 @@ The example below shows how to use a metric in your `LightningModule `_ method, -Lightning will log the metric based on ``on_step`` and ``on_epoch`` flags present in ``self.log(...)``. If ``on_epoch`` is True, the logger automatically logs the end of epoch metric -value by calling ``.compute()``. +Logging metrics can be done in two ways: either logging the metric object directly or the computed metric values. +When :class:`~torchmetrics.Metric` objects, which return a scalar tensor are logged directly in Lightning using the +LightningModule `self.log `_ +method, Lightning will log the metric based on ``on_step`` and ``on_epoch`` flags present in ``self.log(...)``. If +``on_epoch`` is True, the logger automatically logs the end of epoch metric value by calling ``.compute()``. .. note:: - ``sync_dist``, ``sync_dist_op``, ``sync_dist_group``, ``reduce_fx`` and ``tbptt_reduce_fx`` - flags from ``self.log(...)`` don't affect the metric logging in any manner. The metric class - contains its own distributed synchronization logic. + ``sync_dist``, ``sync_dist_group`` and ``reduce_fx`` flags from ``self.log(...)`` don't affect the metric logging + in any manner. The metric class contains its own distributed synchronization logic. This however is only true for metrics that inherit the base class ``Metric``, and thus the functional metric API provides no support for in-built distributed synchronization @@ -96,8 +98,8 @@ value by calling ``.compute()``. self.valid_acc(logits, y) self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True) -As an alternative to logging the metric object and letting Lightning take care of when to reset the metric etc. you can also manually log the output -of the metrics. +As an alternative to logging the metric object and letting Lightning take care of when to reset the metric etc. you can +also manually log the output of the metrics. .. testcode:: python @@ -115,7 +117,7 @@ of the metrics. batch_value = self.train_acc(preds, y) self.log('train_acc_step', batch_value) - def training_epoch_end(self, outputs): + def on_train_epoch_end(self): self.train_acc.reset() def validation_step(self, batch, batch_idx): @@ -123,19 +125,20 @@ of the metrics. ... self.valid_acc.update(logits, y) - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self, outputs): self.log('valid_acc_epoch', self.valid_acc.compute()) self.valid_acc.reset() -Note that logging metrics this way will require you to manually reset the metrics at the end of the epoch yourself. In general, we recommend logging -the metric object to make sure that metrics are correctly computed and reset. Additionally, we highly recommend that the two ways of logging are not -mixed as it can lead to wrong results. +Note that logging metrics this way will require you to manually reset the metrics at the end of the epoch yourself. +In general, we recommend logging the metric object to make sure that metrics are correctly computed and reset. +Additionally, we highly recommend that the two ways of logging are not mixed as it can lead to wrong results. .. note:: - When using any Modular metric, calling ``self.metric(...)`` or ``self.metric.forward(...)`` serves the dual purpose of calling ``self.metric.update()`` - on its input and simultaneously returning the metric value over the provided input. So if you are logging a metric *only* on epoch-level (as in the - example above), it is recommended to call ``self.metric.update()`` directly to avoid the extra computation. + When using any Modular metric, calling ``self.metric(...)`` or ``self.metric.forward(...)`` serves the dual purpose + of calling ``self.metric.update()`` on its input and simultaneously returning the metric value over the provided + input. So if you are logging a metric *only* on epoch-level (as in the example above), it is recommended to call + ``self.metric.update()`` directly to avoid the extra computation. .. testcode:: python @@ -158,25 +161,6 @@ Common Pitfalls The following contains a list of pitfalls to be aware of: -* If using metrics in data parallel mode (dp), the metric update/logging should be done - in the ``_step_end`` method (where ```` is either ``training``, ``validation`` - or ``test``). This is because ``dp`` split the batches during the forward pass and metric states are destroyed after each forward pass, thus leading to wrong accumulation. In practice do the following: - -.. testcode:: python - - class MyModule(LightningModule): - - def training_step(self, batch, batch_idx): - data, target = batch - preds = self(data) - # ... - return {'loss': loss, 'preds': preds, 'target': target} - - def training_step_end(self, outputs): - # update and log - self.metric(outputs['preds'], outputs['target']) - self.log('metric', self.metric) - * Modular metrics contain internal states that should belong to only one DataLoader. In case you are using multiple DataLoaders, it is recommended to initialize a separate modular metric instances for each DataLoader and use them separately. The same holds for using seperate metrics for training, validation and testing. diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index 8c523889b34..f30e38b1a7d 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -130,32 +130,6 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics. You can always check which device the metric is located on using the `.device` property. -Metrics in Dataparallel (DP) mode -================================= - -When using metrics in `Dataparallel (DP) `_ -mode, one should be aware DP will both create and clean-up replicas of Metric objects during a single forward pass. -This has the consequence, that the metric state of the replicas will as default be destroyed before we can sync -them. It is therefore recommended, when using metrics in DP mode, to initialize them with ``dist_sync_on_step=True`` -such that metric states are synchonized between the main process and the replicas before they are destroyed. - -Addtionally, if metrics are used together with a `LightningModule` the metric update/logging should be done -in the ``_step_end`` method (where ```` is either ``training``, ``validation`` or ``test``), else -it will lead to wrong accumulation. In practice do the following: - -.. testcode:: - - def training_step(self, batch, batch_idx): - data, target = batch - preds = self(data) - ... - return {'loss': loss, 'preds': preds, 'target': target} - - def training_step_end(self, outputs): - #update and log - self.metric(outputs['preds'], outputs['target']) - self.log('metric', self.metric) - Metrics in Distributed Data Parallel (DDP) mode =============================================== diff --git a/tests/integrations/lightning/boring_model.py b/tests/integrations/lightning/boring_model.py index 143ef756983..f67dd0e9b98 100644 --- a/tests/integrations/lightning/boring_model.py +++ b/tests/integrations/lightning/boring_model.py @@ -90,55 +90,32 @@ def training_step(self, batch, batch_idx): loss = self.loss(batch, output) return {"loss": loss} - @staticmethod - def training_step_end(training_step_outputs): - """Run at the end of a training step. Needed when using multiple devices.""" - return training_step_outputs - - @staticmethod - def training_epoch_end(outputs) -> None: - """Run at the end of a training epoch.""" - torch.stack([x["loss"] for x in outputs]).mean() - def validation_step(self, batch, batch_idx): """Single validation step in the model.""" output = self.layer(batch) loss = self.loss(batch, output) return {"x": loss} - @staticmethod - def validation_epoch_end(outputs) -> None: - """Run at the end of each validation epoch.""" - torch.stack([x["x"] for x in outputs]).mean() - def test_step(self, batch, batch_idx): """Single test step in the model.""" output = self.layer(batch) loss = self.loss(batch, output) return {"y": loss} - @staticmethod - def test_epoch_end(outputs) -> None: - """Run at the end of each test epoch.""" - torch.stack([x["y"] for x in outputs]).mean() - def configure_optimizers(self): """Configure which optimizer to use when training the model.""" optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) - return [optimizer], [lr_scheduler] + lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) + return {"optimizer": optimizer, "scheduler": lr_scheduler} - @staticmethod - def train_dataloader(): + def train_dataloader(self): """Define train dataloader used for training the model.""" return torch.utils.data.DataLoader(RandomDataset(32, 64)) - @staticmethod - def val_dataloader(): + def val_dataloader(self): """Define validation dataloader used for validating the model.""" return torch.utils.data.DataLoader(RandomDataset(32, 64)) - @staticmethod - def test_dataloader(): + def test_dataloader(self): """Define test dataloader used for testing the mdoel.""" return torch.utils.data.DataLoader(RandomDataset(32, 64)) diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index e627eac2ea8..81cd8b15fff 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -49,7 +49,7 @@ def training_step(self, batch, batch_idx): return self.step(x) - def training_epoch_end(self, outs): + def on_training_epoch_end(self): if not torch.allclose(self.sum, self.metric.compute()): raise ValueError("Sum and computed value must be equal") self.sum = 0.0 @@ -71,10 +71,10 @@ def training_epoch_end(self, outs): def test_metrics_reset(tmpdir): """Tests that metrics are reset correctly after the end of the train/val/test epoch. - Taken from: `Metric Test for Reset`_ + Taken from: `Metric Test for Reset`_ """ - class TestModel(LightningModule): + class TestModel(BoringModel): def __init__(self) -> None: super().__init__() self.layer = torch.nn.Linear(32, 1) @@ -122,23 +122,6 @@ def validation_step(self, batch, batch_idx, *args, **kwargs): def test_step(self, batch, batch_idx, *args, **kwargs): return self._step("test", batch) - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) - return [optimizer], [lr_scheduler] - - @staticmethod - def train_dataloader(): - return DataLoader(RandomDataset(32, 64), batch_size=2) - - @staticmethod - def val_dataloader(): - return DataLoader(RandomDataset(32, 64), batch_size=2) - - @staticmethod - def test_dataloader(): - return DataLoader(RandomDataset(32, 64), batch_size=2) - def _assert_epoch_end(self, stage): acc = self._modules[f"acc_{stage}"] ap = self._modules[f"ap_{stage}"] @@ -146,13 +129,13 @@ def _assert_epoch_end(self, stage): acc.reset.asset_not_called() ap.reset.assert_not_called() - def train_epoch_end(self, outputs): + def on_train_epoch_end(self): self._assert_epoch_end("train") - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): self._assert_epoch_end("val") - def test_epoch_end(self, outputs): + def on_test_epoch_end(self): self._assert_epoch_end("test") def _assert_called(model, stage): @@ -194,6 +177,7 @@ def __init__(self) -> None: self.metric_step = SumMetric() self.metric_epoch = SumMetric() self.sum = torch.tensor(0.0) + self.outs = [] def on_train_epoch_start(self): self.sum = torch.tensor(0.0) @@ -203,10 +187,12 @@ def training_step(self, batch, batch_idx): self.metric_step(x.sum()) self.sum += x.sum() self.log("sum_step", self.metric_step, on_epoch=True, on_step=False) - return {"loss": self.step(x), "data": x} + self.outs.append(x) + return self.step(x) - def training_epoch_end(self, outs): - self.log("sum_epoch", self.metric_epoch(torch.stack([o["data"] for o in outs]).sum())) + def on_train_epoch_end(self): + self.log("sum_epoch", self.metric_epoch(torch.stack(self.outs))) + self.outs = [] model = TestModel() @@ -246,7 +232,7 @@ def training_step(self, batch, batch_idx): self.log_dict({f"{k}_step": v for k, v in metric_vals.items()}) return self.step(x) - def training_epoch_end(self, outputs): + def on_train_epoch_end(self): metric_vals = self.metric.compute() self.log_dict({f"{k}_epoch": v for k, v in metric_vals.items()}) From 9055b99f0db35a9420e9777d7036877f66dd4a0a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 21 Mar 2023 13:07:25 +0100 Subject: [PATCH 5/6] Add plotting 14/n (#1631) --- CHANGELOG.md | 1 + src/torchmetrics/text/cer.py | 51 +++++++++++++++++++++++++- src/torchmetrics/text/eed.py | 51 +++++++++++++++++++++++++- src/torchmetrics/text/mer.py | 51 +++++++++++++++++++++++++- src/torchmetrics/text/wer.py | 50 ++++++++++++++++++++++++- src/torchmetrics/text/wil.py | 51 +++++++++++++++++++++++++- src/torchmetrics/text/wip.py | 51 +++++++++++++++++++++++++- tests/unittests/utilities/test_plot.py | 16 ++++++++ 8 files changed, 311 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 40ba935bf4b..73795d4194e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1621](https://github.com/Lightning-AI/metrics/pull/1621), [#1624](https://github.com/Lightning-AI/metrics/pull/1624), [#1623](https://github.com/Lightning-AI/metrics/pull/1623), + [#1631](https://github.com/Lightning-AI/metrics/pull/1631), ) diff --git a/src/torchmetrics/text/cer.py b/src/torchmetrics/text/cer.py index 6dcd06cc9a2..4d90ad1b754 100644 --- a/src/torchmetrics/text/cer.py +++ b/src/torchmetrics/text/cer.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union import torch from torch import Tensor, tensor from torchmetrics.functional.text.cer import _cer_compute, _cer_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["CharErrorRate.plot"] class CharErrorRate(Metric): @@ -84,3 +88,46 @@ def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> def compute(self) -> Tensor: """Calculate the character error rate.""" return _cer_compute(self.errors, self.total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import CharErrorRate + >>> metric = CharErrorRate() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import CharErrorRate + >>> metric = CharErrorRate() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/eed.py b/src/torchmetrics/text/eed.py index a28c6b18294..56c5c27c230 100644 --- a/src/torchmetrics/text/eed.py +++ b/src/torchmetrics/text/eed.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, List, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union from torch import Tensor, stack from typing_extensions import Literal from torchmetrics.functional.text.eed import _eed_compute, _eed_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["ExtendedEditDistance.plot"] class ExtendedEditDistance(Metric): @@ -112,3 +116,46 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: if self.return_sentence_level_score: return average, stack(self.sentence_eed) return average + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import ExtendedEditDistance + >>> metric = ExtendedEditDistance() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import ExtendedEditDistance + >>> metric = ExtendedEditDistance() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/mer.py b/src/torchmetrics/text/mer.py index 088a5ab9fe0..1e6282b5c85 100644 --- a/src/torchmetrics/text/mer.py +++ b/src/torchmetrics/text/mer.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union import torch from torch import Tensor, tensor from torchmetrics.functional.text.mer import _mer_compute, _mer_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MatchErrorRate.plot"] class MatchErrorRate(Metric): @@ -85,3 +89,46 @@ def update( def compute(self) -> Tensor: """Calculate the Match error rate.""" return _mer_compute(self.errors, self.total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import MatchErrorRate + >>> metric = MatchErrorRate() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import MatchErrorRate + >>> metric = MatchErrorRate() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/wer.py b/src/torchmetrics/text/wer.py index 4687da840b5..2d62438a292 100644 --- a/src/torchmetrics/text/wer.py +++ b/src/torchmetrics/text/wer.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union import torch from torch import Tensor, tensor from torchmetrics.functional.text.wer import _wer_compute, _wer_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["WordErrorRate.plot"] class WordErrorRate(Metric): @@ -82,3 +87,46 @@ def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> def compute(self) -> Tensor: """Calculate the word error rate.""" return _wer_compute(self.errors, self.total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import WordErrorRate + >>> metric = WordErrorRate() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import WordErrorRate + >>> metric = WordErrorRate() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/wil.py b/src/torchmetrics/text/wil.py index d0c9c8bc96d..698854366f0 100644 --- a/src/torchmetrics/text/wil.py +++ b/src/torchmetrics/text/wil.py @@ -11,13 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union from torch import Tensor, tensor from torchmetrics.functional.text.wil import _wil_compute, _wil_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["WordInfoLost.plot"] class WordInfoLost(Metric): @@ -83,3 +87,46 @@ def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> def compute(self) -> Tensor: """Calculate the Word Information Lost.""" return _wil_compute(self.errors, self.target_total, self.preds_total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import WordInfoLost + >>> metric = WordInfoLost() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import WordInfoLost + >>> metric = WordInfoLost() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/wip.py b/src/torchmetrics/text/wip.py index 335c524a00d..992d6021eca 100644 --- a/src/torchmetrics/text/wip.py +++ b/src/torchmetrics/text/wip.py @@ -11,13 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union from torch import Tensor, tensor from torchmetrics.functional.text.wip import _wip_compute, _wip_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["WordInfoPreserved.plot"] class WordInfoPreserved(Metric): @@ -84,3 +88,46 @@ def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> def compute(self) -> Tensor: """Calculate the Word Information Preserved.""" return _wip_compute(self.errors, self.target_total, self.preds_total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import WordInfoPreserved + >>> metric = WordInfoPreserved() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import WordInfoPreserved + >>> metric = WordInfoPreserved() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index dc4c742a87b..c181c88ceea 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -117,6 +117,14 @@ RetrievalRecallAtFixedPrecision, RetrievalRPrecision, ) +from torchmetrics.text import ( + CharErrorRate, + ExtendedEditDistance, + MatchErrorRate, + WordErrorRate, + WordInfoLost, + WordInfoPreserved, +) _rand_input = lambda: torch.rand(10) _binary_randint_input = lambda: torch.randint(2, (10,)) @@ -130,6 +138,8 @@ torch.tensor([1, 1, 0, 0, 0, 0, 1, 1]).float(), 40, replacement=True ).reshape(1, 5, 4, 2) _nominal_input = lambda: torch.randint(0, 4, (100,)) +_text_input_1 = lambda: ["this is the prediction", "there is an other sample"] +_text_input_2 = lambda: ["this is the reference", "there is another one"] @pytest.mark.parametrize( @@ -422,6 +432,12 @@ pytest.param(SymmetricMeanAbsolutePercentageError, _rand_input, _rand_input, id="symmetric mape"), pytest.param(TweedieDevianceScore, _rand_input, _rand_input, id="tweedie deviance score"), pytest.param(WeightedMeanAbsolutePercentageError, _rand_input, _rand_input, id="weighted mape"), + pytest.param(WordInfoPreserved, _text_input_1, _text_input_2, id="word info preserved"), + pytest.param(WordInfoLost, _text_input_1, _text_input_2, id="word info lost"), + pytest.param(WordErrorRate, _text_input_1, _text_input_2, id="word error rate"), + pytest.param(CharErrorRate, _text_input_1, _text_input_2, id="character error rate"), + pytest.param(ExtendedEditDistance, _text_input_1, _text_input_2, id="extended edit distance"), + pytest.param(MatchErrorRate, _text_input_1, _text_input_2, id="match error rate"), ], ) @pytest.mark.parametrize("num_vals", [1, 5]) From b613c5059e42f0b7ce05b3fdf56d238ccdb1a5ce Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 21 Mar 2023 15:51:11 +0100 Subject: [PATCH 6/6] Add plotting 15/n (#1638) --- CHANGELOG.md | 1 + src/torchmetrics/classification/dice.py | 48 ++++++- .../classification/exact_match.py | 89 +++++++++++- src/torchmetrics/classification/hamming.py | 134 +++++++++++++++++- src/torchmetrics/classification/hinge.py | 89 +++++++++++- src/torchmetrics/classification/jaccard.py | 130 ++++++++++++++++- tests/unittests/utilities/test_plot.py | 57 ++++++++ 7 files changed, 543 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 73795d4194e..afefa880f8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1621](https://github.com/Lightning-AI/metrics/pull/1621), [#1624](https://github.com/Lightning-AI/metrics/pull/1624), [#1623](https://github.com/Lightning-AI/metrics/pull/1623), + [#1638](https://github.com/Lightning-AI/metrics/pull/1638), [#1631](https://github.com/Lightning-AI/metrics/pull/1631), ) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index e8e54cc2ed6..fc98037076d 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Tuple, no_type_check +from typing import Any, Callable, Optional, Sequence, Tuple, Union, no_type_check import torch from torch import Tensor @@ -21,6 +21,11 @@ from torchmetrics.functional.classification.stat_scores import _stat_scores_update from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["Dice.plot"] class Dice(Metric): @@ -235,3 +240,44 @@ def compute(self) -> Tensor: """Compute metric.""" tp, fp, _, fn = self._get_final_stats() return _dice_compute(tp, fp, fn, self.average, self.mdmc_reduce, self.zero_division) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torch import randint + >>> from torchmetrics.classification import Dice + >>> metric = Dice() + >>> metric.update(randint(2,(10,)), randint(2,(10,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torch import randint + >>> from torchmetrics.classification import Dice + >>> metric = Dice() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(randint(2,(10,)), randint(2,(10,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 596dd0f3dd5..8ddcacbee37 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor @@ -33,6 +33,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.enums import ClassificationTaskNoBinary +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MulticlassExactMatch.plot", "MultilabelExactMatch.plot"] class MulticlassExactMatch(Metric): @@ -140,6 +145,47 @@ def compute(self) -> Tensor: correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct return _exact_match_reduce(correct, self.total) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value per class + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassExactMatch + >>> metric = MulticlassExactMatch(num_classes=3) + >>> metric.update(randint(3, (20,5)), randint(3, (20,5))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> # Example plotting a multiple values per class + >>> from torchmetrics.classification import MulticlassExactMatch + >>> metric = MulticlassExactMatch(num_classes=3) + >>> values = [] + >>> for _ in range(20): + ... values.append(metric(randint(3, (20,5)), randint(3, (20,5)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class MultilabelExactMatch(Metric): r"""Compute Exact match (also known as subset accuracy) for multilabel tasks. @@ -261,6 +307,47 @@ def compute(self) -> Tensor: correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct return _exact_match_reduce(correct, self.total) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torch import rand, randint + >>> from torchmetrics.classification import MultilabelExactMatch + >>> metric = MultilabelExactMatch(num_labels=3) + >>> metric.update(randint(2, (20, 3, 5)), randint(2, (20, 3, 5))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torch import rand, randint + >>> from torchmetrics.classification import MultilabelExactMatch + >>> metric = MultilabelExactMatch(num_labels=3) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(randint(2, (20, 3, 5)), randint(2, (20, 3, 5)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class ExactMatch: r"""Compute Exact match (also known as subset accuracy). diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index f02bd3dcdf8..c426d94285c 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -20,6 +20,15 @@ from torchmetrics.functional.classification.hamming import _hamming_distance_reduce from torchmetrics.metric import Metric from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = [ + "BinaryHammingDistance.plot", + "MulticlassHammingDistance.plot", + "MultilabelHammingDistance.plot", + ] class BinaryHammingDistance(BinaryStatScores): @@ -98,6 +107,47 @@ def compute(self) -> Tensor: tp, fp, tn, fn = self._final_state() return _hamming_distance_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torch import rand, randint + >>> from torchmetrics.classification import BinaryHammingDistance + >>> metric = BinaryHammingDistance() + >>> metric.update(rand(10), randint(2,(10,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torch import rand, randint + >>> from torchmetrics.classification import BinaryHammingDistance + >>> metric = BinaryHammingDistance() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(rand(10), randint(2,(10,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class MulticlassHammingDistance(MulticlassStatScores): r"""Compute the average `Hamming distance`_ (also known as Hamming loss) for multiclass tasks. @@ -204,6 +254,47 @@ def compute(self) -> Tensor: tp, fp, tn, fn = self._final_state() return _hamming_distance_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value per class + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassHammingDistance + >>> metric = MulticlassHammingDistance(num_classes=3, average=None) + >>> metric.update(randint(3, (20,)), randint(3, (20,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting a multiple values per class + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassHammingDistance + >>> metric = MulticlassHammingDistance(num_classes=3, average=None) + >>> values = [] + >>> for _ in range(20): + ... values.append(metric(randint(3, (20,)), randint(3, (20,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class MultilabelHammingDistance(MultilabelStatScores): r"""Compute the average `Hamming distance`_ (also known as Hamming loss) for multilabel tasks. @@ -310,6 +401,47 @@ def compute(self) -> Tensor: tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True ) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torch import rand, randint + >>> from torchmetrics.classification import MultilabelHammingDistance + >>> metric = MultilabelHammingDistance(num_labels=3) + >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torch import rand, randint + >>> from torchmetrics.classification import MultilabelHammingDistance + >>> metric = MultilabelHammingDistance(num_labels=3) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class HammingDistance: r"""Compute the average `Hamming distance`_ (also known as Hamming loss). diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index 381eaf45354..d15eec8a8bd 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor @@ -30,6 +30,11 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BinaryHingeLoss.plot", "MulticlassHingeLoss.plot"] class BinaryHingeLoss(Metric): @@ -112,6 +117,47 @@ def compute(self) -> Tensor: """Compute metric.""" return _hinge_loss_compute(self.measures, self.total) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torch import rand, randint + >>> from torchmetrics.classification import BinaryHingeLoss + >>> metric = BinaryHingeLoss() + >>> metric.update(rand(10), randint(2,(10,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torch import rand, randint + >>> from torchmetrics.classification import BinaryHingeLoss + >>> metric = BinaryHingeLoss() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(rand(10), randint(2,(10,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class MulticlassHingeLoss(Metric): r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for multiclass tasks. @@ -216,6 +262,47 @@ def compute(self) -> Tensor: """Compute metric.""" return _hinge_loss_compute(self.measures, self.total) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value per class + >>> from torch import randint, randn + >>> from torchmetrics.classification import MulticlassHingeLoss + >>> metric = MulticlassHingeLoss(num_classes=3) + >>> metric.update(randn(20, 3), randint(3, (20,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting a multiple values per class + >>> from torch import randint, randn + >>> from torchmetrics.classification import MulticlassHingeLoss + >>> metric = MulticlassHingeLoss(num_classes=3) + >>> values = [] + >>> for _ in range(20): + ... values.append(metric(randn(20, 3), randint(3, (20,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class HingeLoss: r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs). diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 36cbaa6e4e8..b4e08d65f24 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -24,6 +24,11 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BinaryJaccardIndex.plot", "MulticlassJaccardIndex.plot", "MultilabelJaccardIndex.plot"] class BinaryJaccardIndex(BinaryConfusionMatrix): @@ -93,6 +98,47 @@ def compute(self) -> Tensor: """Compute metric.""" return _jaccard_index_reduce(self.confmat, average="binary") + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torch import rand, randint + >>> from torchmetrics.classification import BinaryJaccardIndex + >>> metric = BinaryJaccardIndex() + >>> metric.update(rand(10), randint(2,(10,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torch import rand, randint + >>> from torchmetrics.classification import BinaryJaccardIndex + >>> metric = BinaryJaccardIndex() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(rand(10), randint(2,(10,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class MulticlassJaccardIndex(MulticlassConfusionMatrix): r"""Calculate the Jaccard index for multiclass tasks. @@ -178,6 +224,47 @@ def compute(self) -> Tensor: """Compute metric.""" return _jaccard_index_reduce(self.confmat, average=self.average, ignore_index=self.ignore_index) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value per class + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassJaccardIndex + >>> metric = MulticlassJaccardIndex(num_classes=3, average=None) + >>> metric.update(randint(3, (20,)), randint(3, (20,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting a multiple values per class + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassJaccardIndex + >>> metric = MulticlassJaccardIndex(num_classes=3, average=None) + >>> values = [] + >>> for _ in range(20): + ... values.append(metric(randint(3, (20,)), randint(3, (20,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class MultilabelJaccardIndex(MultilabelConfusionMatrix): r"""Calculate the Jaccard index for multilabel tasks. @@ -267,6 +354,47 @@ def compute(self) -> Tensor: """Compute metric.""" return _jaccard_index_reduce(self.confmat, average=self.average) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torch import rand, randint + >>> from torchmetrics.classification import MultilabelJaccardIndex + >>> metric = MultilabelJaccardIndex(num_labels=3) + >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torch import rand, randint + >>> from torchmetrics.classification import MultilabelJaccardIndex + >>> metric = MultilabelJaccardIndex(num_labels=3) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class JaccardIndex: r"""Calculate the Jaccard index for multilabel tasks. diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index c181c88ceea..fa25c215840 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -38,18 +38,26 @@ BinaryCalibrationError, BinaryCohenKappa, BinaryConfusionMatrix, + BinaryHammingDistance, + BinaryHingeLoss, + BinaryJaccardIndex, BinaryMatthewsCorrCoef, BinaryPrecision, BinaryRecall, BinaryRecallAtFixedPrecision, BinaryROC, BinarySpecificity, + Dice, MulticlassAccuracy, MulticlassAUROC, MulticlassAveragePrecision, MulticlassCalibrationError, MulticlassCohenKappa, MulticlassConfusionMatrix, + MulticlassExactMatch, + MulticlassHammingDistance, + MulticlassHingeLoss, + MulticlassJaccardIndex, MulticlassMatthewsCorrCoef, MulticlassPrecision, MulticlassRecall, @@ -58,6 +66,9 @@ MultilabelAveragePrecision, MultilabelConfusionMatrix, MultilabelCoverageError, + MultilabelExactMatch, + MultilabelHammingDistance, + MultilabelJaccardIndex, MultilabelMatthewsCorrCoef, MultilabelPrecision, MultilabelRankingAveragePrecision, @@ -432,6 +443,52 @@ pytest.param(SymmetricMeanAbsolutePercentageError, _rand_input, _rand_input, id="symmetric mape"), pytest.param(TweedieDevianceScore, _rand_input, _rand_input, id="tweedie deviance score"), pytest.param(WeightedMeanAbsolutePercentageError, _rand_input, _rand_input, id="weighted mape"), + pytest.param(Dice, _multiclass_randint_input, _multiclass_randint_input, id="dice"), + pytest.param( + partial(MulticlassExactMatch, num_classes=3), + lambda: torch.randint(3, (20, 5)), + lambda: torch.randint(3, (20, 5)), + id="multiclass exact match", + ), + pytest.param( + partial(MultilabelExactMatch, num_labels=3), + lambda: torch.randint(2, (20, 3, 5)), + lambda: torch.randint(2, (20, 3, 5)), + id="multilabel exact match", + ), + pytest.param(BinaryHammingDistance, _rand_input, _binary_randint_input, id="binary hamming distance"), + pytest.param( + partial(MulticlassHammingDistance, num_classes=3), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass hamming distance", + ), + pytest.param( + partial(MultilabelHammingDistance, num_labels=3), + _multilabel_rand_input, + _multilabel_randint_input, + id="multilabel hamming distance", + ), + pytest.param(BinaryHingeLoss, _rand_input, _binary_randint_input, id="binary hinge loss"), + pytest.param( + partial(MulticlassHingeLoss, num_classes=3), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass hinge loss", + ), + pytest.param(BinaryJaccardIndex, _rand_input, _binary_randint_input, id="binary jaccard index"), + pytest.param( + partial(MulticlassJaccardIndex, num_classes=3), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass jaccard index", + ), + pytest.param( + partial(MultilabelJaccardIndex, num_labels=3), + _multilabel_rand_input, + _multilabel_randint_input, + id="multilabel jaccard index", + ), pytest.param(WordInfoPreserved, _text_input_1, _text_input_2, id="word info preserved"), pytest.param(WordInfoLost, _text_input_1, _text_input_2, id="word info lost"), pytest.param(WordErrorRate, _text_input_1, _text_input_2, id="word error rate"),