Skip to content

Commit c0e4250

Browse files
ankitaS11pre-commit-ci[bot]Bordajustusschockstancld
authored
Added new image metric - UQI (#824)
* Added new metric - UQI * Registered UQI to functional init; tested locally * Testcases added for UQI * Apply suggestions from code review * Update requirements.txt Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Daniel Stancl <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 7330e2e commit c0e4250

File tree

13 files changed

+526
-46
lines changed

13 files changed

+526
-46
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
- Added support for `MetricCollection` in `MetricTracker` ([#718](https://github.com/PyTorchLightning/metrics/pull/718))
1515

16+
17+
- Added new image metric `UniversalImageQualityIndex` ([#824](https://github.com/PyTorchLightning/metrics/pull/824))
18+
19+
1620
### Changed
1721

1822

docs/source/links.rst

+1
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,4 @@
7878
.. _TER: https://aclanthology.org/2006.amta-papers.25.pdf
7979
.. _ExtendedEditDistance: https://aclanthology.org/W19-5359.pdf
8080
.. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
81+
.. _UniversalImageQualityIndex: https://ieeexplore.ieee.org/document/995823

docs/source/references/functional.rst

+6
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ peak_signal_noise_ratio [func]
270270
.. autofunction:: torchmetrics.functional.peak_signal_noise_ratio
271271
:noindex:
272272

273+
universal_image_quality_index [func]
274+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
275+
276+
.. autofunction:: torchmetrics.functional.universal_image_quality_index
277+
:noindex:
278+
273279

274280
**********
275281
Regression

docs/source/references/modules.rst

+6
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,12 @@ StructuralSimilarityIndexMeasure
404404
.. autoclass:: torchmetrics.StructuralSimilarityIndexMeasure
405405
:noindex:
406406

407+
UniversalImageQualityIndex
408+
~~~~~~~~~~~~~~~~~~~~~~~~~~
409+
410+
.. autoclass:: torchmetrics.UniversalImageQualityIndex
411+
:noindex:
412+
407413
*********
408414
Detection
409415
*********

tests/image/test_uqi.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from collections import namedtuple
15+
from functools import partial
16+
17+
import pytest
18+
import torch
19+
from skimage.metrics import structural_similarity
20+
21+
from tests.helpers import seed_all
22+
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
23+
from torchmetrics.functional.image.uqi import universal_image_quality_index
24+
from torchmetrics.image.uqi import UniversalImageQualityIndex
25+
26+
seed_all(42)
27+
28+
# UQI is SSIM with both constants k1 and k2 as 0
29+
skimage_uqi = partial(structural_similarity, k1=0, k2=0)
30+
31+
Input = namedtuple("Input", ["preds", "target", "multichannel"])
32+
33+
_inputs = []
34+
for size, channel, coef, multichannel, dtype in [
35+
(12, 3, 0.9, True, torch.float),
36+
(13, 1, 0.8, False, torch.float32),
37+
(14, 1, 0.7, False, torch.double),
38+
(15, 3, 0.6, True, torch.float64),
39+
]:
40+
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
41+
_inputs.append(
42+
Input(
43+
preds=preds,
44+
target=preds * coef,
45+
multichannel=multichannel,
46+
)
47+
)
48+
49+
50+
def _sk_uqi(preds, target, data_range, multichannel, kernel_size):
51+
c, h, w = preds.shape[-3:]
52+
sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
53+
sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
54+
if not multichannel:
55+
sk_preds = sk_preds[:, :, :, 0]
56+
sk_target = sk_target[:, :, :, 0]
57+
58+
return skimage_uqi(
59+
sk_target,
60+
sk_preds,
61+
data_range=data_range,
62+
multichannel=multichannel,
63+
gaussian_weights=True,
64+
win_size=kernel_size,
65+
sigma=1.5,
66+
use_sample_covariance=False,
67+
)
68+
69+
70+
@pytest.mark.parametrize(
71+
"preds, target, multichannel",
72+
[(i.preds, i.target, i.multichannel) for i in _inputs],
73+
)
74+
@pytest.mark.parametrize("kernel_size", [5, 11])
75+
class TestUQI(MetricTester):
76+
atol = 6e-3
77+
78+
@pytest.mark.parametrize("ddp", [True, False])
79+
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
80+
def test_uqi(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_step):
81+
self.run_class_metric_test(
82+
ddp,
83+
preds,
84+
target,
85+
UniversalImageQualityIndex,
86+
partial(_sk_uqi, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size),
87+
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
88+
dist_sync_on_step=dist_sync_on_step,
89+
)
90+
91+
def test_uqi_functional(self, preds, target, multichannel, kernel_size):
92+
self.run_functional_metric_test(
93+
preds,
94+
target,
95+
universal_image_quality_index,
96+
partial(_sk_uqi, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size),
97+
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
98+
)
99+
100+
# UQI half + cpu does not work due to missing support in torch.log
101+
@pytest.mark.xfail(reason="UQI metric does not support cpu + half precision")
102+
def test_uqi_half_cpu(self, preds, target, multichannel, kernel_size):
103+
self.run_precision_test_cpu(
104+
preds, target, UniversalImageQualityIndex, universal_image_quality_index, {"data_range": 1.0}
105+
)
106+
107+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
108+
def test_uqi_half_gpu(self, preds, target, multichannel, kernel_size):
109+
self.run_precision_test_gpu(
110+
preds, target, UniversalImageQualityIndex, universal_image_quality_index, {"data_range": 1.0}
111+
)
112+
113+
114+
@pytest.mark.parametrize(
115+
["pred", "target", "kernel", "sigma"],
116+
[
117+
([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
118+
([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
119+
([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
120+
([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
121+
([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
122+
([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input
123+
([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input
124+
([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input
125+
([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input
126+
],
127+
)
128+
def test_uqi_invalid_inputs(pred, target, kernel, sigma):
129+
pred_t = torch.rand(pred)
130+
target_t = torch.rand(target, dtype=torch.float64)
131+
with pytest.raises(TypeError):
132+
universal_image_quality_index(pred_t, target_t)
133+
134+
pred = torch.rand(pred)
135+
target = torch.rand(target)
136+
with pytest.raises(ValueError):
137+
universal_image_quality_index(pred, target, kernel, sigma)
138+
139+
140+
def test_uqi_unequal_kernel_size():
141+
"""Test the case where kernel_size[0] != kernel_size[1]"""
142+
preds = torch.tensor(
143+
[
144+
[
145+
[
146+
[1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0],
147+
[1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
148+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
149+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
150+
[0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
151+
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
152+
[1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0],
153+
]
154+
]
155+
]
156+
)
157+
target = torch.tensor(
158+
[
159+
[
160+
[
161+
[1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0],
162+
[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0],
163+
[1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0],
164+
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0],
165+
[1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0],
166+
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0],
167+
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
168+
]
169+
]
170+
]
171+
)
172+
# kernel order matters
173+
torch.allclose(universal_image_quality_index(preds, target, kernel_size=(3, 5)), torch.tensor(0.10662283))
174+
torch.allclose(universal_image_quality_index(preds, target, kernel_size=(5, 3)), torch.tensor(0.10662283))

torchmetrics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
MultiScaleStructuralSimilarityIndexMeasure,
5151
PeakSignalNoiseRatio,
5252
StructuralSimilarityIndexMeasure,
53+
UniversalImageQualityIndex,
5354
)
5455
from torchmetrics.metric import Metric # noqa: E402
5556
from torchmetrics.regression import ( # noqa: E402
@@ -159,6 +160,7 @@
159160
"SumMetric",
160161
"SymmetricMeanAbsolutePercentageError",
161162
"TranslationEditRate",
163+
"UniversalImageQualityIndex",
162164
"WordErrorRate",
163165
"CharErrorRate",
164166
"MatchErrorRate",

torchmetrics/functional/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
multiscale_structural_similarity_index_measure,
4040
structural_similarity_index_measure,
4141
)
42+
from torchmetrics.functional.image.uqi import universal_image_quality_index
4243
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity
4344
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
4445
from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity
@@ -142,6 +143,7 @@
142143
"stat_scores",
143144
"symmetric_mean_absolute_percentage_error",
144145
"translation_edit_rate",
146+
"universal_image_quality_index",
145147
"word_error_rate",
146148
"char_error_rate",
147149
"match_error_rate",

torchmetrics/functional/image/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
multiscale_structural_similarity_index_measure,
1818
structural_similarity_index_measure,
1919
)
20+
from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Sequence
2+
3+
import torch
4+
from torch import Tensor
5+
6+
7+
def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
8+
"""Computes 1D gaussian kernel.
9+
10+
Args:
11+
kernel_size: size of the gaussian kernel
12+
sigma: Standard deviation of the gaussian kernel
13+
dtype: data type of the output tensor
14+
device: device of the output tensor
15+
16+
Example:
17+
>>> _gaussian(3, 1, torch.float, 'cpu')
18+
tensor([[0.2741, 0.4519, 0.2741]])
19+
"""
20+
dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device)
21+
gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)
22+
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
23+
24+
25+
def _gaussian_kernel(
26+
channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
27+
) -> Tensor:
28+
"""Computes 2D gaussian kernel.
29+
30+
Args:
31+
channel: number of channels in the image
32+
kernel_size: size of the gaussian kernel as a tuple (h, w)
33+
sigma: Standard deviation of the gaussian kernel
34+
dtype: data type of the output tensor
35+
device: device of the output tensor
36+
37+
Example:
38+
>>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu")
39+
tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030],
40+
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
41+
[0.0219, 0.0983, 0.1621, 0.0983, 0.0219],
42+
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
43+
[0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]])
44+
"""
45+
46+
gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
47+
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
48+
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)
49+
50+
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])

torchmetrics/functional/image/ssim.py

+1-46
Original file line numberDiff line numberDiff line change
@@ -18,56 +18,11 @@
1818
from torch.nn import functional as F
1919
from typing_extensions import Literal
2020

21+
from torchmetrics.functional.image.helper import _gaussian_kernel
2122
from torchmetrics.utilities.checks import _check_same_shape
2223
from torchmetrics.utilities.distributed import reduce
2324

2425

25-
def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
26-
"""Computes 1D gaussian kernel.
27-
28-
Args:
29-
kernel_size: size of the gaussian kernel
30-
sigma: Standard deviation of the gaussian kernel
31-
dtype: data type of the output tensor
32-
device: device of the output tensor
33-
34-
Example:
35-
>>> _gaussian(3, 1, torch.float, 'cpu')
36-
tensor([[0.2741, 0.4519, 0.2741]])
37-
"""
38-
dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device)
39-
gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)
40-
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
41-
42-
43-
def _gaussian_kernel(
44-
channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
45-
) -> Tensor:
46-
"""Computes 2D gaussian kernel.
47-
48-
Args:
49-
channel: number of channels in the image
50-
kernel_size: size of the gaussian kernel as a tuple (h, w)
51-
sigma: Standard deviation of the gaussian kernel
52-
dtype: data type of the output tensor
53-
device: device of the output tensor
54-
55-
Example:
56-
>>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu")
57-
tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030],
58-
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
59-
[0.0219, 0.0983, 0.1621, 0.0983, 0.0219],
60-
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
61-
[0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]])
62-
"""
63-
64-
gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
65-
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
66-
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)
67-
68-
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])
69-
70-
7126
def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
7227
"""Updates and returns variables required to compute Structural Similarity Index Measure. Checks for same shape
7328
and type of the input tensors.

0 commit comments

Comments
 (0)