Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding psnrb #1421

Merged
merged 50 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
f487bad
Adding psnrb
soma2000-lang Jan 1, 2023
c4e3fe3
Adding the changes suggested
soma2000-lang Jan 4, 2023
b0e963c
Adding the changes suggested
soma2000-lang Jan 4, 2023
f85d3c8
changes
soma2000-lang Jan 4, 2023
63df2b7
changes
soma2000-lang Jan 4, 2023
ed1a6cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2023
0b2b9dc
imports
Borda Jan 6, 2023
fb450d4
cls name
Borda Jan 6, 2023
64edfaf
Merge branch 'master' into psnrb
Borda Jan 6, 2023
87142b8
precommit
Borda Jan 6, 2023
5a91e17
Merge branch 'psnrb' of https://github.com/soma2000-lang/metrics into…
Borda Jan 6, 2023
9b4207c
fix changelog
SkafteNicki Jan 24, 2023
3c805b2
remove unwanted files
SkafteNicki Jan 24, 2023
32bd8b3
rename file
SkafteNicki Jan 24, 2023
59c7144
fix docstring
SkafteNicki Jan 24, 2023
1f10784
Merge branch 'master' into psnrb
SkafteNicki Jan 24, 2023
af0ea3f
Merge branch 'master' into psnrb
Borda Jan 30, 2023
303d9ba
small fixes
SkafteNicki Feb 3, 2023
8678087
Merge branch 'master' into psnrb
Borda Feb 6, 2023
0a7944b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2023
c812d8b
doctest
Borda Feb 6, 2023
530b6a4
Apply suggestions from code review
Borda Feb 6, 2023
bed471a
Merge branch 'master' into psnrb
Borda Feb 7, 2023
1378054
Merge branch 'master' into psnrb
Borda Feb 17, 2023
c45b70f
Merge branch 'master' into psnrb
Borda Feb 18, 2023
3b550cf
Merge branch 'master' into psnrb
Borda Feb 22, 2023
97fe8bf
Merge branch 'psnrb' of https://github.com/soma2000-lang/metrics into…
SkafteNicki Feb 24, 2023
d085733
Merge branch 'master' into psnrb
SkafteNicki Feb 24, 2023
317756f
fixes
SkafteNicki Feb 24, 2023
2e487a4
Merge branch 'master' into psnrb
Borda Feb 27, 2023
a990ce6
Merge branch 'master' into psnrb
soma2000-lang Feb 27, 2023
bd84dc0
Merge branch 'master' into psnrb
Borda Feb 28, 2023
5f7d496
Merge branch 'master' into psnrb
Borda Mar 6, 2023
e158f41
Merge branch 'master' into psnrb
Borda Mar 6, 2023
d6de718
Merge branch 'master' into psnrb
Borda Mar 21, 2023
cc7eab5
Merge branch 'master' into psnrb
Borda Mar 31, 2023
53ce7ec
Merge branch 'psnrb' of https://github.com/soma2000-lang/metrics into…
SkafteNicki Apr 14, 2023
5012c5c
merge master
SkafteNicki Apr 14, 2023
4133e75
fix changelog
SkafteNicki Apr 14, 2023
cad1cb3
fix
SkafteNicki Apr 14, 2023
e89623b
fix implementation and tests
SkafteNicki Apr 14, 2023
ebb2bfe
Merge branch 'master' into psnrb
SkafteNicki Apr 15, 2023
4ecde1d
docs: Import through torchmetrics.image
stancld Apr 15, 2023
fac071f
Apply suggestions from code review
stancld Apr 15, 2023
8393c5e
Update regex match accordingly in tests/unittests/image/test_psnrb.py
stancld Apr 16, 2023
09818dd
Merge branch 'master' into psnrb
mergify[bot] Apr 17, 2023
489b06e
Merge branch 'master' into psnrb
Borda Apr 17, 2023
8075256
Merge branch 'master' into psnrb
mergify[bot] Apr 17, 2023
36dfff0
Merge branch 'master' into psnrb
mergify[bot] Apr 17, 2023
8181e61
Merge branch 'master' into psnrb
mergify[bot] Apr 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
)


- Added `PSNRB` metric ([#1421](https://github.com/Lightning-AI/metrics/pull/1421))


- Added new detection metric `PanopticQuality` ([#929](https://github.com/PyTorchLightning/metrics/pull/929))


- Added `ClassificationTask` Enum and use in metrics ([#1479](https://github.com/Lightning-AI/metrics/pull/1479))


Expand Down
23 changes: 23 additions & 0 deletions docs/source/image/peak_signal_to_noise_with_block.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Peak Signal To Noise Ratio With Blocked Effect
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

##############################################
Peak Signal To Noise Ratio With Blocked Effect
##############################################

Module Interface
________________

.. autoclass:: torchmetrics.PeakSignalToNoiseRatioWithBlockedEffect
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.peak_signal_noise_ratio_with_blocked_effect
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,5 @@
.. _Multilabel coverage error: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34
.. _Panoptic Quality: https://arxiv.org/abs/1801.00868
.. _torchmetrics mAP example: https://github.com/Lightning-AI/metrics/blob/master/examples/detection_map.py
.. _Peak Signal to Noise Ratio With Blocked Effect: https://ieeexplore.ieee.org/abstract/document/5535179
.. _Minkowski Distance: https://en.wikipedia.org/wiki/Minkowski_distance
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ErrorRelativeGlobalDimensionlessSynthesis,
MultiScaleStructuralSimilarityIndexMeasure,
PeakSignalNoiseRatio,
PeakSignalNoiseRatioWithBlockedEffect,
SpectralAngleMapper,
SpectralDistortionIndex,
StructuralSimilarityIndexMeasure,
Expand Down Expand Up @@ -159,6 +160,7 @@
"Precision",
"PrecisionRecallCurve",
"PeakSignalNoiseRatio",
"PeakSignalNoiseRatioWithBlockedEffect",
"R2Score",
"Recall",
"RetrievalFallOut",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.psnrb import peak_signal_noise_ratio_with_blocked_effect
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.functional.image.ssim import (
multiscale_structural_similarity_index_measure,
Expand Down Expand Up @@ -153,6 +154,7 @@
"precision",
"precision_recall_curve",
"peak_signal_noise_ratio",
"peak_signal_noise_ratio_with_blocked_effect",
"r2_score",
"recall",
"retrieval_average_precision",
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis # noqa: F401
from torchmetrics.functional.image.gradients import image_gradients # noqa: F401
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio # noqa: F401
from torchmetrics.functional.image.psnrb import peak_signal_noise_ratio_with_blocked_effect # noqa: F401
from torchmetrics.functional.image.sam import spectral_angle_mapper # noqa: F401
from torchmetrics.functional.image.ssim import ( # noqa: F401
multiscale_structural_similarity_index_measure,
Expand Down
209 changes: 209 additions & 0 deletions src/torchmetrics/functional/image/psnrb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union

import torch
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.utilities import rank_zero_warn, reduce


def _compute_bef(target: Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None, block_size=8) -> Tensor:
if dim == 3:
height, width, channels = target.Size
elif dim == 2:
height, width = target.Size
channels = 1
else:
raise ValueError("Not a 1-channel/3-channel grayscale image")
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

if channels > 1:
raise ValueError("Not for color images")

h = torch.arange(width - 1)
h_b = torch.tensor(range(block_size - 1, width - 1, block_size))
h_bc = torch.tensor(list(set(h).symmetric_difference(h_b)))

v = torch.arange(height - 1)
v_b = torch.tensor(range(block_size - 1, height - 1, block_size))
v_bc = torch.tensor(list(set(v).symmetric_difference(v_b)))

d_b = 0
d_bc = 0

# h_b for loop
h_b = torch.arange(0, target.shape[1] - 1, dtype=torch.long)
h_bc = h_b + 1
v_b = torch.arange(0, target.shape[0] - 1, dtype=torch.long)
v_bc = v_b + 1
diff = target.gather(1, h_b.unsqueeze(-1)) - torch.gather(1, h_b.unsqueeze(-1))
d_b += torch.sum(torch.square(diff))
diff = torch.gather(0, v_b.unsqueeze(0)) - torch.gather(0, v_b.unsqueeze(0))
d_b += torch.sum(torch.square(diff))

diff = torch.gather(1, h_bc.unsqueeze(-1)) - torch.gather(1, h_b.unsqueeze(-1))
d_bc += torch.sum(torch.square(diff))
diff = torch.gather(0, v_bc.unsqueeze(0)) - torch.gather(0, v_b.unsqueeze(0))
d_bc += torch.sum(torch.square(diff))

# N code
n_hb = height * (width / block_size) - 1
n_hbc = (height * (width - 1)) - n_hb
n_vb = width * (height / block_size) - 1
n_vbc = (width * (height - 1)) - n_vb

# D code
d_b /= n_hb + n_vb
d_bc /= n_hbc + n_vbc

# Log
t = torch.log2(block_size) / torch.log2(min(height, width)) if d_b > d_bc else 0

# BEF
bef = t * (d_b - d_bc)

return bef


def _psnrb_compute(
sum_squared_error: Tensor,
bef: Tensor,
n_obs: Tensor,
data_range: Tensor,
base: float = 10.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
"""Computes peak signal-to-noise ratio.

Args:
sum_squared_error: Sum of square of errors over all observations
n_obs: Number of predictions or observations
data_range: the range of the data. If None, it is determined from the data (max - min).
``data_range`` must be given when ``dim`` is not None.
base: a base of a logarithm to use
reduction: a method to reduce metric scores over labels:

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'`` or ``None``: no reduction will be applied

Example:
>>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> data_range = target.max() - target.min()
>>> sum_squared_error, n_obs = _psnrb_update(preds, target)
>>> _psnrb_compute(sum_squared_error, n_obs, data_range)
tensor(2.5527)
"""
sum_squared_error = sum_squared_error / n_obs + bef
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error)
psnr_vals = psnr_base_e * (10 / torch.log(tensor(base)))
return reduce(psnr_vals, reduction=reduction)


def _psnrb_update(
preds: Tensor, target: Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None, block_size: int = 8
) -> Tuple[Tensor, Tensor, Tensor]:
"""Updates and returns variables required to compute peak signal-to-noise ratio.

Args:
preds: Predicted tensor
target: Ground truth tensor
dim: Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is
None meaning scores will be reduced across all dimensions.
"""
if dim is None:
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
n_obs = tensor(target.numel(), device=target.device)
bef = _compute_bef(preds, dim=0, block_size=block_size)
return sum_squared_error, n_obs

diff = preds - target
sum_squared_error = torch.sum(diff * diff, dim=dim)

dim_list = [dim] if isinstance(dim, int) else list(dim)
if not dim_list:
n_obs = tensor(target.numel(), device=target.device)
else:
n_obs = tensor(target.size(), device=target.device)[dim_list].prod()
n_obs = n_obs.expand_as(sum_squared_error)

bef = _compute_bef(preds, dim=dim, block_size=block_size)

return sum_squared_error, bef, n_obs


def peak_signal_noise_ratio_with_blocked_effect(
preds: Tensor,
target: Tensor,
block_size: int = 8,
data_range: Optional[float] = None,
base: float = 10.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
dim: Optional[Union[int, Tuple[int, ...]]] = None,
) -> Tensor:
"""Computes `Peak Signal to Noise Ratio With Blocked Effect` (PSNRB) metrics, which is defined as.

.. math:: \text{PSNRB}(I, J) = 10 * \\log_{10} \\left(\frac{\\max(I)^2}{\text{MSE}(I, J)-\text{B}(I, J)}\right)

Where :math:`\text{MSE}` denotes the `mean-squared-error`_ function.

Args:
preds: estimated signal
target: groun truth signal
data_range: the range of the data. If None, it is determined from the data (max - min).
``data_range`` must be given when ``dim`` is not None.
base: a base of a logarithm to use
reduction: a method to reduce metric score over labels:

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'`` or None``: no reduction will be applied

dim:
Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is
None meaning scores will be reduced across all dimensions.

Return:
Tensor with PSNR score

Raises:
ValueError:
If ``dim`` is not ``None`` and ``data_range`` is not provided.

Example:
>>> from torchmetrics.functional import peak_signal_noise_ratio_with_blocked_effect
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> peak_signal_noise_ratio_with_blocked_effect(pred, target)
tensor(2.5527)

.. note::
Half precision is only support on GPU for this metric
"""
if dim is None and reduction != "elementwise_mean":
rank_zero_warn(f"The `reduction={reduction}` will not have any effect when `dim` is None.")

if data_range is None:
if dim is not None:
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate
# `data_range` in the future.
raise ValueError("The `data_range` must be given when `dim` is not None.")

data_range = target.max() - target.min()
else:
data_range = tensor(float(data_range))
sum_squared_error, bef, n_obs = _psnrb_update(preds, target, dim=dim, block_size=block_size)
return _psnrb_compute(sum_squared_error, bef, n_obs, data_range, base=base, reduction=reduction)
1 change: 1 addition & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchmetrics.image.d_lambda import SpectralDistortionIndex # noqa: F401
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis # noqa: F401
from torchmetrics.image.psnr import PeakSignalNoiseRatio # noqa: F401
from torchmetrics.image.psnrb import PeakSignalNoiseRatioWithBlockedEffect # noqa: F401
from torchmetrics.image.sam import SpectralAngleMapper # noqa: F401
from torchmetrics.image.ssim import ( # noqa: F401
MultiScaleStructuralSimilarityIndexMeasure,
Expand Down
Loading