Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new image metric - UQI #824

Merged
merged 36 commits into from
Feb 7, 2022
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
fea56ee
Added new metric - UQI
ankitaS11 Feb 3, 2022
30c5c50
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2022
b2078b8
minor fixes
ankitaS11 Feb 3, 2022
de73401
Merge branch 'feature/799_UQI_metric' of github.com:ankitaS11/metrics…
ankitaS11 Feb 3, 2022
a688cb4
Added missing import
ankitaS11 Feb 3, 2022
ae11a3b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2022
29420c2
Registered UQI to functional init; tested locally
ankitaS11 Feb 3, 2022
79263a3
Merge branch 'feature/799_UQI_metric' of github.com:ankitaS11/metrics…
ankitaS11 Feb 3, 2022
3a08596
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2022
53bfe2a
Testcases added for UQI
ankitaS11 Feb 4, 2022
46e09a2
Merge branch 'feature/799_UQI_metric' of github.com:ankitaS11/metrics…
ankitaS11 Feb 4, 2022
d624e67
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2022
c53eb3b
Merge branch 'master' into feature/799_UQI_metric
Borda Feb 5, 2022
ba326f1
Apply suggestions from code review
justusschock Feb 5, 2022
a530e94
Apply suggestions from code review
justusschock Feb 5, 2022
e07b337
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2022
40e68aa
Apply suggestions from code review
justusschock Feb 5, 2022
a73200d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2022
e793baf
Merge branch 'master' into feature/799_UQI_metric
justusschock Feb 5, 2022
5f8d39b
Apply suggestions from code review
justusschock Feb 5, 2022
94cd330
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2022
6672a90
Update requirements.txt
justusschock Feb 5, 2022
4951bc6
Address reviews
ankitaS11 Feb 6, 2022
8d0593c
Merge branch 'feature/799_UQI_metric' of github.com:ankitaS11/metrics…
ankitaS11 Feb 6, 2022
e208df8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2022
110488e
minor changes
ankitaS11 Feb 6, 2022
2b1db7b
merge conflict resolved
ankitaS11 Feb 6, 2022
a6d255b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2022
4c650a7
Merge branch 'master' into feature/799_UQI_metric
ankitaS11 Feb 6, 2022
0c7e72e
removed unused imports
ankitaS11 Feb 6, 2022
3aa8dd8
Merge branch 'feature/799_UQI_metric' of github.com:ankitaS11/metrics…
ankitaS11 Feb 6, 2022
c08d391
minor fixes
ankitaS11 Feb 6, 2022
2b59b27
Apply suggestions from code review
ankitaS11 Feb 7, 2022
c6e577e
rename UQI
ankitaS11 Feb 7, 2022
459fbb5
Update torchmetrics/image/uqi.py
ankitaS11 Feb 7, 2022
dbabf5b
Merge branch 'master' into feature/799_UQI_metric
mergify[bot] Feb 7, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for `MetricCollection` in `MetricTracker` ([#718](https://github.com/PyTorchLightning/metrics/pull/718))


- Added new image metric `UniversalImageQualityIndex` ([#824](https://github.com/PyTorchLightning/metrics/pull/824))


### Changed


Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@
.. _TER: https://aclanthology.org/2006.amta-papers.25.pdf
.. _ExtendedEditDistance: https://aclanthology.org/W19-5359.pdf
.. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
.. _UniversalImageQualityIndex: https://ieeexplore.ieee.org/document/995823
6 changes: 6 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ peak_signal_noise_ratio [func]
.. autofunction:: torchmetrics.functional.peak_signal_noise_ratio
:noindex:

universal_image_quality_index [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ankitaS11 marked this conversation as resolved.
Show resolved Hide resolved

.. autofunction:: torchmetrics.functional.universal_image_quality_index
:noindex:


**********
Regression
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ StructuralSimilarityIndexMeasure
.. autoclass:: torchmetrics.StructuralSimilarityIndexMeasure
:noindex:

UniversalImageQualityIndex
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ankitaS11 marked this conversation as resolved.
Show resolved Hide resolved

.. autoclass:: torchmetrics.UniversalImageQualityIndex
:noindex:

*********
Detection
*********
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ numpy>=1.17.2
torch>=1.3.1
pyDeprecate==0.3.*
packaging
typing_extensions
ankitaS11 marked this conversation as resolved.
Show resolved Hide resolved
174 changes: 174 additions & 0 deletions tests/image/test_uqi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch
from skimage.metrics import structural_similarity

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional import universal_image_quality_index
from torchmetrics.image import UniversalImageQualityIndex
ankitaS11 marked this conversation as resolved.
Show resolved Hide resolved

seed_all(42)

# UQI is SSIM with both constants k1 and k2 as 0
skimage_uqi = partial(structural_similarity, k1=0, k2=0)

Input = namedtuple("Input", ["preds", "target", "multichannel"])

_inputs = []
for size, channel, coef, multichannel, dtype in [
(12, 3, 0.9, True, torch.float),
(13, 1, 0.8, False, torch.float32),
(14, 1, 0.7, False, torch.double),
(15, 3, 0.6, True, torch.float64),
]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
_inputs.append(
Input(
preds=preds,
target=preds * coef,
multichannel=multichannel,
)
)


def _sk_uqi(preds, target, data_range, multichannel, kernel_size):
c, h, w = preds.shape[-3:]
sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
if not multichannel:
sk_preds = sk_preds[:, :, :, 0]
sk_target = sk_target[:, :, :, 0]

return skimage_uqi(
sk_target,
sk_preds,
data_range=data_range,
multichannel=multichannel,
gaussian_weights=True,
win_size=kernel_size,
sigma=1.5,
use_sample_covariance=False,
)


@pytest.mark.parametrize(
"preds, target, multichannel",
[(i.preds, i.target, i.multichannel) for i in _inputs],
)
@pytest.mark.parametrize("kernel_size", [5, 11])
class TestUQI(MetricTester):
atol = 6e-3

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_uqi(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
UniversalImageQualityIndex,
partial(_sk_uqi, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
dist_sync_on_step=dist_sync_on_step,
)

def test_uqi_functional(self, preds, target, multichannel, kernel_size):
self.run_functional_metric_test(
preds,
target,
universal_image_quality_index,
partial(_sk_uqi, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
)

# UQI half + cpu does not work due to missing support in torch.log
@pytest.mark.xfail(reason="UQI metric does not support cpu + half precision")
def test_uqi_half_cpu(self, preds, target, multichannel, kernel_size):
self.run_precision_test_cpu(
preds, target, UniversalImageQualityIndex, universal_image_quality_index, {"data_range": 1.0}
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_uqi_half_gpu(self, preds, target, multichannel, kernel_size):
self.run_precision_test_gpu(
preds, target, UniversalImageQualityIndex, universal_image_quality_index, {"data_range": 1.0}
)


@pytest.mark.parametrize(
["pred", "target", "kernel", "sigma"],
[
([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input
([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input
([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input
([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input
],
)
def test_uqi_invalid_inputs(pred, target, kernel, sigma):
pred_t = torch.rand(pred)
target_t = torch.rand(target, dtype=torch.float64)
with pytest.raises(TypeError):
universal_image_quality_index(pred_t, target_t)

pred = torch.rand(pred)
target = torch.rand(target)
with pytest.raises(ValueError):
universal_image_quality_index(pred, target, kernel, sigma)


def test_uqi_unequal_kernel_size():
"""Test the case where kernel_size[0] != kernel_size[1]"""
preds = torch.tensor(
[
[
[
[1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0],
[1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0],
]
]
]
)
target = torch.tensor(
[
[
[
[1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
]
]
]
)
# kernel order matters
torch.allclose(universal_image_quality_index(preds, target, kernel_size=(3, 5)), torch.tensor(0.10662283))
torch.allclose(universal_image_quality_index(preds, target, kernel_size=(5, 3)), torch.tensor(0.10662283))
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
MultiScaleStructuralSimilarityIndexMeasure,
PeakSignalNoiseRatio,
StructuralSimilarityIndexMeasure,
UniversalImageQualityIndex,
)
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.regression import ( # noqa: E402
Expand Down Expand Up @@ -159,6 +160,7 @@
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"TranslationEditRate",
"UniversalImageQualityIndex",
"WordErrorRate",
"CharErrorRate",
"MatchErrorRate",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
)
from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity
Expand Down Expand Up @@ -142,6 +143,7 @@
"stat_scores",
"symmetric_mean_absolute_percentage_error",
"translation_edit_rate",
"universal_image_quality_index",
"word_error_rate",
"char_error_rate",
"match_error_rate",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
)
from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401
50 changes: 50 additions & 0 deletions torchmetrics/functional/image/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Sequence

import torch
from torch import Tensor


def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
"""Computes 1D gaussian kernel.

Args:
kernel_size: size of the gaussian kernel
sigma: Standard deviation of the gaussian kernel
dtype: data type of the output tensor
device: device of the output tensor

Example:
>>> _gaussian(3, 1, torch.float, 'cpu')
tensor([[0.2741, 0.4519, 0.2741]])
"""
dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device)
gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)


def _gaussian_kernel(
channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
"""Computes 2D gaussian kernel.

Args:
channel: number of channels in the image
kernel_size: size of the gaussian kernel as a tuple (h, w)
sigma: Standard deviation of the gaussian kernel
dtype: data type of the output tensor
device: device of the output tensor

Example:
>>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu")
tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030],
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
[0.0219, 0.0983, 0.1621, 0.0983, 0.0219],
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
[0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]])
"""

gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)

return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])
47 changes: 1 addition & 46 deletions torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,56 +18,11 @@
from torch.nn import functional as F
from typing_extensions import Literal

from torchmetrics.functional.image.helper import _gaussian_kernel
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.distributed import reduce


def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
"""Computes 1D gaussian kernel.

Args:
kernel_size: size of the gaussian kernel
sigma: Standard deviation of the gaussian kernel
dtype: data type of the output tensor
device: device of the output tensor

Example:
>>> _gaussian(3, 1, torch.float, 'cpu')
tensor([[0.2741, 0.4519, 0.2741]])
"""
dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device)
gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)


def _gaussian_kernel(
channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
"""Computes 2D gaussian kernel.

Args:
channel: number of channels in the image
kernel_size: size of the gaussian kernel as a tuple (h, w)
sigma: Standard deviation of the gaussian kernel
dtype: data type of the output tensor
device: device of the output tensor

Example:
>>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu")
tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030],
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
[0.0219, 0.0983, 0.1621, 0.0983, 0.0219],
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
[0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]])
"""

gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)

return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])


def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute Structural Similarity Index Measure. Checks for same shape
and type of the input tensors.
Expand Down
Loading