From fea56ee4d9350f1e4b2cfe567f4454a766614458 Mon Sep 17 00:00:00 2001 From: Ankita Sharma Date: Thu, 3 Feb 2022 14:36:15 +0530 Subject: [PATCH 01/26] Added new metric - UQI --- torchmetrics/__init__.py | 2 + torchmetrics/functional/image/uqi.py | 246 +++++++++++++++++++++++++++ torchmetrics/image/__init__.py | 4 + torchmetrics/image/uqi.py | 130 ++++++++++++++ 4 files changed, 382 insertions(+) create mode 100644 torchmetrics/functional/image/uqi.py create mode 100644 torchmetrics/image/uqi.py diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 50af1f17447..23c1f178d8a 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -54,9 +54,11 @@ from torchmetrics.image import ( # noqa: E402 PSNR, SSIM, + UQI, MultiScaleStructuralSimilarityIndexMeasure, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, + UniversalImageQualityIndex, ) from torchmetrics.metric import Metric # noqa: E402 from torchmetrics.regression import ( # noqa: E402 diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py new file mode 100644 index 00000000000..cf48369d630 --- /dev/null +++ b/torchmetrics/functional/image/uqi.py @@ -0,0 +1,246 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Sequence, Tuple, Union + +import torch +from deprecate import deprecated, void +from torch import Tensor +from torch.nn import functional as F +from typing_extensions import Literal + +from torchmetrics.utilities import _future_warning +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.distributed import reduce + + +def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor: + """Computes 1D gaussian kernel. + + Args: + kernel_size: size of the gaussian kernel + sigma: Standard deviation of the gaussian kernel + dtype: data type of the output tensor + device: device of the output tensor + + Example: + >>> _gaussian(3, 1, torch.float, 'cpu') + tensor([[0.2741, 0.4519, 0.2741]]) + """ + dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device) + gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) + return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) + + +def _gaussian_kernel( + channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device +) -> Tensor: + """Computes 2D gaussian kernel. + + Args: + channel: number of channels in the image + kernel_size: size of the gaussian kernel as a tuple (h, w) + sigma: Standard deviation of the gaussian kernel + dtype: data type of the output tensor + device: device of the output tensor + + Example: + >>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu") + tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030], + [0.0133, 0.0596, 0.0983, 0.0596, 0.0133], + [0.0219, 0.0983, 0.1621, 0.0983, 0.0219], + [0.0133, 0.0596, 0.0983, 0.0596, 0.0133], + [0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]]) + """ + + gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device) + gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device) + kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) + + return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) + + +def _uqi_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """Updates and returns variables required to compute Universal Image Quality Index. Checks for same shape + and type of the input tensors. + + Args: + preds: Predicted tensor + target: Ground truth tensor + """ + + if preds.dtype != target.dtype: + raise TypeError( + "Expected `preds` and `target` to have the same data type." + f" Got preds: {preds.dtype} and target: {target.dtype}." + ) + _check_same_shape(preds, target) + if len(preds.shape) != 4: + raise ValueError( + "Expected `preds` and `target` to have BxCxHxW shape." + f" Got preds: {preds.shape} and target: {target.shape}." + ) + return preds, target + + +def _uqi_compute( + preds: Tensor, + target: Tensor, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, + return_contrast_sensitivity: bool = False, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Computes Universal Image Quality Index. + + Args: + preds: estimated image + target: ground truth image + kernel_size: size of the gaussian kernel (default: (11, 11)) + sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + data_range: Range of the image. If ``None``, it is determined from the image (max - min) + + Example: + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> preds, target = _uqi_update(preds, target) + >>> _uqi_compute(preds, target) + tensor(0.9216) + """ + if len(kernel_size) != 2 or len(sigma) != 2: + raise ValueError( + "Expected `kernel_size` and `sigma` to have the length of two." + f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}." + ) + + if any(x % 2 == 0 or x <= 0 for x in kernel_size): + raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.") + + if any(y <= 0 for y in sigma): + raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.") + + if data_range is None: + data_range = max(preds.max() - preds.min(), target.max() - target.min()) + + device = preds.device + channel = preds.size(1) + dtype = preds.dtype + kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device) + pad_h = (kernel_size[0] - 1) // 2 + pad_w = (kernel_size[1] - 1) // 2 + + preds = F.pad(preds, (pad_h, pad_h, pad_w, pad_w), mode="reflect") + target = F.pad(target, (pad_h, pad_h, pad_w, pad_w), mode="reflect") + + input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) + outputs = F.conv2d(input_list, kernel, groups=channel) + output_list = outputs.split(preds.shape[0]) + + mu_pred_sq = output_list[0].pow(2) + mu_target_sq = output_list[1].pow(2) + mu_pred_target = output_list[0] * output_list[1] + + sigma_pred_sq = output_list[2] - mu_pred_sq + sigma_target_sq = output_list[3] - mu_target_sq + sigma_pred_target = output_list[4] - mu_pred_target + + upper = 2 * sigma_pred_target + lower = sigma_pred_sq + sigma_target_sq + + uqi_idx = ((2 * mu_pred_target) * upper) / ((mu_pred_sq + mu_target_sq) * lower) + uqi_idx = uqi_idx[..., pad_h:-pad_h, pad_w:-pad_w] + + return reduce(uqi_idx, reduction) + + +def universal_image_quality_index( + preds: Tensor, + target: Tensor, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, +) -> Tensor: + """Universal Image Quality Index. + + Args: + preds: estimated image + target: ground truth image + kernel_size: size of the gaussian kernel (default: (11, 11)) + sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + data_range: Range of the image. If ``None``, it is determined from the image (max - min) + + Return: + Tensor with UQI score + + Raises: + TypeError: + If ``preds`` and ``target`` don't have the same data type. + ValueError: + If ``preds`` and ``target`` don't have ``BxCxHxW shape``. + ValueError: + If the length of ``kernel_size`` or ``sigma`` is not ``2``. + ValueError: + If one of the elements of ``kernel_size`` is not an ``odd positive number``. + ValueError: + If one of the elements of ``sigma`` is not a ``positive number``. + + Example: + >>> from torchmetrics.functional import structural_similarity_index_measure + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> universal_image_quality_index(preds, target) + tensor(0.9216) + + References: + [1] Zhou Wang and A. C. Bovik, "A universal image quality index," in IEEE Signal Processing Letters, vol. 9, + no. 3, pp. 81-84, March 2002, doi: 10.1109/97.995823. + [2] Zhou Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, "Image quality assessment: from error visibility + to structural similarity," in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, April 2004, + doi: 10.1109/TIP.2003.819861. + """ + preds, target = _uqi_update(preds, target) + return _uqi_compute(preds, target, kernel_size, sigma, reduction, data_range) + + +@deprecated(target=universal_image_quality_index, deprecated_in="0.7", remove_in="0.8", stream=_future_warning) +def uqi( + preds: Tensor, + target: Tensor, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, +) -> Tensor: + """Computes Universal Image Quality Index. + + Example: + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> uqi(preds, target) + tensor(0.9216) + """ + return void(preds, target, kernel_size, sigma, reduction, data_range) diff --git a/torchmetrics/image/__init__.py b/torchmetrics/image/__init__.py index b5c456a5e8c..9cd2f8f77a2 100644 --- a/torchmetrics/image/__init__.py +++ b/torchmetrics/image/__init__.py @@ -19,6 +19,10 @@ MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure, ) +from torchmetrics.image.uqi import ( # noqa: F401 + UQI, + UniversalImageQualityIndex, +) from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCH_FIDELITY_AVAILABLE if _TORCH_FIDELITY_AVAILABLE: diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py new file mode 100644 index 00000000000..8ae01559d0e --- /dev/null +++ b/torchmetrics/image/uqi.py @@ -0,0 +1,130 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Sequence, Tuple + +import torch +from deprecate import deprecated, void +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.image.uqi import _uqi_compute, _uqi_update +from torchmetrics.metric import Metric +from torchmetrics.utilities import _future_warning, rank_zero_warn +from torchmetrics.utilities.data import dim_zero_cat + + +class UniversalImageQualityIndex(Metric): + """Computes Universal Image Quality Index (UQI_). + + Args: + kernel_size: size of the gaussian kernel + sigma: Standard deviation of the gaussian kernel + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + data_range: Range of the image. If ``None``, it is determined from the image (max - min) + + Return: + Tensor with UQI score + + Example: + >>> from torchmetrics import UniversalImageQualityIndex + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> uqi = UniversalImageQualityIndex() + >>> uqi(preds, target) + tensor(0.9219) + """ + + preds: List[Tensor] + target: List[Tensor] + higher_is_better = True + + def __init__( + self, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ) -> None: + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + rank_zero_warn( + "Metric `UQI` will save all targets and" + " predictions in buffer. For large datasets this may lead" + " to large memory footprint." + ) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + self.kernel_size = kernel_size + self.sigma = sigma + self.data_range = data_range + self.reduction = reduction + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target = _uqi_update(preds, target) + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Tensor: + """Computes explained variance over state.""" + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + return _uqi_compute( + preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range + ) + + +class UQI(UniversalImageQualityIndex): + """Computes Universal Image Quality Index (UQI_). + + .. deprecated:: v0.7 + Use :class:`torchmetrics.UniversalImageQualityIndex`. Will be removed in v0.8. + + Example: + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> uqi = UQI() + >>> uqi(preds, target) + tensor(0.9219) + """ + + @deprecated(target=UniversalImageQualityIndex, deprecated_in="0.7", remove_in="0.8", stream=_future_warning) + def __init__( + self, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ) -> None: + void(kernel_size, sigma, reduction, data_range, compute_on_step, dist_sync_on_step, process_group) \ No newline at end of file From 30c5c5080bb2d6437b8859784adb0a5b6c3147e6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Feb 2022 09:35:04 +0000 Subject: [PATCH 02/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/image/uqi.py | 10 +++++----- torchmetrics/image/__init__.py | 5 +---- torchmetrics/image/uqi.py | 6 ++---- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py index cf48369d630..ed63a44de68 100644 --- a/torchmetrics/functional/image/uqi.py +++ b/torchmetrics/functional/image/uqi.py @@ -71,8 +71,8 @@ def _gaussian_kernel( def _uqi_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: - """Updates and returns variables required to compute Universal Image Quality Index. Checks for same shape - and type of the input tensors. + """Updates and returns variables required to compute Universal Image Quality Index. Checks for same shape and + type of the input tensors. Args: preds: Predicted tensor @@ -161,7 +161,7 @@ def _uqi_compute( sigma_target_sq = output_list[3] - mu_target_sq sigma_pred_target = output_list[4] - mu_pred_target - upper = 2 * sigma_pred_target + upper = 2 * sigma_pred_target lower = sigma_pred_sq + sigma_target_sq uqi_idx = ((2 * mu_pred_target) * upper) / ((mu_pred_sq + mu_target_sq) * lower) @@ -216,10 +216,10 @@ def universal_image_quality_index( tensor(0.9216) References: - [1] Zhou Wang and A. C. Bovik, "A universal image quality index," in IEEE Signal Processing Letters, vol. 9, + [1] Zhou Wang and A. C. Bovik, "A universal image quality index," in IEEE Signal Processing Letters, vol. 9, no. 3, pp. 81-84, March 2002, doi: 10.1109/97.995823. [2] Zhou Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, "Image quality assessment: from error visibility - to structural similarity," in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, April 2004, + to structural similarity," in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, April 2004, doi: 10.1109/TIP.2003.819861. """ preds, target = _uqi_update(preds, target) diff --git a/torchmetrics/image/__init__.py b/torchmetrics/image/__init__.py index 9cd2f8f77a2..e2e28586f2c 100644 --- a/torchmetrics/image/__init__.py +++ b/torchmetrics/image/__init__.py @@ -19,10 +19,7 @@ MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure, ) -from torchmetrics.image.uqi import ( # noqa: F401 - UQI, - UniversalImageQualityIndex, -) +from torchmetrics.image.uqi import UQI, UniversalImageQualityIndex # noqa: F401 from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCH_FIDELITY_AVAILABLE if _TORCH_FIDELITY_AVAILABLE: diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index 8ae01559d0e..978ba2765e3 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -97,9 +97,7 @@ def compute(self) -> Tensor: """Computes explained variance over state.""" preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) - return _uqi_compute( - preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range - ) + return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range) class UQI(UniversalImageQualityIndex): @@ -127,4 +125,4 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ) -> None: - void(kernel_size, sigma, reduction, data_range, compute_on_step, dist_sync_on_step, process_group) \ No newline at end of file + void(kernel_size, sigma, reduction, data_range, compute_on_step, dist_sync_on_step, process_group) From b2078b8612e0b70559005033758c20d0f9fbc7bb Mon Sep 17 00:00:00 2001 From: Ankita Sharma Date: Thu, 3 Feb 2022 23:49:44 +0530 Subject: [PATCH 03/26] minor fixes --- torchmetrics/__init__.py | 2 ++ torchmetrics/functional/image/uqi.py | 5 ++--- torchmetrics/image/uqi.py | 7 +++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 23c1f178d8a..4b169c3c218 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -178,6 +178,8 @@ "SumMetric", "SymmetricMeanAbsolutePercentageError", "TranslationEditRate", + "UQI", + "UniversalImageQualityIndex", "WordErrorRate", "CharErrorRate", "MatchErrorRate", diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py index cf48369d630..d952a7274a5 100644 --- a/torchmetrics/functional/image/uqi.py +++ b/torchmetrics/functional/image/uqi.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Tuple, Union import torch from deprecate import deprecated, void from torch import Tensor from torch.nn import functional as F -from typing_extensions import Literal from torchmetrics.utilities import _future_warning from torchmetrics.utilities.checks import _check_same_shape @@ -209,7 +208,7 @@ def universal_image_quality_index( If one of the elements of ``sigma`` is not a ``positive number``. Example: - >>> from torchmetrics.functional import structural_similarity_index_measure + >>> from torchmetrics.functional import universal_image_quality_index >>> preds = torch.rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> universal_image_quality_index(preds, target) diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index 8ae01559d0e..2611115c037 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence import torch from deprecate import deprecated, void from torch import Tensor -from typing_extensions import Literal from torchmetrics.functional.image.uqi import _uqi_compute, _uqi_update from torchmetrics.metric import Metric @@ -47,7 +46,7 @@ class UniversalImageQualityIndex(Metric): >>> target = preds * 0.75 >>> uqi = UniversalImageQualityIndex() >>> uqi(preds, target) - tensor(0.9219) + tensor(0.9216) """ preds: List[Tensor] @@ -113,7 +112,7 @@ class UQI(UniversalImageQualityIndex): >>> target = preds * 0.75 >>> uqi = UQI() >>> uqi(preds, target) - tensor(0.9219) + tensor(0.9216) """ @deprecated(target=UniversalImageQualityIndex, deprecated_in="0.7", remove_in="0.8", stream=_future_warning) From a688cb48800e3d995eb18f2e76c8f8a23cbe0a99 Mon Sep 17 00:00:00 2001 From: Ankita Sharma Date: Fri, 4 Feb 2022 00:41:58 +0530 Subject: [PATCH 04/26] Added missing import --- torchmetrics/functional/image/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchmetrics/functional/image/__init__.py b/torchmetrics/functional/image/__init__.py index 04f38e3ef38..dd813c174fb 100644 --- a/torchmetrics/functional/image/__init__.py +++ b/torchmetrics/functional/image/__init__.py @@ -18,3 +18,7 @@ ssim, structural_similarity_index_measure, ) +from torchmetrics.functional.image.uqi import ( # noqa: F401 + uqi, + universal_image_quality_index, +) From ae11a3b07fa2b1dbde9fd6b09f2612b82b044e13 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Feb 2022 19:12:44 +0000 Subject: [PATCH 05/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/image/__init__.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchmetrics/functional/image/__init__.py b/torchmetrics/functional/image/__init__.py index dd813c174fb..c61fed614a5 100644 --- a/torchmetrics/functional/image/__init__.py +++ b/torchmetrics/functional/image/__init__.py @@ -18,7 +18,4 @@ ssim, structural_similarity_index_measure, ) -from torchmetrics.functional.image.uqi import ( # noqa: F401 - uqi, - universal_image_quality_index, -) +from torchmetrics.functional.image.uqi import universal_image_quality_index, uqi # noqa: F401 From 29420c2099e5e542b271ea046ac5fa82d94ae4d7 Mon Sep 17 00:00:00 2001 From: Ankita Sharma Date: Fri, 4 Feb 2022 01:01:51 +0530 Subject: [PATCH 06/26] Registered UQI to functional init; tested locally --- torchmetrics/functional/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 3089fad2fe1..d6c5674d12b 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -41,6 +41,10 @@ ssim, structural_similarity_index_measure, ) +from torchmetrics.functional.image.uqi import ( + uqi, + universal_image_quality_index, +) from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity @@ -149,6 +153,8 @@ "stat_scores", "symmetric_mean_absolute_percentage_error", "translation_edit_rate", + "uqi", + "universal_image_quality_index", "word_error_rate", "char_error_rate", "match_error_rate", From 3a08596ccd2ca5bfd73cc7d91933719475d48231 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Feb 2022 19:33:21 +0000 Subject: [PATCH 07/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/__init__.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index d6c5674d12b..a1744fec70e 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -41,10 +41,7 @@ ssim, structural_similarity_index_measure, ) -from torchmetrics.functional.image.uqi import ( - uqi, - universal_image_quality_index, -) +from torchmetrics.functional.image.uqi import universal_image_quality_index, uqi from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity From 53bfe2ae0f3c47460988e9283fa61c3df77b6145 Mon Sep 17 00:00:00 2001 From: Ankita Sharma Date: Fri, 4 Feb 2022 17:25:58 +0530 Subject: [PATCH 08/26] Testcases added for UQI --- tests/image/test_uqi.py | 174 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 tests/image/test_uqi.py diff --git a/tests/image/test_uqi.py b/tests/image/test_uqi.py new file mode 100644 index 00000000000..6ebd574b3a7 --- /dev/null +++ b/tests/image/test_uqi.py @@ -0,0 +1,174 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from skimage.metrics import structural_similarity + +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.functional import universal_image_quality_index +from torchmetrics.image import UniversalImageQualityIndex + +seed_all(42) + +# UQI is SSIM with both constants k1 and k2 as 0 +skimage_uqi = partial(structural_similarity, k1 = 0, k2 = 0) + +Input = namedtuple("Input", ["preds", "target", "multichannel"]) + +_inputs = [] +for size, channel, coef, multichannel, dtype in [ + (12, 3, 0.9, True, torch.float), + (13, 1, 0.8, False, torch.float32), + (14, 1, 0.7, False, torch.double), + (15, 3, 0.6, True, torch.float64), +]: + preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) + _inputs.append( + Input( + preds=preds, + target=preds * coef, + multichannel=multichannel, + ) + ) + + +def _sk_uqi(preds, target, data_range, multichannel, kernel_size): + c, h, w = preds.shape[-3:] + sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() + sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() + if not multichannel: + sk_preds = sk_preds[:, :, :, 0] + sk_target = sk_target[:, :, :, 0] + + return skimage_uqi( + sk_target, + sk_preds, + data_range=data_range, + multichannel=multichannel, + gaussian_weights=True, + win_size=kernel_size, + sigma=1.5, + use_sample_covariance=False, + ) + + +@pytest.mark.parametrize( + "preds, target, multichannel", + [(i.preds, i.target, i.multichannel) for i in _inputs], +) +@pytest.mark.parametrize("kernel_size", [5, 11]) +class TestUQI(MetricTester): + atol = 6e-3 + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_uqi(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + UniversalImageQualityIndex, + partial(_sk_uqi, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size), + metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)}, + dist_sync_on_step=dist_sync_on_step, + ) + + def test_uqi_functional(self, preds, target, multichannel, kernel_size): + self.run_functional_metric_test( + preds, + target, + universal_image_quality_index, + partial(_sk_uqi, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size), + metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)}, + ) + + # UQI half + cpu does not work due to missing support in torch.log + @pytest.mark.xfail(reason="UQI metric does not support cpu + half precision") + def test_uqi_half_cpu(self, preds, target, multichannel, kernel_size): + self.run_precision_test_cpu( + preds, target, UniversalImageQualityIndex, universal_image_quality_index, {"data_range": 1.0} + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + def test_uqi_half_gpu(self, preds, target, multichannel, kernel_size): + self.run_precision_test_gpu( + preds, target, UniversalImageQualityIndex, universal_image_quality_index, {"data_range": 1.0} + ) + + +@pytest.mark.parametrize( + ["pred", "target", "kernel", "sigma"], + [ + ([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape) + ([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma) + ([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma) + ([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma) + ([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input + ([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input + ([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input + ([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input + ([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input + ], +) +def test_uqi_invalid_inputs(pred, target, kernel, sigma): + pred_t = torch.rand(pred) + target_t = torch.rand(target, dtype=torch.float64) + with pytest.raises(TypeError): + universal_image_quality_index(pred_t, target_t) + + pred = torch.rand(pred) + target = torch.rand(target) + with pytest.raises(ValueError): + universal_image_quality_index(pred, target, kernel, sigma) + + +def test_uqi_unequal_kernel_size(): + """Test the case where kernel_size[0] != kernel_size[1]""" + preds = torch.tensor( + [ + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0], + ] + ] + ] + ) + target = torch.tensor( + [ + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], + ] + ] + ] + ) + # kernel order matters + assert universal_image_quality_index(preds, target, kernel_size=(3, 5)) == torch.tensor(0.10662283) + assert universal_image_quality_index(preds, target, kernel_size=(5, 3)) != torch.tensor(0.10662283) From d624e675ce4ce1bd9e4df56770471ef80c7b6ace Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 4 Feb 2022 11:57:28 +0000 Subject: [PATCH 09/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/image/test_uqi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/image/test_uqi.py b/tests/image/test_uqi.py index 6ebd574b3a7..16b99856006 100644 --- a/tests/image/test_uqi.py +++ b/tests/image/test_uqi.py @@ -26,7 +26,7 @@ seed_all(42) # UQI is SSIM with both constants k1 and k2 as 0 -skimage_uqi = partial(structural_similarity, k1 = 0, k2 = 0) +skimage_uqi = partial(structural_similarity, k1=0, k2=0) Input = namedtuple("Input", ["preds", "target", "multichannel"]) From ba326f1fdf0ca7240d0e09277bde593289d4fe73 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Sat, 5 Feb 2022 11:01:07 +0100 Subject: [PATCH 10/26] Apply suggestions from code review Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> --- torchmetrics/functional/__init__.py | 3 +-- torchmetrics/functional/image/__init__.py | 2 +- torchmetrics/functional/image/uqi.py | 19 ---------------- torchmetrics/image/__init__.py | 2 +- torchmetrics/image/uqi.py | 27 ----------------------- 5 files changed, 3 insertions(+), 50 deletions(-) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index a1744fec70e..c79c0992262 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -41,7 +41,7 @@ ssim, structural_similarity_index_measure, ) -from torchmetrics.functional.image.uqi import universal_image_quality_index, uqi +from torchmetrics.functional.image.uqi import universal_image_quality_index from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity @@ -150,7 +150,6 @@ "stat_scores", "symmetric_mean_absolute_percentage_error", "translation_edit_rate", - "uqi", "universal_image_quality_index", "word_error_rate", "char_error_rate", diff --git a/torchmetrics/functional/image/__init__.py b/torchmetrics/functional/image/__init__.py index c61fed614a5..16f477241bb 100644 --- a/torchmetrics/functional/image/__init__.py +++ b/torchmetrics/functional/image/__init__.py @@ -18,4 +18,4 @@ ssim, structural_similarity_index_measure, ) -from torchmetrics.functional.image.uqi import universal_image_quality_index, uqi # noqa: F401 +from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401 diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py index 26cbea9daca..23a14544bf1 100644 --- a/torchmetrics/functional/image/uqi.py +++ b/torchmetrics/functional/image/uqi.py @@ -14,7 +14,6 @@ from typing import Optional, Sequence, Tuple, Union import torch -from deprecate import deprecated, void from torch import Tensor from torch.nn import functional as F @@ -225,21 +224,3 @@ def universal_image_quality_index( return _uqi_compute(preds, target, kernel_size, sigma, reduction, data_range) -@deprecated(target=universal_image_quality_index, deprecated_in="0.7", remove_in="0.8", stream=_future_warning) -def uqi( - preds: Tensor, - target: Tensor, - kernel_size: Sequence[int] = (11, 11), - sigma: Sequence[float] = (1.5, 1.5), - reduction: str = "elementwise_mean", - data_range: Optional[float] = None, -) -> Tensor: - """Computes Universal Image Quality Index. - - Example: - >>> preds = torch.rand([16, 1, 16, 16]) - >>> target = preds * 0.75 - >>> uqi(preds, target) - tensor(0.9216) - """ - return void(preds, target, kernel_size, sigma, reduction, data_range) diff --git a/torchmetrics/image/__init__.py b/torchmetrics/image/__init__.py index e2e28586f2c..1d0b00cc174 100644 --- a/torchmetrics/image/__init__.py +++ b/torchmetrics/image/__init__.py @@ -19,7 +19,7 @@ MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure, ) -from torchmetrics.image.uqi import UQI, UniversalImageQualityIndex # noqa: F401 +from torchmetrics.image.uqi import UniversalImageQualityIndex # noqa: F401 from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCH_FIDELITY_AVAILABLE if _TORCH_FIDELITY_AVAILABLE: diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index 2ce30362156..a2712f68bf8 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -14,7 +14,6 @@ from typing import Any, List, Optional, Sequence import torch -from deprecate import deprecated, void from torch import Tensor from torchmetrics.functional.image.uqi import _uqi_compute, _uqi_update @@ -99,29 +98,3 @@ def compute(self) -> Tensor: return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range) -class UQI(UniversalImageQualityIndex): - """Computes Universal Image Quality Index (UQI_). - - .. deprecated:: v0.7 - Use :class:`torchmetrics.UniversalImageQualityIndex`. Will be removed in v0.8. - - Example: - >>> preds = torch.rand([16, 1, 16, 16]) - >>> target = preds * 0.75 - >>> uqi = UQI() - >>> uqi(preds, target) - tensor(0.9216) - """ - - @deprecated(target=UniversalImageQualityIndex, deprecated_in="0.7", remove_in="0.8", stream=_future_warning) - def __init__( - self, - kernel_size: Sequence[int] = (11, 11), - sigma: Sequence[float] = (1.5, 1.5), - reduction: str = "elementwise_mean", - data_range: Optional[float] = None, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - ) -> None: - void(kernel_size, sigma, reduction, data_range, compute_on_step, dist_sync_on_step, process_group) From a530e941537e551a125a0ed220b532504eecffbf Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Sat, 5 Feb 2022 11:03:12 +0100 Subject: [PATCH 11/26] Apply suggestions from code review --- torchmetrics/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 4b169c3c218..8cffec12be7 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -54,7 +54,6 @@ from torchmetrics.image import ( # noqa: E402 PSNR, SSIM, - UQI, MultiScaleStructuralSimilarityIndexMeasure, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, @@ -178,7 +177,6 @@ "SumMetric", "SymmetricMeanAbsolutePercentageError", "TranslationEditRate", - "UQI", "UniversalImageQualityIndex", "WordErrorRate", "CharErrorRate", From e07b3374b5395065b69c5365ddb9dfffe069c594 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 5 Feb 2022 10:03:45 +0000 Subject: [PATCH 12/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/image/uqi.py | 2 -- torchmetrics/image/uqi.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py index 23a14544bf1..b0f558ae843 100644 --- a/torchmetrics/functional/image/uqi.py +++ b/torchmetrics/functional/image/uqi.py @@ -222,5 +222,3 @@ def universal_image_quality_index( """ preds, target = _uqi_update(preds, target) return _uqi_compute(preds, target, kernel_size, sigma, reduction, data_range) - - diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index a2712f68bf8..1bd2c286d85 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -96,5 +96,3 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range) - - From 40e68aaadf36a419859ecd762a74a763f439aa43 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Sat, 5 Feb 2022 11:05:34 +0100 Subject: [PATCH 13/26] Apply suggestions from code review Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> --- torchmetrics/functional/image/uqi.py | 6 +++--- torchmetrics/image/uqi.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py index b0f558ae843..194c14e419f 100644 --- a/torchmetrics/functional/image/uqi.py +++ b/torchmetrics/functional/image/uqi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Tuple, Union, Literal import torch from torch import Tensor @@ -96,7 +96,7 @@ def _uqi_compute( target: Tensor, kernel_size: Sequence[int] = (11, 11), sigma: Sequence[float] = (1.5, 1.5), - reduction: str = "elementwise_mean", + reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", data_range: Optional[float] = None, return_contrast_sensitivity: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: @@ -173,7 +173,7 @@ def universal_image_quality_index( target: Tensor, kernel_size: Sequence[int] = (11, 11), sigma: Sequence[float] = (1.5, 1.5), - reduction: str = "elementwise_mean", + reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", data_range: Optional[float] = None, ) -> Tensor: """Universal Image Quality Index. diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index 1bd2c286d85..1d55d59ae86 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence +from typing import Any, List, Optional, Sequence, Literal import torch from torch import Tensor @@ -56,7 +56,7 @@ def __init__( self, kernel_size: Sequence[int] = (11, 11), sigma: Sequence[float] = (1.5, 1.5), - reduction: str = "elementwise_mean", + reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", data_range: Optional[float] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, From a73200d843cfb9c596e7f390e993cedc6f8cbb42 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 5 Feb 2022 10:06:06 +0000 Subject: [PATCH 14/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/image/uqi.py | 2 +- torchmetrics/image/uqi.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py index 194c14e419f..ff7ec9fe6b6 100644 --- a/torchmetrics/functional/image/uqi.py +++ b/torchmetrics/functional/image/uqi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union, Literal +from typing import Literal, Optional, Sequence, Tuple, Union import torch from torch import Tensor diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index 1d55d59ae86..00a10eb3212 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Literal +from typing import Any, List, Literal, Optional, Sequence import torch from torch import Tensor From 5f8d39bfed0f68110cdfe5990dfcd2ac015fa0f3 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Sat, 5 Feb 2022 11:21:30 +0100 Subject: [PATCH 15/26] Apply suggestions from code review --- torchmetrics/functional/image/uqi.py | 3 ++- torchmetrics/image/uqi.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py index ff7ec9fe6b6..e3522f8151c 100644 --- a/torchmetrics/functional/image/uqi.py +++ b/torchmetrics/functional/image/uqi.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Tuple, Union +from typing_extensions import Literal import torch from torch import Tensor diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index 00a10eb3212..0935f246ee2 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Literal, Optional, Sequence +from typing import Any, List, Optional, Sequence +from typing_extensions import Literal import torch from torch import Tensor From 94cd3308c3de2aafafda1d2cd9032c1a4c24d44a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 5 Feb 2022 10:22:04 +0000 Subject: [PATCH 16/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/image/uqi.py | 2 +- torchmetrics/image/uqi.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py index e3522f8151c..b466cfacd70 100644 --- a/torchmetrics/functional/image/uqi.py +++ b/torchmetrics/functional/image/uqi.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Sequence, Tuple, Union -from typing_extensions import Literal import torch from torch import Tensor from torch.nn import functional as F +from typing_extensions import Literal from torchmetrics.utilities import _future_warning from torchmetrics.utilities.checks import _check_same_shape diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index 0935f246ee2..c8fa0757164 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List, Optional, Sequence -from typing_extensions import Literal import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.image.uqi import _uqi_compute, _uqi_update from torchmetrics.metric import Metric From 6672a90fe65cd673cc72cea4527749feef4a1690 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Sat, 5 Feb 2022 11:22:48 +0100 Subject: [PATCH 17/26] Update requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 5d7a0bc464d..9b1500e5fac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ numpy>=1.17.2 torch>=1.3.1 pyDeprecate==0.3.* packaging +typing_extensions From 4951bc6edadb20f9a531cd2c509fa755e90d6721 Mon Sep 17 00:00:00 2001 From: Ankita Sharma Date: Sun, 6 Feb 2022 15:41:06 +0530 Subject: [PATCH 18/26] Address reviews --- CHANGELOG.md | 4 ++ docs/source/references/functional.rst | 6 +++ docs/source/references/modules.rst | 6 +++ tests/image/test_uqi.py | 4 +- torchmetrics/functional/image/helper.py | 50 +++++++++++++++++++++++++ torchmetrics/functional/image/ssim.py | 47 +---------------------- torchmetrics/functional/image/uqi.py | 47 +---------------------- 7 files changed, 70 insertions(+), 94 deletions(-) create mode 100644 torchmetrics/functional/image/helper.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3eb3f080be4..57d5396a931 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for `MetricCollection` in `MetricTracker` ([#718](https://github.com/PyTorchLightning/metrics/pull/718)) + +- Added new image metric `UniversalImageQualityIndex` ([#824](https://github.com/PyTorchLightning/metrics/pull/824)) + + ### Changed - Used `torch.bucketize` in calibration error when `torch>1.8` for faster computations ([#769](https://github.com/PyTorchLightning/metrics/pull/769)) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 5aca08b133e..707b130bedb 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -270,6 +270,12 @@ peak_signal_noise_ratio [func] .. autofunction:: torchmetrics.functional.peak_signal_noise_ratio :noindex: +universal_image_quality_index [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.universal_image_quality_index + :noindex: + ********** Regression diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index b403da53bf9..a89913fe637 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -404,6 +404,12 @@ StructuralSimilarityIndexMeasure .. autoclass:: torchmetrics.StructuralSimilarityIndexMeasure :noindex: +UniversalImageQualityIndex +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.UniversalImageQualityIndex + :noindex: + ********* Detection ********* diff --git a/tests/image/test_uqi.py b/tests/image/test_uqi.py index 6ebd574b3a7..29cafe18d24 100644 --- a/tests/image/test_uqi.py +++ b/tests/image/test_uqi.py @@ -170,5 +170,5 @@ def test_uqi_unequal_kernel_size(): ] ) # kernel order matters - assert universal_image_quality_index(preds, target, kernel_size=(3, 5)) == torch.tensor(0.10662283) - assert universal_image_quality_index(preds, target, kernel_size=(5, 3)) != torch.tensor(0.10662283) + torch.allclose(universal_image_quality_index(preds, target, kernel_size=(3, 5)), torch.tensor(0.10662283)) + torch.allclose(universal_image_quality_index(preds, target, kernel_size=(5, 3)), torch.tensor(0.10662283)) diff --git a/torchmetrics/functional/image/helper.py b/torchmetrics/functional/image/helper.py new file mode 100644 index 00000000000..f52f3462317 --- /dev/null +++ b/torchmetrics/functional/image/helper.py @@ -0,0 +1,50 @@ +from typing import Sequence + +import torch +from torch import Tensor + + +def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor: + """Computes 1D gaussian kernel. + + Args: + kernel_size: size of the gaussian kernel + sigma: Standard deviation of the gaussian kernel + dtype: data type of the output tensor + device: device of the output tensor + + Example: + >>> _gaussian(3, 1, torch.float, 'cpu') + tensor([[0.2741, 0.4519, 0.2741]]) + """ + dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device) + gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) + return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) + + +def _gaussian_kernel( + channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device +) -> Tensor: + """Computes 2D gaussian kernel. + + Args: + channel: number of channels in the image + kernel_size: size of the gaussian kernel as a tuple (h, w) + sigma: Standard deviation of the gaussian kernel + dtype: data type of the output tensor + device: device of the output tensor + + Example: + >>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu") + tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030], + [0.0133, 0.0596, 0.0983, 0.0596, 0.0133], + [0.0219, 0.0983, 0.1621, 0.0983, 0.0219], + [0.0133, 0.0596, 0.0983, 0.0596, 0.0133], + [0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]]) + """ + + gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device) + gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device) + kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) + + return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) diff --git a/torchmetrics/functional/image/ssim.py b/torchmetrics/functional/image/ssim.py index 8c033d5e154..1096d950849 100644 --- a/torchmetrics/functional/image/ssim.py +++ b/torchmetrics/functional/image/ssim.py @@ -22,52 +22,7 @@ from torchmetrics.utilities import _future_warning from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.distributed import reduce - - -def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor: - """Computes 1D gaussian kernel. - - Args: - kernel_size: size of the gaussian kernel - sigma: Standard deviation of the gaussian kernel - dtype: data type of the output tensor - device: device of the output tensor - - Example: - >>> _gaussian(3, 1, torch.float, 'cpu') - tensor([[0.2741, 0.4519, 0.2741]]) - """ - dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device) - gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) - return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) - - -def _gaussian_kernel( - channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device -) -> Tensor: - """Computes 2D gaussian kernel. - - Args: - channel: number of channels in the image - kernel_size: size of the gaussian kernel as a tuple (h, w) - sigma: Standard deviation of the gaussian kernel - dtype: data type of the output tensor - device: device of the output tensor - - Example: - >>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu") - tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030], - [0.0133, 0.0596, 0.0983, 0.0596, 0.0133], - [0.0219, 0.0983, 0.1621, 0.0983, 0.0219], - [0.0133, 0.0596, 0.0983, 0.0596, 0.0133], - [0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]]) - """ - - gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device) - gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device) - kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) - - return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) +from torchmetrics.functional.image.helper import _gaussian_kernel def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py index 26cbea9daca..a1f7070af67 100644 --- a/torchmetrics/functional/image/uqi.py +++ b/torchmetrics/functional/image/uqi.py @@ -21,52 +21,7 @@ from torchmetrics.utilities import _future_warning from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.distributed import reduce - - -def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor: - """Computes 1D gaussian kernel. - - Args: - kernel_size: size of the gaussian kernel - sigma: Standard deviation of the gaussian kernel - dtype: data type of the output tensor - device: device of the output tensor - - Example: - >>> _gaussian(3, 1, torch.float, 'cpu') - tensor([[0.2741, 0.4519, 0.2741]]) - """ - dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device) - gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) - return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) - - -def _gaussian_kernel( - channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device -) -> Tensor: - """Computes 2D gaussian kernel. - - Args: - channel: number of channels in the image - kernel_size: size of the gaussian kernel as a tuple (h, w) - sigma: Standard deviation of the gaussian kernel - dtype: data type of the output tensor - device: device of the output tensor - - Example: - >>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu") - tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030], - [0.0133, 0.0596, 0.0983, 0.0596, 0.0133], - [0.0219, 0.0983, 0.1621, 0.0983, 0.0219], - [0.0133, 0.0596, 0.0983, 0.0596, 0.0133], - [0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]]) - """ - - gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device) - gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device) - kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) - - return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) +from torchmetrics.functional.image.helper import _gaussian_kernel def _uqi_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: From e208df85054faea897d235fdf01bdee55f0131a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 6 Feb 2022 10:13:06 +0000 Subject: [PATCH 19/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/image/ssim.py | 2 +- torchmetrics/functional/image/uqi.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/image/ssim.py b/torchmetrics/functional/image/ssim.py index 03d8dabcf9b..f0b3b819497 100644 --- a/torchmetrics/functional/image/ssim.py +++ b/torchmetrics/functional/image/ssim.py @@ -18,9 +18,9 @@ from torch.nn import functional as F from typing_extensions import Literal +from torchmetrics.functional.image.helper import _gaussian_kernel from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.distributed import reduce -from torchmetrics.functional.image.helper import _gaussian_kernel def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py index 8c79bff4981..376acf22ddf 100644 --- a/torchmetrics/functional/image/uqi.py +++ b/torchmetrics/functional/image/uqi.py @@ -18,10 +18,10 @@ from torch.nn import functional as F from typing_extensions import Literal +from torchmetrics.functional.image.helper import _gaussian_kernel from torchmetrics.utilities import _future_warning from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.distributed import reduce -from torchmetrics.functional.image.helper import _gaussian_kernel def _uqi_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: From 110488e9f84a5325a8cec8304e968ec9f9eb923b Mon Sep 17 00:00:00 2001 From: Ankita Sharma Date: Sun, 6 Feb 2022 15:54:51 +0530 Subject: [PATCH 20/26] minor changes --- docs/source/links.rst | 1 + torchmetrics/functional/image/uqi.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/links.rst b/docs/source/links.rst index d0d54d5223b..8b31ce8d712 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -78,3 +78,4 @@ .. _TER: https://aclanthology.org/2006.amta-papers.25.pdf .. _ExtendedEditDistance: https://aclanthology.org/W19-5359.pdf .. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf +.. _UniversalImageQualityIndex: https://ieeexplore.ieee.org/document/995823 diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py index 8c79bff4981..ca7d211d3a5 100644 --- a/torchmetrics/functional/image/uqi.py +++ b/torchmetrics/functional/image/uqi.py @@ -18,7 +18,6 @@ from torch.nn import functional as F from typing_extensions import Literal -from torchmetrics.utilities import _future_warning from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.distributed import reduce from torchmetrics.functional.image.helper import _gaussian_kernel From a6d255b1ccfc69b09d5cf2ac1e452dcbfd69b477 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 6 Feb 2022 10:29:28 +0000 Subject: [PATCH 21/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/links.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/links.rst b/docs/source/links.rst index 8b31ce8d712..c5a2e6977ed 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -78,4 +78,4 @@ .. _TER: https://aclanthology.org/2006.amta-papers.25.pdf .. _ExtendedEditDistance: https://aclanthology.org/W19-5359.pdf .. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf -.. _UniversalImageQualityIndex: https://ieeexplore.ieee.org/document/995823 +.. _UniversalImageQualityIndex: https://ieeexplore.ieee.org/document/995823 From 0c7e72ea17c730cbfd613c24ba00b457b02252b6 Mon Sep 17 00:00:00 2001 From: Ankita Sharma Date: Sun, 6 Feb 2022 16:15:57 +0530 Subject: [PATCH 22/26] removed unused imports --- torchmetrics/image/uqi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index c8fa0757164..ed29ded616e 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -19,7 +19,7 @@ from torchmetrics.functional.image.uqi import _uqi_compute, _uqi_update from torchmetrics.metric import Metric -from torchmetrics.utilities import _future_warning, rank_zero_warn +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat From c08d3913872e64d8ef245ffdcb73448032dfae99 Mon Sep 17 00:00:00 2001 From: Ankita Sharma Date: Sun, 6 Feb 2022 17:23:45 +0530 Subject: [PATCH 23/26] minor fixes --- docs/source/references/functional.rst | 2 +- docs/source/references/modules.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 707b130bedb..9a821a825f6 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -271,7 +271,7 @@ peak_signal_noise_ratio [func] :noindex: universal_image_quality_index [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: torchmetrics.functional.universal_image_quality_index :noindex: diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index a89913fe637..460a1a54134 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -405,7 +405,7 @@ StructuralSimilarityIndexMeasure :noindex: UniversalImageQualityIndex -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: torchmetrics.UniversalImageQualityIndex :noindex: From 2b59b27612eb2fd57f3e5a66933f3c4ec50523c0 Mon Sep 17 00:00:00 2001 From: Ankita Sharma Date: Mon, 7 Feb 2022 12:27:12 +0530 Subject: [PATCH 24/26] Apply suggestions from code review Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> --- requirements.txt | 1 - tests/image/test_uqi.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9b1500e5fac..5d7a0bc464d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,3 @@ numpy>=1.17.2 torch>=1.3.1 pyDeprecate==0.3.* packaging -typing_extensions diff --git a/tests/image/test_uqi.py b/tests/image/test_uqi.py index e8042e203cd..14a92011f50 100644 --- a/tests/image/test_uqi.py +++ b/tests/image/test_uqi.py @@ -20,8 +20,8 @@ from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.functional import universal_image_quality_index -from torchmetrics.image import UniversalImageQualityIndex +from torchmetrics.functional.image.uqi import universal_image_quality_index +from torchmetrics.image.uqi import UniversalImageQualityIndex seed_all(42) From c6e577e1cbe59c10c73e4a1257ee4ff0160c716f Mon Sep 17 00:00:00 2001 From: Ankita Sharma Date: Mon, 7 Feb 2022 12:45:01 +0530 Subject: [PATCH 25/26] rename UQI --- torchmetrics/functional/image/uqi.py | 2 +- torchmetrics/image/uqi.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmetrics/functional/image/uqi.py b/torchmetrics/functional/image/uqi.py index 795e518580d..b03421e34f2 100644 --- a/torchmetrics/functional/image/uqi.py +++ b/torchmetrics/functional/image/uqi.py @@ -147,7 +147,7 @@ def universal_image_quality_index( data_range: Range of the image. If ``None``, it is determined from the image (max - min) Return: - Tensor with UQI score + Tensor with UniversalImageQualityIndex score Raises: TypeError: diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index ed29ded616e..78f668d9e73 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -24,7 +24,7 @@ class UniversalImageQualityIndex(Metric): - """Computes Universal Image Quality Index (UQI_). + """Computes Universal Image Quality Index (UniversalImageQualityIndex_). Args: kernel_size: size of the gaussian kernel @@ -38,7 +38,7 @@ class UniversalImageQualityIndex(Metric): data_range: Range of the image. If ``None``, it is determined from the image (max - min) Return: - Tensor with UQI score + Tensor with UniversalImageQualityIndex score Example: >>> from torchmetrics import UniversalImageQualityIndex @@ -69,7 +69,7 @@ def __init__( process_group=process_group, ) rank_zero_warn( - "Metric `UQI` will save all targets and" + "Metric `UniversalImageQualityIndex` will save all targets and" " predictions in buffer. For large datasets this may lead" " to large memory footprint." ) From 459fbb55c2d68c777a9bc91e48e0ee24c549de13 Mon Sep 17 00:00:00 2001 From: Ankita Sharma Date: Mon, 7 Feb 2022 23:14:24 +0530 Subject: [PATCH 26/26] Update torchmetrics/image/uqi.py Co-authored-by: Jirka Borovec --- torchmetrics/image/uqi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/image/uqi.py b/torchmetrics/image/uqi.py index 78f668d9e73..eefb98a2b6e 100644 --- a/torchmetrics/image/uqi.py +++ b/torchmetrics/image/uqi.py @@ -51,7 +51,7 @@ class UniversalImageQualityIndex(Metric): preds: List[Tensor] target: List[Tensor] - higher_is_better = True + higher_is_better: bool = True def __init__( self,