Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add normalize argument to certain image metrics #1246

Merged
merged 9 commits into from
Nov 8, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `TotalVariation` to image package ([#978](https://github.com/Lightning-AI/metrics/pull/978))


- Added `normalize` argument to `Inception`, `FID`, `KID` metrics ([#1246](https://github.com/Lightning-AI/metrics/pull/1246))


### Changed

- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))
Expand Down
15 changes: 12 additions & 3 deletions src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,12 @@ class FrechetInceptionDistance(Metric):
originally proposed in [1].

Using the default feature extraction (Inception v3 using the original weights from [2]), the input is
expected to be mini-batches of 3-channel RGB images of shape (``3 x H x W``) with dtype uint8. All images
will be resized to 299 x 299 which is the size of the original training data. The boolian flag ``real``
determines if the images should update the statistics of the real distribution or the fake distribution.
expected to be mini-batches of 3-channel RGB images of shape (``3 x H x W``). If argument ``normalize``
is ``True`` images are expected to be dtype ``float`` and have values in the ``[0, 1]`` range, else if
``normalize`` is set to ``False`` images are expected to have dtype ``uint8`` and take values in the ``[0, 255]``
range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian
flag ``real`` determines if the images should update the statistics of the real distribution or the
fake distribution.

.. note:: using this metrics requires you to have ``scipy`` install. Either install as ``pip install
torchmetrics[image]`` or ``pip install scipy``
Expand Down Expand Up @@ -211,6 +214,7 @@ def __init__(
self,
feature: Union[int, Module] = 2048,
reset_real_features: bool = True,
normalize: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -241,6 +245,10 @@ def __init__(
raise ValueError("Argument `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features

if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize

mx_nb_feets = (num_features, num_features)
self.add_state("real_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum")
self.add_state("real_features_cov_sum", torch.zeros(mx_nb_feets).double(), dist_reduce_fx="sum")
Expand All @@ -257,6 +265,7 @@ def update(self, imgs: Tensor, real: bool) -> None: # type: ignore
imgs: tensor with images feed to the feature extractor
real: bool indicating if ``imgs`` belong to the real or the fake distribution
"""
imgs = (imgs * 255).byte() if self.normalize else imgs
features = self.inception(imgs)
self.orig_dtype = features.dtype
features = features.double()
Expand Down
12 changes: 10 additions & 2 deletions src/torchmetrics/image/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ class InceptionScore(Metric):
both a mean and standard deviation of the score are returned. The metric was originally proposed in [1].

Using the default feature extraction (Inception v3 using the original weights from [2]), the input is
expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images
will be resized to 299 x 299 which is the size of the original training data.
expected to be mini-batches of 3-channel RGB images of shape (``3 x H x W``). If argument ``normalize``
is ``True`` images are expected to be dtype ``float`` and have values in the ``[0, 1]`` range, else if
``normalize`` is set to ``False`` images are expected to have dtype uint8 and take values in the ``[0, 255]``
range. All images will be resized to 299 x 299 which is the size of the original training data.

.. note:: using this metric with the default feature extractor requires that ``torch-fidelity``
is installed. Either install as ``pip install torchmetrics[image]`` or
Expand Down Expand Up @@ -96,6 +98,7 @@ def __init__(
self,
feature: Union[str, int, Module] = "logits_unbiased",
splits: int = 10,
normalize: bool = False,
Borda marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -124,6 +127,10 @@ def __init__(
else:
raise TypeError("Got unknown input to argument `feature`")

if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize

self.splits = splits
self.add_state("features", [], dist_reduce_fx=None)

Expand All @@ -133,6 +140,7 @@ def update(self, imgs: Tensor) -> None: # type: ignore
Args:
imgs: tensor with images feed to the feature extractor
"""
imgs = (imgs * 255).byte() if self.normalize else imgs
features = self.inception(imgs)
self.features.append(features)

Expand Down
14 changes: 12 additions & 2 deletions src/torchmetrics/image/kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ class KernelInceptionDistance(Metric):
subsets to be able to both get the mean and standard deviation of KID.

Using the default feature extraction (Inception v3 using the original weights from [2]), the input is
expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images
will be resized to 299 x 299 which is the size of the original training data.
expected to be mini-batches of 3-channel RGB images of shape (``3 x H x W``). If argument ``normalize``
is ``True`` images are expected to be dtype ``float`` and have values in the ``[0, 1]`` range, else if
``normalize`` is set to ``False`` images are expected to have dtype ``uint8`` and take values in the ``[0, 255]``
range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian
flag ``real`` determines if the images should update the statistics of the real distribution or the
fake distribution.

.. note:: using this metric with the default feature extractor requires that ``torch-fidelity``
is installed. Either install as ``pip install torchmetrics[image]`` or
Expand Down Expand Up @@ -164,6 +168,7 @@ def __init__(
gamma: Optional[float] = None,
coef: float = 1.0,
reset_real_features: bool = True,
normalize: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -216,6 +221,10 @@ def __init__(
raise ValueError("Arugment `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features

if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize

# states for extracted features
self.add_state("real_features", [], dist_reduce_fx=None)
self.add_state("fake_features", [], dist_reduce_fx=None)
Expand All @@ -227,6 +236,7 @@ def update(self, imgs: Tensor, real: bool) -> None:
imgs: tensor with images feed to the feature extractor
real: bool indicating if ``imgs`` belong to the real or the fake distribution
"""
imgs = (imgs * 255).byte() if self.normalize else imgs
features = self.inception(imgs)

if real:
Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from contextlib import nullcontext as does_not_raise

import pytest
import torch
Expand Down Expand Up @@ -181,3 +182,19 @@ def test_reset_real_features_arg(reset_real_features):
assert metric.real_features_num_samples == 2
assert metric.real_features_sum.shape == torch.Size([64])
assert metric.real_features_cov_sum.shape == torch.Size([64, 64])


@pytest.mark.parametrize(
"normalize, expectation, message",
[
(True, does_not_raise(), None),
(False, pytest.raises(ValueError), "Expecting image as torch.Tensor with dtype=torch.uint8"),
],
)
def test_normalize_arg(normalize, expectation, message):
"""Test that normalize argument works as expected."""
img = torch.rand(2, 3, 299, 299)
metric = FrechetInceptionDistance(normalize=normalize)
with expectation as e:
metric.update(img, real=True)
assert message is None or message in str(e)
17 changes: 17 additions & 0 deletions tests/unittests/image/test_inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from contextlib import nullcontext as does_not_raise

import pytest
import torch
Expand Down Expand Up @@ -126,3 +127,19 @@ def test_compare_is(tmpdir, compute_on_cpu):
tm_mean, _ = metric.compute()

assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid["inception_score_mean"]]), atol=1e-3)


@pytest.mark.parametrize(
"normalize, expectation, message",
[
(True, does_not_raise(), None),
(False, pytest.raises(ValueError), "Expecting image as torch.Tensor with dtype=torch.uint8"),
],
)
def test_normalize_arg(normalize, expectation, message):
"""Test that normalize argument works as expected."""
img = torch.rand(2, 3, 299, 299)
metric = InceptionScore(normalize=normalize)
with expectation as e:
metric.update(img)
assert message is None or message in str(e)
17 changes: 17 additions & 0 deletions tests/unittests/image/test_kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from contextlib import nullcontext as does_not_raise

import pytest
import torch
Expand Down Expand Up @@ -189,3 +190,19 @@ def test_reset_real_features_arg(reset_real_features):
else:
assert len(metric.real_features) == 1
assert list(metric.real_features[0].shape) == [2, 64]


@pytest.mark.parametrize(
"normalize, expectation, message",
[
(True, does_not_raise(), None),
(False, pytest.raises(ValueError), "Expecting image as torch.Tensor with dtype=torch.uint8"),
],
)
def test_normalize_arg(normalize, expectation, message):
"""Test that normalize argument works as expected."""
img = torch.rand(2, 3, 299, 299)
metric = KernelInceptionDistance(normalize=normalize)
with expectation as e:
metric.update(img, real=True)
assert message is None or message in str(e)