Skip to content

Commit

Permalink
Add normalize argument to certain image metrics (#1246)
Browse files Browse the repository at this point in the history
* implementation
* testing
* changelog
  • Loading branch information
SkafteNicki authored Nov 8, 2022
1 parent ec5dfc8 commit e0f6406
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `TotalVariation` to image package ([#978](https://github.com/Lightning-AI/metrics/pull/978))


- Added option to pass `distributed_available_fn` to metrics to allow checks for custom communication backend for making `dist_sync_fn` actually useful ([#1301](https://github.com/Lightning-AI/metrics/pull/1301))


- Added `normalize` argument to `Inception`, `FID`, `KID` metrics ([#1246](https://github.com/Lightning-AI/metrics/pull/1246))


- Added `KendallRankCorrCoef` to regression package ([#1271](https://github.com/Lightning-AI/metrics/pull/1271))


Expand Down
15 changes: 12 additions & 3 deletions src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,12 @@ class FrechetInceptionDistance(Metric):
originally proposed in [1].
Using the default feature extraction (Inception v3 using the original weights from [2]), the input is
expected to be mini-batches of 3-channel RGB images of shape (``3 x H x W``) with dtype uint8. All images
will be resized to 299 x 299 which is the size of the original training data. The boolian flag ``real``
determines if the images should update the statistics of the real distribution or the fake distribution.
expected to be mini-batches of 3-channel RGB images of shape (``3 x H x W``). If argument ``normalize``
is ``True`` images are expected to be dtype ``float`` and have values in the ``[0, 1]`` range, else if
``normalize`` is set to ``False`` images are expected to have dtype ``uint8`` and take values in the ``[0, 255]``
range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian
flag ``real`` determines if the images should update the statistics of the real distribution or the
fake distribution.
.. note:: using this metrics requires you to have ``scipy`` install. Either install as ``pip install
torchmetrics[image]`` or ``pip install scipy``
Expand Down Expand Up @@ -211,6 +214,7 @@ def __init__(
self,
feature: Union[int, Module] = 2048,
reset_real_features: bool = True,
normalize: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -241,6 +245,10 @@ def __init__(
raise ValueError("Argument `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features

if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize

mx_nb_feets = (num_features, num_features)
self.add_state("real_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum")
self.add_state("real_features_cov_sum", torch.zeros(mx_nb_feets).double(), dist_reduce_fx="sum")
Expand All @@ -257,6 +265,7 @@ def update(self, imgs: Tensor, real: bool) -> None: # type: ignore
imgs: tensor with images feed to the feature extractor
real: bool indicating if ``imgs`` belong to the real or the fake distribution
"""
imgs = (imgs * 255).byte() if self.normalize else imgs
features = self.inception(imgs)
self.orig_dtype = features.dtype
features = features.double()
Expand Down
12 changes: 10 additions & 2 deletions src/torchmetrics/image/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ class InceptionScore(Metric):
both a mean and standard deviation of the score are returned. The metric was originally proposed in [1].
Using the default feature extraction (Inception v3 using the original weights from [2]), the input is
expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images
will be resized to 299 x 299 which is the size of the original training data.
expected to be mini-batches of 3-channel RGB images of shape (``3 x H x W``). If argument ``normalize``
is ``True`` images are expected to be dtype ``float`` and have values in the ``[0, 1]`` range, else if
``normalize`` is set to ``False`` images are expected to have dtype uint8 and take values in the ``[0, 255]``
range. All images will be resized to 299 x 299 which is the size of the original training data.
.. note:: using this metric with the default feature extractor requires that ``torch-fidelity``
is installed. Either install as ``pip install torchmetrics[image]`` or
Expand Down Expand Up @@ -96,6 +98,7 @@ def __init__(
self,
feature: Union[str, int, Module] = "logits_unbiased",
splits: int = 10,
normalize: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -124,6 +127,10 @@ def __init__(
else:
raise TypeError("Got unknown input to argument `feature`")

if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize

self.splits = splits
self.add_state("features", [], dist_reduce_fx=None)

Expand All @@ -133,6 +140,7 @@ def update(self, imgs: Tensor) -> None: # type: ignore
Args:
imgs: tensor with images feed to the feature extractor
"""
imgs = (imgs * 255).byte() if self.normalize else imgs
features = self.inception(imgs)
self.features.append(features)

Expand Down
14 changes: 12 additions & 2 deletions src/torchmetrics/image/kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ class KernelInceptionDistance(Metric):
subsets to be able to both get the mean and standard deviation of KID.
Using the default feature extraction (Inception v3 using the original weights from [2]), the input is
expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images
will be resized to 299 x 299 which is the size of the original training data.
expected to be mini-batches of 3-channel RGB images of shape (``3 x H x W``). If argument ``normalize``
is ``True`` images are expected to be dtype ``float`` and have values in the ``[0, 1]`` range, else if
``normalize`` is set to ``False`` images are expected to have dtype ``uint8`` and take values in the ``[0, 255]``
range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian
flag ``real`` determines if the images should update the statistics of the real distribution or the
fake distribution.
.. note:: using this metric with the default feature extractor requires that ``torch-fidelity``
is installed. Either install as ``pip install torchmetrics[image]`` or
Expand Down Expand Up @@ -164,6 +168,7 @@ def __init__(
gamma: Optional[float] = None,
coef: float = 1.0,
reset_real_features: bool = True,
normalize: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -216,6 +221,10 @@ def __init__(
raise ValueError("Arugment `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features

if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize

# states for extracted features
self.add_state("real_features", [], dist_reduce_fx=None)
self.add_state("fake_features", [], dist_reduce_fx=None)
Expand All @@ -227,6 +236,7 @@ def update(self, imgs: Tensor, real: bool) -> None:
imgs: tensor with images feed to the feature extractor
real: bool indicating if ``imgs`` belong to the real or the fake distribution
"""
imgs = (imgs * 255).byte() if self.normalize else imgs
features = self.inception(imgs)

if real:
Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from contextlib import nullcontext as does_not_raise

import pytest
import torch
Expand Down Expand Up @@ -181,3 +182,19 @@ def test_reset_real_features_arg(reset_real_features):
assert metric.real_features_num_samples == 2
assert metric.real_features_sum.shape == torch.Size([64])
assert metric.real_features_cov_sum.shape == torch.Size([64, 64])


@pytest.mark.parametrize(
"normalize, expectation, message",
[
(True, does_not_raise(), None),
(False, pytest.raises(ValueError), "Expecting image as torch.Tensor with dtype=torch.uint8"),
],
)
def test_normalize_arg(normalize, expectation, message):
"""Test that normalize argument works as expected."""
img = torch.rand(2, 3, 299, 299)
metric = FrechetInceptionDistance(normalize=normalize)
with expectation as e:
metric.update(img, real=True)
assert message is None or message in str(e)
17 changes: 17 additions & 0 deletions tests/unittests/image/test_inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from contextlib import nullcontext as does_not_raise

import pytest
import torch
Expand Down Expand Up @@ -126,3 +127,19 @@ def test_compare_is(tmpdir, compute_on_cpu):
tm_mean, _ = metric.compute()

assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid["inception_score_mean"]]), atol=1e-3)


@pytest.mark.parametrize(
"normalize, expectation, message",
[
(True, does_not_raise(), None),
(False, pytest.raises(ValueError), "Expecting image as torch.Tensor with dtype=torch.uint8"),
],
)
def test_normalize_arg(normalize, expectation, message):
"""Test that normalize argument works as expected."""
img = torch.rand(2, 3, 299, 299)
metric = InceptionScore(normalize=normalize)
with expectation as e:
metric.update(img)
assert message is None or message in str(e)
17 changes: 17 additions & 0 deletions tests/unittests/image/test_kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from contextlib import nullcontext as does_not_raise

import pytest
import torch
Expand Down Expand Up @@ -189,3 +190,19 @@ def test_reset_real_features_arg(reset_real_features):
else:
assert len(metric.real_features) == 1
assert list(metric.real_features[0].shape) == [2, 64]


@pytest.mark.parametrize(
"normalize, expectation, message",
[
(True, does_not_raise(), None),
(False, pytest.raises(ValueError), "Expecting image as torch.Tensor with dtype=torch.uint8"),
],
)
def test_normalize_arg(normalize, expectation, message):
"""Test that normalize argument works as expected."""
img = torch.rand(2, 3, 299, 299)
metric = KernelInceptionDistance(normalize=normalize)
with expectation as e:
metric.update(img, real=True)
assert message is None or message in str(e)

0 comments on commit e0f6406

Please sign in to comment.