diff --git a/CHANGELOG.md b/CHANGELOG.md index 47a729cd2fa..8de1ee274b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,9 +13,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `TotalVariation` to image package ([#978](https://github.com/Lightning-AI/metrics/pull/978)) + - Added option to pass `distributed_available_fn` to metrics to allow checks for custom communication backend for making `dist_sync_fn` actually useful ([#1301](https://github.com/Lightning-AI/metrics/pull/1301)) +- Added `normalize` argument to `Inception`, `FID`, `KID` metrics ([#1246](https://github.com/Lightning-AI/metrics/pull/1246)) + + - Added `KendallRankCorrCoef` to regression package ([#1271](https://github.com/Lightning-AI/metrics/pull/1271)) diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 8d0f8f7dc49..afd56e46600 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -137,9 +137,12 @@ class FrechetInceptionDistance(Metric): originally proposed in [1]. Using the default feature extraction (Inception v3 using the original weights from [2]), the input is - expected to be mini-batches of 3-channel RGB images of shape (``3 x H x W``) with dtype uint8. All images - will be resized to 299 x 299 which is the size of the original training data. The boolian flag ``real`` - determines if the images should update the statistics of the real distribution or the fake distribution. + expected to be mini-batches of 3-channel RGB images of shape (``3 x H x W``). If argument ``normalize`` + is ``True`` images are expected to be dtype ``float`` and have values in the ``[0, 1]`` range, else if + ``normalize`` is set to ``False`` images are expected to have dtype ``uint8`` and take values in the ``[0, 255]`` + range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian + flag ``real`` determines if the images should update the statistics of the real distribution or the + fake distribution. .. note:: using this metrics requires you to have ``scipy`` install. Either install as ``pip install torchmetrics[image]`` or ``pip install scipy`` @@ -211,6 +214,7 @@ def __init__( self, feature: Union[int, Module] = 2048, reset_real_features: bool = True, + normalize: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -241,6 +245,10 @@ def __init__( raise ValueError("Argument `reset_real_features` expected to be a bool") self.reset_real_features = reset_real_features + if not isinstance(normalize, bool): + raise ValueError("Argument `normalize` expected to be a bool") + self.normalize = normalize + mx_nb_feets = (num_features, num_features) self.add_state("real_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum") self.add_state("real_features_cov_sum", torch.zeros(mx_nb_feets).double(), dist_reduce_fx="sum") @@ -257,6 +265,7 @@ def update(self, imgs: Tensor, real: bool) -> None: # type: ignore imgs: tensor with images feed to the feature extractor real: bool indicating if ``imgs`` belong to the real or the fake distribution """ + imgs = (imgs * 255).byte() if self.normalize else imgs features = self.inception(imgs) self.orig_dtype = features.dtype features = features.double() diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index d2dec74a7e0..5022508567a 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -39,8 +39,10 @@ class InceptionScore(Metric): both a mean and standard deviation of the score are returned. The metric was originally proposed in [1]. Using the default feature extraction (Inception v3 using the original weights from [2]), the input is - expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images - will be resized to 299 x 299 which is the size of the original training data. + expected to be mini-batches of 3-channel RGB images of shape (``3 x H x W``). If argument ``normalize`` + is ``True`` images are expected to be dtype ``float`` and have values in the ``[0, 1]`` range, else if + ``normalize`` is set to ``False`` images are expected to have dtype uint8 and take values in the ``[0, 255]`` + range. All images will be resized to 299 x 299 which is the size of the original training data. .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` is installed. Either install as ``pip install torchmetrics[image]`` or @@ -96,6 +98,7 @@ def __init__( self, feature: Union[str, int, Module] = "logits_unbiased", splits: int = 10, + normalize: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -124,6 +127,10 @@ def __init__( else: raise TypeError("Got unknown input to argument `feature`") + if not isinstance(normalize, bool): + raise ValueError("Argument `normalize` expected to be a bool") + self.normalize = normalize + self.splits = splits self.add_state("features", [], dist_reduce_fx=None) @@ -133,6 +140,7 @@ def update(self, imgs: Tensor) -> None: # type: ignore Args: imgs: tensor with images feed to the feature extractor """ + imgs = (imgs * 255).byte() if self.normalize else imgs features = self.inception(imgs) self.features.append(features) diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index 9dd2ac92816..fb7be620c5e 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -82,8 +82,12 @@ class KernelInceptionDistance(Metric): subsets to be able to both get the mean and standard deviation of KID. Using the default feature extraction (Inception v3 using the original weights from [2]), the input is - expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images - will be resized to 299 x 299 which is the size of the original training data. + expected to be mini-batches of 3-channel RGB images of shape (``3 x H x W``). If argument ``normalize`` + is ``True`` images are expected to be dtype ``float`` and have values in the ``[0, 1]`` range, else if + ``normalize`` is set to ``False`` images are expected to have dtype ``uint8`` and take values in the ``[0, 255]`` + range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian + flag ``real`` determines if the images should update the statistics of the real distribution or the + fake distribution. .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` is installed. Either install as ``pip install torchmetrics[image]`` or @@ -164,6 +168,7 @@ def __init__( gamma: Optional[float] = None, coef: float = 1.0, reset_real_features: bool = True, + normalize: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -216,6 +221,10 @@ def __init__( raise ValueError("Arugment `reset_real_features` expected to be a bool") self.reset_real_features = reset_real_features + if not isinstance(normalize, bool): + raise ValueError("Argument `normalize` expected to be a bool") + self.normalize = normalize + # states for extracted features self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) @@ -227,6 +236,7 @@ def update(self, imgs: Tensor, real: bool) -> None: imgs: tensor with images feed to the feature extractor real: bool indicating if ``imgs`` belong to the real or the fake distribution """ + imgs = (imgs * 255).byte() if self.normalize else imgs features = self.inception(imgs) if real: diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index 39b45cb3146..2e658a394ff 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pickle +from contextlib import nullcontext as does_not_raise import pytest import torch @@ -181,3 +182,19 @@ def test_reset_real_features_arg(reset_real_features): assert metric.real_features_num_samples == 2 assert metric.real_features_sum.shape == torch.Size([64]) assert metric.real_features_cov_sum.shape == torch.Size([64, 64]) + + +@pytest.mark.parametrize( + "normalize, expectation, message", + [ + (True, does_not_raise(), None), + (False, pytest.raises(ValueError), "Expecting image as torch.Tensor with dtype=torch.uint8"), + ], +) +def test_normalize_arg(normalize, expectation, message): + """Test that normalize argument works as expected.""" + img = torch.rand(2, 3, 299, 299) + metric = FrechetInceptionDistance(normalize=normalize) + with expectation as e: + metric.update(img, real=True) + assert message is None or message in str(e) diff --git a/tests/unittests/image/test_inception.py b/tests/unittests/image/test_inception.py index 147dedd0d4e..93198b874c2 100644 --- a/tests/unittests/image/test_inception.py +++ b/tests/unittests/image/test_inception.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pickle +from contextlib import nullcontext as does_not_raise import pytest import torch @@ -126,3 +127,19 @@ def test_compare_is(tmpdir, compute_on_cpu): tm_mean, _ = metric.compute() assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid["inception_score_mean"]]), atol=1e-3) + + +@pytest.mark.parametrize( + "normalize, expectation, message", + [ + (True, does_not_raise(), None), + (False, pytest.raises(ValueError), "Expecting image as torch.Tensor with dtype=torch.uint8"), + ], +) +def test_normalize_arg(normalize, expectation, message): + """Test that normalize argument works as expected.""" + img = torch.rand(2, 3, 299, 299) + metric = InceptionScore(normalize=normalize) + with expectation as e: + metric.update(img) + assert message is None or message in str(e) diff --git a/tests/unittests/image/test_kid.py b/tests/unittests/image/test_kid.py index 263c1b55c86..69ebc320766 100644 --- a/tests/unittests/image/test_kid.py +++ b/tests/unittests/image/test_kid.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pickle +from contextlib import nullcontext as does_not_raise import pytest import torch @@ -189,3 +190,19 @@ def test_reset_real_features_arg(reset_real_features): else: assert len(metric.real_features) == 1 assert list(metric.real_features[0].shape) == [2, 64] + + +@pytest.mark.parametrize( + "normalize, expectation, message", + [ + (True, does_not_raise(), None), + (False, pytest.raises(ValueError), "Expecting image as torch.Tensor with dtype=torch.uint8"), + ], +) +def test_normalize_arg(normalize, expectation, message): + """Test that normalize argument works as expected.""" + img = torch.rand(2, 3, 299, 299) + metric = KernelInceptionDistance(normalize=normalize) + with expectation as e: + metric.update(img, real=True) + assert message is None or message in str(e)