From a151c150e3062fa25319bce0c4ff42be4a9a098a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 6 Mar 2023 15:51:33 +0100 Subject: [PATCH 1/8] boilerplate --- src/torchmetrics/image/fid.py | 55 ++++++++++++++++++++++++++-- src/torchmetrics/image/inception.py | 56 +++++++++++++++++++++++++++-- src/torchmetrics/image/kid.py | 55 ++++++++++++++++++++++++++-- src/torchmetrics/image/lpip.py | 55 ++++++++++++++++++++++++++-- src/torchmetrics/image/rase.py | 54 +++++++++++++++++++++++++++- src/torchmetrics/image/rmse_sw.py | 54 +++++++++++++++++++++++++++- src/torchmetrics/image/tv.py | 5 +++ 7 files changed, 324 insertions(+), 10 deletions(-) diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 352f7d0a647..a0ae9371622 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Sequence, Union import numpy as np import torch @@ -22,7 +22,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_info -from torchmetrics.utilities.imports import _SCIPY_AVAILABLE, _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _SCIPY_AVAILABLE, _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["FrechetInceptionDistance.plot"] if _TORCH_FIDELITY_AVAILABLE: from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 as _FeatureExtractorInceptionV3 @@ -303,3 +307,50 @@ def reset(self) -> None: self.real_features_num_samples = real_features_num_samples else: super().reset() + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index 9bf8cbe511f..ece4e74a373 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -21,7 +21,12 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["InceptionScore.plot"] + __doctest_requires__ = {("InceptionScore", "IS"): ["torch_fidelity"]} @@ -162,3 +167,50 @@ def compute(self) -> Tuple[Tensor, Tensor]: # return mean and std return kl.mean(), kl.std() + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index d52dfbfe0b8..8197296c5fc 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -21,7 +21,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["KernelInceptionDistance.plot"] __doctest_requires__ = {("KernelInceptionDistance", "KID"): ["torch_fidelity"]} @@ -274,3 +278,50 @@ def reset(self) -> None: self._defaults["real_features"] = value else: super().reset() + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index ae5216a1d0a..95496e7c41e 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, List +from typing import Any, List, Optional, Sequence, Union import torch from torch import Tensor @@ -21,7 +21,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout -from torchmetrics.utilities.imports import _LPIPS_AVAILABLE +from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity.plot"] if _LPIPS_AVAILABLE: from lpips import LPIPS as _LPIPS @@ -167,3 +171,50 @@ def compute(self) -> Tensor: if self.reduction == "sum": return self.sum_scores return None + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/image/rase.py b/src/torchmetrics/image/rase.py index e2b112a4671..08c45f13879 100644 --- a/src/torchmetrics/image/rase.py +++ b/src/torchmetrics/image/rase.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Sequence, Union import torch from torch import Tensor @@ -20,6 +20,11 @@ from torchmetrics.functional.image.rase import relative_average_spectral_error from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["RelativeAverageSpectralError.plot"] class RelativeAverageSpectralError(Metric): @@ -85,3 +90,50 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return relative_average_spectral_error(preds, target, self.window_size) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/image/rmse_sw.py b/src/torchmetrics/image/rmse_sw.py index 9f61b490ef3..80faa661b19 100644 --- a/src/torchmetrics/image/rmse_sw.py +++ b/src/torchmetrics/image/rmse_sw.py @@ -12,13 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch from torch import Tensor from torchmetrics.functional.image.rmse_sw import _rmse_sw_compute, _rmse_sw_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["RootMeanSquaredErrorUsingSlidingWindow.plot"] class RootMeanSquaredErrorUsingSlidingWindow(Metric): @@ -85,3 +90,50 @@ def compute(self) -> Optional[Tensor]: assert self.rmse_map is not None rmse, _ = _rmse_sw_compute(self.rmse_val_sum, self.rmse_map, self.total_images) return rmse + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index 9d30148cd33..d843a8f7676 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -20,6 +20,11 @@ from torchmetrics.functional.image.tv import _total_variation_compute, _total_variation_update from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["TotalVariation.plot"] class TotalVariation(Metric): From 0659d591af073923efdc5763ba30595b9831d62c Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 7 Mar 2023 12:21:56 +0100 Subject: [PATCH 2/8] add tests + fix code --- src/torchmetrics/image/fid.py | 30 ++++++---- src/torchmetrics/image/inception.py | 24 ++++---- src/torchmetrics/image/kid.py | 29 +++++---- src/torchmetrics/image/lpip.py | 24 +++----- src/torchmetrics/image/rase.py | 17 ++---- src/torchmetrics/image/rmse_sw.py | 18 ++---- src/torchmetrics/image/tv.py | 43 +++++++++++++- tests/unittests/utilities/test_plot.py | 82 +++++++++++++++++++++++++- 8 files changed, 186 insertions(+), 81 deletions(-) diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index a0ae9371622..b3b92c53c31 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -330,12 +330,12 @@ def plot( >>> # Example plotting a single value >>> import torch - >>> _ = torch.manual_seed(42) - >>> from torchmetrics import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) - >>> metric = SpectralDistortionIndex() - >>> metric.update(preds, target) + >>> from torchmetrics.image.fid import FrechetInceptionDistance + >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) + >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) + >>> metric = FrechetInceptionDistance(feature=64) + >>> metric.update(imgs_dist1, real=True) + >>> metric.update(imgs_dist2, real=False) >>> fig_, ax_ = metric.plot() .. plot:: @@ -343,14 +343,18 @@ def plot( >>> # Example plotting multiple values >>> import torch - >>> _ = torch.manual_seed(42) - >>> from torchmetrics import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) - >>> metric = SpectralDistortionIndex() + >>> from torchmetrics.image.kid import KernelInceptionDistance + >>> imgs_dist1 = lambda: torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) + >>> imgs_dist2 = lambda: torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) + >>> metric = FrechetInceptionDistance(feature=64) >>> values = [ ] - >>> for _ in range(10): - ... values.append(metric(preds, target)) + >>> for _ in range(3): + ... metric.update(imgs_dist1(), real=True) + ... metric.update(imgs_dist2(), real=False) + ... values.append(metric.compute()) + ... metric.reset() >>> fig_, ax_ = metric.plot(values) + + """ return self._plot(val, ax) diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index ece4e74a373..c8034af54af 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -190,27 +190,23 @@ def plot( >>> # Example plotting a single value >>> import torch - >>> _ = torch.manual_seed(42) - >>> from torchmetrics import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) - >>> metric = SpectralDistortionIndex() - >>> metric.update(preds, target) - >>> fig_, ax_ = metric.plot() + >>> from torchmetrics.image.inception import InceptionScore + >>> metric = InceptionScore() + >>> metric.update(torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8)) + >>> fig_, ax_ = metric.plot() # the returned plot only shows the mean value by default .. plot:: :scale: 75 >>> # Example plotting multiple values >>> import torch - >>> _ = torch.manual_seed(42) - >>> from torchmetrics import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) - >>> metric = SpectralDistortionIndex() + >>> from torchmetrics.image.inception import InceptionScore + >>> metric = InceptionScore() >>> values = [ ] - >>> for _ in range(10): - ... values.append(metric(preds, target)) + >>> for _ in range(3): + ... # we index by 0 such that only the mean value is plotted + ... values.append(metric(torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8))[0]) >>> fig_, ax_ = metric.plot(values) """ + val = val or self.compute()[0] # by default we select the mean to plot return self._plot(val, ax) diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index 8197296c5fc..0540e95aa16 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -301,12 +301,12 @@ def plot( >>> # Example plotting a single value >>> import torch - >>> _ = torch.manual_seed(42) - >>> from torchmetrics import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) - >>> metric = SpectralDistortionIndex() - >>> metric.update(preds, target) + >>> from torchmetrics.image.kid import KernelInceptionDistance + >>> imgs_dist1 = torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8) + >>> imgs_dist2 = torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8) + >>> metric = KernelInceptionDistance(subsets=3, subset_size=20) + >>> metric.update(imgs_dist1, real=True) + >>> metric.update(imgs_dist2, real=False) >>> fig_, ax_ = metric.plot() .. plot:: @@ -314,14 +314,17 @@ def plot( >>> # Example plotting multiple values >>> import torch - >>> _ = torch.manual_seed(42) - >>> from torchmetrics import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) - >>> metric = SpectralDistortionIndex() + >>> from torchmetrics.image.kid import KernelInceptionDistance + >>> imgs_dist1 = lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8) + >>> imgs_dist2 = lambda: torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8) + >>> metric = KernelInceptionDistance(subsets=3, subset_size=20) >>> values = [ ] - >>> for _ in range(10): - ... values.append(metric(preds, target)) + >>> for _ in range(3): + ... metric.update(imgs_dist1(), real=True) + ... metric.update(imgs_dist2(), real=False) + ... values.append(metric.compute()[0]) + ... metric.reset() >>> fig_, ax_ = metric.plot(values) """ + val = val or self.compute()[0] # by default we select the mean to plot return self._plot(val, ax) diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index 95496e7c41e..2e7133a27e5 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -34,13 +34,13 @@ def _download_lpips() -> None: _LPIPS(pretrained=True, net="vgg") if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_lpips): - __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LPIPS"] + __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LearnedPerceptualImagePatchSimilarity.plot"] else: class _LPIPS(Module): pass - __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LPIPS"] + __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LearnedPerceptualImagePatchSimilarity.plot"] class NoTrainLpips(_LPIPS): @@ -194,12 +194,9 @@ def plot( >>> # Example plotting a single value >>> import torch - >>> _ = torch.manual_seed(42) - >>> from torchmetrics import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) - >>> metric = SpectralDistortionIndex() - >>> metric.update(preds, target) + >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + >>> metric = LearnedPerceptualImagePatchSimilarity() + >>> metric.update(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100)) >>> fig_, ax_ = metric.plot() .. plot:: @@ -207,14 +204,11 @@ def plot( >>> # Example plotting multiple values >>> import torch - >>> _ = torch.manual_seed(42) - >>> from torchmetrics import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) - >>> metric = SpectralDistortionIndex() + >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + >>> metric = LearnedPerceptualImagePatchSimilarity() >>> values = [ ] - >>> for _ in range(10): - ... values.append(metric(preds, target)) + >>> for _ in range(3): + ... values.append(metric(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100))) >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) diff --git a/src/torchmetrics/image/rase.py b/src/torchmetrics/image/rase.py index 08c45f13879..7ed80c19afc 100644 --- a/src/torchmetrics/image/rase.py +++ b/src/torchmetrics/image/rase.py @@ -113,12 +113,9 @@ def plot( >>> # Example plotting a single value >>> import torch - >>> _ = torch.manual_seed(42) - >>> from torchmetrics import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) - >>> metric = SpectralDistortionIndex() - >>> metric.update(preds, target) + >>> from torchmetrics import RelativeAverageSpectralError + >>> metric = RelativeAverageSpectralError() + >>> metric.update(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16)) >>> fig_, ax_ = metric.plot() .. plot:: @@ -127,13 +124,11 @@ def plot( >>> # Example plotting multiple values >>> import torch >>> _ = torch.manual_seed(42) - >>> from torchmetrics import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) - >>> metric = SpectralDistortionIndex() + >>> from torchmetrics import RelativeAverageSpectralError + >>> metric = RelativeAverageSpectralError() >>> values = [ ] >>> for _ in range(10): - ... values.append(metric(preds, target)) + ... values.append(metric(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16))) >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) diff --git a/src/torchmetrics/image/rmse_sw.py b/src/torchmetrics/image/rmse_sw.py index 80faa661b19..4352ddca2e6 100644 --- a/src/torchmetrics/image/rmse_sw.py +++ b/src/torchmetrics/image/rmse_sw.py @@ -113,12 +113,9 @@ def plot( >>> # Example plotting a single value >>> import torch - >>> _ = torch.manual_seed(42) - >>> from torchmetrics import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) - >>> metric = SpectralDistortionIndex() - >>> metric.update(preds, target) + >>> from torchmetrics import RootMeanSquaredErrorUsingSlidingWindow + >>> metric = RootMeanSquaredErrorUsingSlidingWindow() + >>> metric.update(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16)) >>> fig_, ax_ = metric.plot() .. plot:: @@ -126,14 +123,11 @@ def plot( >>> # Example plotting multiple values >>> import torch - >>> _ = torch.manual_seed(42) - >>> from torchmetrics import SpectralDistortionIndex - >>> preds = torch.rand([16, 3, 16, 16]) - >>> target = torch.rand([16, 3, 16, 16]) - >>> metric = SpectralDistortionIndex() + >>> from torchmetrics import RootMeanSquaredErrorUsingSlidingWindow + >>> metric = RootMeanSquaredErrorUsingSlidingWindow() >>> values = [ ] >>> for _ in range(10): - ... values.append(metric(preds, target)) + ... values.append(metric(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16))) >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index d843a8f7676..40b92297ba0 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor, tensor @@ -91,3 +91,44 @@ def compute(self) -> Tensor: """Compute final total variation.""" score = dim_zero_cat(self.score) if self.reduction is None or self.reduction == "none" else self.score return _total_variation_compute(score, self.num_elements, self.reduction) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import TotalVariation + >>> metric = TotalVariation() + >>> metric.update(torch.rand(5, 3, 28, 28)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import TotalVariation + >>> metric = TotalVariation() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(5, 3, 28, 28))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 030a3dcaf31..2613731c6ca 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -50,11 +50,18 @@ from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio from torchmetrics.image import ( ErrorRelativeGlobalDimensionlessSynthesis, + FrechetInceptionDistance, + InceptionScore, + KernelInceptionDistance, + LearnedPerceptualImagePatchSimilarity, MultiScaleStructuralSimilarityIndexMeasure, PeakSignalNoiseRatio, + RelativeAverageSpectralError, + RootMeanSquaredErrorUsingSlidingWindow, SpectralAngleMapper, SpectralDistortionIndex, StructuralSimilarityIndexMeasure, + TotalVariation, UniversalImageQualityIndex, ) from torchmetrics.nominal import CramersV, PearsonsContingencyCoefficient, TheilsU, TschuprowsT @@ -224,13 +231,26 @@ _multilabel_randint_input, id="multilabel average precision", ), + pytest.param(TotalVariation, _image_input, None, id="total variation"), + pytest.param( + RootMeanSquaredErrorUsingSlidingWindow, + _image_input, + _image_input, + id="root mean squared error using sliding window", + ), + pytest.param(RelativeAverageSpectralError, _image_input, _image_input, id="relative average spectral error"), + pytest.param( + LearnedPerceptualImagePatchSimilarity, + lambda: torch.rand(10, 3, 100, 100), + lambda: torch.rand(10, 3, 100, 100), + id="learned perceptual image patch similarity", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5]) -def test_single_multi_val_plot_methods(metric_class: object, preds: Callable, target: Callable, num_vals: int): +def test_plot_methods(metric_class: object, preds: Callable, target: Callable, num_vals: int): """Test the plot method of metrics that only output a single tensor scalar.""" metric = metric_class() - input = (lambda: (preds(),)) if target is None else lambda: (preds(), target()) if num_vals == 1: @@ -246,6 +266,64 @@ def test_single_multi_val_plot_methods(metric_class: object, preds: Callable, ta assert isinstance(ax, matplotlib.axes.Axes) +@pytest.mark.parametrize( + ("metric_class", "preds", "target", "index_0"), + [ + pytest.param( + partial(KernelInceptionDistance, feature=64, subsets=3, subset_size=20), + lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8), + lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8), + True, + id="kernel inception distance", + ), + pytest.param( + partial(FrechetInceptionDistance, feature=64), + lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8), + lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8), + False, + id="frechet inception distance", + ), + pytest.param( + partial(InceptionScore, feature=64), + lambda: torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8), + None, + True, + id="inception score", + ), + ], +) +@pytest.mark.parametrize("num_vals", [1, 2]) +def test_plot_methods_special_image_metrics(metric_class, preds, target, index_0, num_vals): + """Test the plot method of metrics that only output a single tensor scalar. + + This takes care of FID, KID and inception score image metrics as these have a slightly different call and update + signature than other metrics. + """ + metric = metric_class() + + if num_vals == 1: + if target is None: + metric.update(preds()) + else: + metric.update(preds(), real=True) + metric.update(target(), real=False) + fig, ax = metric.plot() + else: + vals = [] + for _ in range(num_vals): + if target is None: + vals.append(metric(preds())) + else: + metric.update(preds(), real=True) + metric.update(target(), real=False) + vals.append(metric.compute() if not index_0 else metric.compute()[0]) + metric.reset() + fig, ax = metric.plot(vals) + + assert isinstance(fig, plt.Figure) + assert isinstance(ax, matplotlib.axes.Axes) + + @pytest.mark.parametrize( ("metric_class", "preds", "target", "labels"), [ From 27f874aa42f001a711309f4524662462a22b2dee Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 7 Mar 2023 12:54:53 +0100 Subject: [PATCH 3/8] fix --- tests/unittests/utilities/test_plot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 2613731c6ca..ab8e8ff1107 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -285,7 +285,7 @@ def test_plot_methods(metric_class: object, preds: Callable, target: Callable, n ), pytest.param( partial(InceptionScore, feature=64), - lambda: torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8), + lambda: torch.randint(0, 255, (30, 3, 299, 299), dtype=torch.uint8), None, True, id="inception score", @@ -312,7 +312,7 @@ def test_plot_methods_special_image_metrics(metric_class, preds, target, index_0 vals = [] for _ in range(num_vals): if target is None: - vals.append(metric(preds())) + vals.append(metric(preds())[0]) else: metric.update(preds(), real=True) metric.update(target(), real=False) From ea11b30d86c6d65ed71aaadf88a1e11d64649f53 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 7 Mar 2023 12:57:33 +0100 Subject: [PATCH 4/8] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ac070334422..f4e1f38d79a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1581](https://github.com/Lightning-AI/metrics/pull/1581), [#1585](https://github.com/Lightning-AI/metrics/pull/1585), [#1593](https://github.com/Lightning-AI/metrics/pull/1593), + [#1600](https://github.com/Lightning-AI/metrics/pull/1600), ) From 575c7cd9c997b215ef89b402b2ce86da8cd8fb14 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 7 Mar 2023 14:55:17 +0100 Subject: [PATCH 5/8] req --- requirements/docs.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/docs.txt b/requirements/docs.txt index 4dd5da84959..f81bcda95d0 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -16,3 +16,4 @@ sphinx-copybutton>=0.3 -r visual.txt -r audio.txt -r detection.txt +-r image From 29d598b9aa44ccc344188e0be5965a0d4ccfcc74 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 9 Mar 2023 09:16:30 +0100 Subject: [PATCH 6/8] fix --- requirements/docs.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/docs.txt b/requirements/docs.txt index f81bcda95d0..1209e6d16b0 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -16,4 +16,4 @@ sphinx-copybutton>=0.3 -r visual.txt -r audio.txt -r detection.txt --r image +-r image.txt From 3a57e64010eab4d625688fafb531d30a32ea3037 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 9 Mar 2023 10:40:23 +0100 Subject: [PATCH 7/8] fix --- src/torchmetrics/image/fid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index b3b92c53c31..df3f228cafd 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -343,7 +343,7 @@ def plot( >>> # Example plotting multiple values >>> import torch - >>> from torchmetrics.image.kid import KernelInceptionDistance + >>> from torchmetrics.image.fid import FrechetInceptionDistance >>> imgs_dist1 = lambda: torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = lambda: torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> metric = FrechetInceptionDistance(feature=64) From e526c55a4d420d5606c77a59cd00acd4aeba6103 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 9 Mar 2023 11:03:29 +0100 Subject: [PATCH 8/8] skip on missing import --- src/torchmetrics/image/fid.py | 2 +- src/torchmetrics/image/inception.py | 2 +- src/torchmetrics/image/kid.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index df3f228cafd..cd0fe874a60 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -35,7 +35,7 @@ class _FeatureExtractorInceptionV3(Module): pass - __doctest_skip__ = ["FrechetInceptionDistance", "FID"] + __doctest_skip__ = ["FrechetInceptionDistance", "FrechetInceptionDistance.plot"] if _SCIPY_AVAILABLE: diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index c8034af54af..8ce90fb777d 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -28,7 +28,7 @@ __doctest_skip__ = ["InceptionScore.plot"] -__doctest_requires__ = {("InceptionScore", "IS"): ["torch_fidelity"]} +__doctest_requires__ = {("InceptionScore", "InceptionScore.plot"): ["torch_fidelity"]} class InceptionScore(Metric): diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index 0540e95aa16..b7b9906185e 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -27,7 +27,7 @@ if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["KernelInceptionDistance.plot"] -__doctest_requires__ = {("KernelInceptionDistance", "KID"): ["torch_fidelity"]} +__doctest_requires__ = {("KernelInceptionDistance", "KernelInceptionDistance.plot"): ["torch_fidelity"]} def maximum_mean_discrepancy(k_xx: Tensor, k_xy: Tensor, k_yy: Tensor) -> Tensor: