Skip to content

Commit

Permalink
Support for torch.float weighted networks for FID and KID calculati…
Browse files Browse the repository at this point in the history
…ons. (#2483)

* add support to normalized custom models

* - documentation fix
- support for float weighted custom networks
- support for custom sized input imgs

* added dummz feature extractor network to test custom extractor

* add dummy feature extractor to tests for testing custom feature extractor

* fixed init error

* convert int8 tensor imgs to float32 on model side

* prehook commit changes

* precommit hook changes

* fix typing error

* fix argument quotation

* changelog

* Update src/torchmetrics/image/fid.py

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>

* Update src/torchmetrics/image/fid.py

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>

* Update src/torchmetrics/image/fid.py

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>

* Update src/torchmetrics/image/fid.py

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>

* Update src/torchmetrics/image/fid.py

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>

* Update src/torchmetrics/image/fid.py

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>

* try fixing issues in docs

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
  • Loading branch information
3 people authored Apr 16, 2024
1 parent 5259c22 commit 822dba2
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 18 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for calculating segmentation quality and recognition quality in `PanopticQuality` metric ([#2381](https://github.com/Lightning-AI/torchmetrics/pull/2381))


- Added support for `torch.float` weighted networks for FID and KID calculations ([#2483](https://github.com/Lightning-AI/torchmetrics/pull/2483))


### Changed

- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424))
Expand Down
56 changes: 43 additions & 13 deletions src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ class FrechetInceptionDistance(Metric):
flag ``real`` determines if the images should update the statistics of the real distribution or the
fake distribution.
Using custom feature extractor is also possible. One can give a torch.nn.Module as `feature` argument. This
custom feature extractor is expected to have output shape of ``(1, num_features)``. This would change the
used feature extractor from default (Inception v3) to the given network. In case network doesn't have
``num_features`` attribute, a random tensor will be given to the network to infer feature dimensionality.
Size of this tensor can be controlled by ``input_img_size`` argument and type of the tensor can be controlled
with ``normalize`` argument (``True`` uses float32 tensors and ``False`` uses int8 tensors). In this case, update
method expects to have the tensor given to `imgs` argument to be in the correct shape and type that is compatible
to the custom feature extractor.
This metric is known to be unstable in its calculatations, and we recommend for the best results using this metric
that you calculate using `torch.float64` (default is `torch.float32`) which can be set using the `.set_dtype`
method of the metric.
Expand Down Expand Up @@ -228,13 +237,20 @@ class FrechetInceptionDistance(Metric):
reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not
change, the features can be cached them to avoid recomputing them which is costly. Set this to ``False`` if
your dataset does not change.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
normalize:
Argument for controlling the input image dtype normalization:
- If default feature extractor is used, controls whether input imgs have values in range [0, 1] or not:
- True: if input imgs have values ranged in [0, 1]. They are cast to int8/byte tensors.
- False: if input imgs have values ranged in [0, 255]. No casting is done.
- If custom feature extractor module is used, controls type of the input img tensors:
.. note::
If a custom feature extractor is provided through the `feature` argument it is expected to either have a
attribute called ``num_features`` that indicates the number of features returned by the forward pass or
alternatively we will pass through tensor of shape ``(1, 3, 299, 299)`` and dtype ``torch.uint8``` to the
forward pass and expect a tensor of shape ``(1, num_features)`` as output.
- True: if input imgs are expected to be in the data type of torch.float32.
- False: if input imgs are expected to be in the data type of torch.int8.
input_img_size: tuple of integers. Indicates input img size to the custom feature extractor network if provided.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand Down Expand Up @@ -284,9 +300,16 @@ def __init__(
feature: Union[int, Module] = 2048,
reset_real_features: bool = True,
normalize: bool = False,
input_img_size: Tuple[int, int, int] = (3, 299, 299),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize
self.used_custom_model = False

if isinstance(feature, int):
num_features = feature
if not _TORCH_FIDELITY_AVAILABLE:
Expand All @@ -304,10 +327,14 @@ def __init__(

elif isinstance(feature, Module):
self.inception = feature
self.used_custom_model = True
if hasattr(self.inception, "num_features"):
num_features = self.inception.num_features
else:
dummy_image = torch.randint(0, 255, (1, 3, 299, 299), dtype=torch.uint8)
if self.normalize:
dummy_image = torch.rand(1, *input_img_size, dtype=torch.float32)
else:
dummy_image = torch.randint(0, 255, (1, *input_img_size), dtype=torch.uint8)
num_features = self.inception(dummy_image).shape[-1]
else:
raise TypeError("Got unknown input to argument `feature`")
Expand All @@ -316,10 +343,6 @@ def __init__(
raise ValueError("Argument `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features

if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize

mx_num_feats = (num_features, num_features)
self.add_state("real_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum")
self.add_state("real_features_cov_sum", torch.zeros(mx_num_feats).double(), dist_reduce_fx="sum")
Expand All @@ -330,8 +353,15 @@ def __init__(
self.add_state("fake_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum")

def update(self, imgs: Tensor, real: bool) -> None:
"""Update the state with extracted features."""
imgs = (imgs * 255).byte() if self.normalize else imgs
"""Update the state with extracted features.
Args:
imgs: Input img tensors to evaluate. If used custom feature extractor please
make sure dtype and size is correct for the model.
real: Whether given image is real or fake.
"""
imgs = (imgs * 255).byte() if self.normalize and (not self.used_custom_model) else imgs
features = self.inception(imgs)
self.orig_dtype = features.dtype
features = features.double()
Expand Down
26 changes: 23 additions & 3 deletions src/torchmetrics/image/kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ class KernelInceptionDistance(Metric):
flag ``real`` determines if the images should update the statistics of the real distribution or the
fake distribution.
Using custom feature extractor is also possible. One can give a torch.nn.Module as `feature` argument. This
custom feature extractor is expected to have output shape of ``(1, num_features)`` This would change the
used feature extractor from default (Inception v3) to the given network. ``normalize`` argument won't have any
effect and update method expects to have the tensor given to `imgs` argument to be in the correct shape and
type that is compatible to the custom feature extractor.
.. note:: using this metric with the default feature extractor requires that ``torch-fidelity``
is installed. Either install as ``pip install torchmetrics[image]`` or
``pip install torch-fidelity``
Expand All @@ -103,7 +109,7 @@ class KernelInceptionDistance(Metric):
As output of `forward` and `compute` the metric returns the following output
- ``kid_mean`` (:class:`~torch.Tensor`): float scalar tensor with mean value over subsets
- ``kid_std`` (:class:`~torch.Tensor`): float scalar tensor with mean value over subsets
- ``kid_std`` (:class:`~torch.Tensor`): float scalar tensor with standard deviation value over subsets
Args:
feature: Either an str, integer or ``nn.Module``:
Expand Down Expand Up @@ -187,6 +193,8 @@ def __init__(
UserWarning,
)

self.used_custom_model = False

if isinstance(feature, (str, int)):
if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
Expand All @@ -202,6 +210,7 @@ def __init__(
self.inception: Module = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
elif isinstance(feature, Module):
self.inception = feature
self.used_custom_model = True
else:
raise TypeError("Got unknown input to argument `feature`")

Expand Down Expand Up @@ -238,8 +247,15 @@ def __init__(
self.add_state("fake_features", [], dist_reduce_fx=None)

def update(self, imgs: Tensor, real: bool) -> None:
"""Update the state with extracted features."""
imgs = (imgs * 255).byte() if self.normalize else imgs
"""Update the state with extracted features.
Args:
imgs: Input img tensors to evaluate. If used custom feature extractor please
make sure dtype and size is correct for the model.
real: Whether given image is real or fake.
"""
imgs = (imgs * 255).byte() if self.normalize and (not self.used_custom_model) else imgs
features = self.inception(imgs)

if real:
Expand All @@ -252,6 +268,10 @@ def compute(self) -> Tuple[Tensor, Tensor]:
Implementation inspired by `Fid Score`_
Returns:
kid_mean (:class:`~torch.Tensor`): float scalar tensor with mean value over subsets
kid_std (:class:`~torch.Tensor`): float scalar tensor with standard deviation value over subsets
"""
real_features = dim_zero_cat(self.real_features)
fake_features = dim_zero_cat(self.fake_features)
Expand Down
13 changes: 12 additions & 1 deletion tests/unittests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,19 @@ def test_fid_raises_errors_and_warnings():
_ = FrechetInceptionDistance(feature=[1, 2])


class _DummyFeatureExtractor(Module):
def __init__(self) -> None:
super().__init__()
self.flatten = torch.nn.Flatten()
self.extractor = torch.nn.Linear(3 * 299 * 299, 64)

def __call__(self, img) -> torch.Tensor:
img = (img / 125.5).float() # Convert int img input to float as Linear layer expects float inputs
return self.extractor(self.flatten(img))


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
@pytest.mark.parametrize("feature", [64, 192, 768, 2048, _DummyFeatureExtractor()])
def test_fid_same_input(feature):
"""If real and fake are update on the same data the fid score should be 0."""
metric = FrechetInceptionDistance(feature=feature)
Expand Down
13 changes: 12 additions & 1 deletion tests/unittests/image/test_kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,19 @@ def test_kid_extra_parameters():
KernelInceptionDistance(coef=-1)


class _DummyFeatureExtractor(Module):
def __init__(self) -> None:
super().__init__()
self.flatten = torch.nn.Flatten()
self.extractor = torch.nn.Linear(3 * 299 * 299, 64)

def __call__(self, img) -> torch.Tensor:
img = (img / 125.5).float() # Convert int img input to float as Linear layer expects float inputs
return self.extractor(self.flatten(img))


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
@pytest.mark.parametrize("feature", [64, 192, 768, 2048, _DummyFeatureExtractor()])
def test_kid_same_input(feature):
"""Test that the metric works."""
metric = KernelInceptionDistance(feature=feature, subsets=5, subset_size=2)
Expand Down

0 comments on commit 822dba2

Please sign in to comment.