Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for torch.float weighted networks for FID and KID calculations. #2483

Merged
merged 28 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c51c656
add support to normalized custom models
furkan-celik Mar 25, 2024
a45f8ab
- documentation fix
furkan-celik Mar 31, 2024
cff2d49
added dummz feature extractor network to test custom extractor
furkan-celik Mar 31, 2024
6fcd9a4
add dummy feature extractor to tests for testing custom feature extra…
furkan-celik Mar 31, 2024
09cf848
fixed init error
furkan-celik Mar 31, 2024
5a786e2
convert int8 tensor imgs to float32 on model side
furkan-celik Mar 31, 2024
b829dcf
prehook commit changes
furkan-celik Mar 31, 2024
f317e73
precommit hook changes
furkan-celik Mar 31, 2024
472951b
fix typing error
furkan-celik Mar 31, 2024
5bf2b50
Merge branch 'master' into master
Borda Apr 10, 2024
2ab9fb0
fix argument quotation
furkan-celik Apr 12, 2024
9f563ef
Merge branch 'master' of github.com:furkan-celik/torchmetrics
furkan-celik Apr 12, 2024
5ce7eef
Merge branch 'master' into master
Borda Apr 12, 2024
efd324b
changelog
SkafteNicki Apr 12, 2024
cb7942d
Update src/torchmetrics/image/fid.py
furkan-celik Apr 12, 2024
cd9d656
Update src/torchmetrics/image/fid.py
furkan-celik Apr 12, 2024
afd003f
Update src/torchmetrics/image/fid.py
furkan-celik Apr 12, 2024
14e968d
Update src/torchmetrics/image/fid.py
furkan-celik Apr 12, 2024
43e28ee
Update src/torchmetrics/image/fid.py
furkan-celik Apr 12, 2024
5944b4e
Update src/torchmetrics/image/fid.py
furkan-celik Apr 12, 2024
fb9db7a
Merge branch 'master' into master
SkafteNicki Apr 13, 2024
b88d910
try fixing issues in docs
SkafteNicki Apr 13, 2024
24e2c49
Merge branch 'master' into master
Borda Apr 14, 2024
ff178ed
Merge branch 'master' into master
SkafteNicki Apr 15, 2024
dda2740
Merge branch 'master' into master
mergify[bot] Apr 15, 2024
de761af
Merge branch 'master' into master
mergify[bot] Apr 15, 2024
afc5629
Merge branch 'master' into master
mergify[bot] Apr 15, 2024
51dcc91
Merge branch 'master' into master
mergify[bot] Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for calculating segmentation quality and recognition quality in `PanopticQuality` metric ([#2381](https://github.com/Lightning-AI/torchmetrics/pull/2381))


- Added support for `torch.float` weighted networks for FID and KID calculations ([#2483](https://github.com/Lightning-AI/torchmetrics/pull/2483))


### Changed

- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424))
Expand Down
56 changes: 43 additions & 13 deletions src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ class FrechetInceptionDistance(Metric):
flag ``real`` determines if the images should update the statistics of the real distribution or the
fake distribution.

Using custom feature extractor is also possible. One can give a torch.nn.Module as `feature` argument. This
custom feature extractor is expected to have output shape of ``(1, num_features)``. This would change the
used feature extractor from default (Inception v3) to the given network. In case network doesn't have
``num_features`` attribute, a random tensor will be given to the network to infer feature dimensionality.
Size of this tensor can be controlled by ``input_img_size`` argument and type of the tensor can be controlled
with ``normalize`` argument (``True`` uses float32 tensors and ``False`` uses int8 tensors). In this case, update
method expects to have the tensor given to `imgs` argument to be in the correct shape and type that is compatible
to the custom feature extractor.

This metric is known to be unstable in its calculatations, and we recommend for the best results using this metric
that you calculate using `torch.float64` (default is `torch.float32`) which can be set using the `.set_dtype`
method of the metric.
Expand Down Expand Up @@ -228,13 +237,20 @@ class FrechetInceptionDistance(Metric):
reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not
change, the features can be cached them to avoid recomputing them which is costly. Set this to ``False`` if
your dataset does not change.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
normalize:
Argument for controlling the input image dtype normalization:

- If default feature extractor is used, controls whether input imgs have values in range [0, 1] or not:

- True: if input imgs have values ranged in [0, 1]. They are cast to int8/byte tensors.
- False: if input imgs have values ranged in [0, 255]. No casting is done.

- If custom feature extractor module is used, controls type of the input img tensors:

.. note::
If a custom feature extractor is provided through the `feature` argument it is expected to either have a
attribute called ``num_features`` that indicates the number of features returned by the forward pass or
alternatively we will pass through tensor of shape ``(1, 3, 299, 299)`` and dtype ``torch.uint8``` to the
forward pass and expect a tensor of shape ``(1, num_features)`` as output.
- True: if input imgs are expected to be in the data type of torch.float32.
- False: if input imgs are expected to be in the data type of torch.int8.
input_img_size: tuple of integers. Indicates input img size to the custom feature extractor network if provided.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
Expand Down Expand Up @@ -284,9 +300,16 @@ def __init__(
feature: Union[int, Module] = 2048,
reset_real_features: bool = True,
normalize: bool = False,
input_img_size: Tuple[int, int, int] = (3, 299, 299),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize
self.used_custom_model = False

if isinstance(feature, int):
num_features = feature
if not _TORCH_FIDELITY_AVAILABLE:
Expand All @@ -304,10 +327,14 @@ def __init__(

elif isinstance(feature, Module):
self.inception = feature
self.used_custom_model = True
if hasattr(self.inception, "num_features"):
num_features = self.inception.num_features
else:
dummy_image = torch.randint(0, 255, (1, 3, 299, 299), dtype=torch.uint8)
if self.normalize:
dummy_image = torch.rand(1, *input_img_size, dtype=torch.float32)
else:
dummy_image = torch.randint(0, 255, (1, *input_img_size), dtype=torch.uint8)
num_features = self.inception(dummy_image).shape[-1]
else:
raise TypeError("Got unknown input to argument `feature`")
Expand All @@ -316,10 +343,6 @@ def __init__(
raise ValueError("Argument `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features

if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize

mx_num_feats = (num_features, num_features)
self.add_state("real_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum")
self.add_state("real_features_cov_sum", torch.zeros(mx_num_feats).double(), dist_reduce_fx="sum")
Expand All @@ -330,8 +353,15 @@ def __init__(
self.add_state("fake_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum")

def update(self, imgs: Tensor, real: bool) -> None:
"""Update the state with extracted features."""
imgs = (imgs * 255).byte() if self.normalize else imgs
"""Update the state with extracted features.

Args:
imgs: Input img tensors to evaluate. If used custom feature extractor please
make sure dtype and size is correct for the model.
real: Whether given image is real or fake.

"""
imgs = (imgs * 255).byte() if self.normalize and (not self.used_custom_model) else imgs
features = self.inception(imgs)
self.orig_dtype = features.dtype
features = features.double()
Expand Down
26 changes: 23 additions & 3 deletions src/torchmetrics/image/kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ class KernelInceptionDistance(Metric):
flag ``real`` determines if the images should update the statistics of the real distribution or the
fake distribution.

Using custom feature extractor is also possible. One can give a torch.nn.Module as `feature` argument. This
custom feature extractor is expected to have output shape of ``(1, num_features)`` This would change the
used feature extractor from default (Inception v3) to the given network. ``normalize`` argument won't have any
effect and update method expects to have the tensor given to `imgs` argument to be in the correct shape and
type that is compatible to the custom feature extractor.

.. note:: using this metric with the default feature extractor requires that ``torch-fidelity``
is installed. Either install as ``pip install torchmetrics[image]`` or
``pip install torch-fidelity``
Expand All @@ -103,7 +109,7 @@ class KernelInceptionDistance(Metric):
As output of `forward` and `compute` the metric returns the following output

- ``kid_mean`` (:class:`~torch.Tensor`): float scalar tensor with mean value over subsets
- ``kid_std`` (:class:`~torch.Tensor`): float scalar tensor with mean value over subsets
- ``kid_std`` (:class:`~torch.Tensor`): float scalar tensor with standard deviation value over subsets

Args:
feature: Either an str, integer or ``nn.Module``:
Expand Down Expand Up @@ -187,6 +193,8 @@ def __init__(
UserWarning,
)

self.used_custom_model = False

if isinstance(feature, (str, int)):
if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
Expand All @@ -202,6 +210,7 @@ def __init__(
self.inception: Module = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
elif isinstance(feature, Module):
self.inception = feature
self.used_custom_model = True
else:
raise TypeError("Got unknown input to argument `feature`")

Expand Down Expand Up @@ -238,8 +247,15 @@ def __init__(
self.add_state("fake_features", [], dist_reduce_fx=None)

def update(self, imgs: Tensor, real: bool) -> None:
"""Update the state with extracted features."""
imgs = (imgs * 255).byte() if self.normalize else imgs
"""Update the state with extracted features.

Args:
imgs: Input img tensors to evaluate. If used custom feature extractor please
make sure dtype and size is correct for the model.
real: Whether given image is real or fake.

"""
imgs = (imgs * 255).byte() if self.normalize and (not self.used_custom_model) else imgs
features = self.inception(imgs)

if real:
Expand All @@ -252,6 +268,10 @@ def compute(self) -> Tuple[Tensor, Tensor]:

Implementation inspired by `Fid Score`_

Returns:
kid_mean (:class:`~torch.Tensor`): float scalar tensor with mean value over subsets
kid_std (:class:`~torch.Tensor`): float scalar tensor with standard deviation value over subsets

"""
real_features = dim_zero_cat(self.real_features)
fake_features = dim_zero_cat(self.fake_features)
Expand Down
13 changes: 12 additions & 1 deletion tests/unittests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,19 @@ def test_fid_raises_errors_and_warnings():
_ = FrechetInceptionDistance(feature=[1, 2])


class _DummyFeatureExtractor(Module):
def __init__(self) -> None:
super().__init__()
self.flatten = torch.nn.Flatten()
self.extractor = torch.nn.Linear(3 * 299 * 299, 64)

def __call__(self, img) -> torch.Tensor:
img = (img / 125.5).float() # Convert int img input to float as Linear layer expects float inputs
return self.extractor(self.flatten(img))


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
@pytest.mark.parametrize("feature", [64, 192, 768, 2048, _DummyFeatureExtractor()])
def test_fid_same_input(feature):
"""If real and fake are update on the same data the fid score should be 0."""
metric = FrechetInceptionDistance(feature=feature)
Expand Down
13 changes: 12 additions & 1 deletion tests/unittests/image/test_kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,19 @@ def test_kid_extra_parameters():
KernelInceptionDistance(coef=-1)


class _DummyFeatureExtractor(Module):
def __init__(self) -> None:
super().__init__()
self.flatten = torch.nn.Flatten()
self.extractor = torch.nn.Linear(3 * 299 * 299, 64)

def __call__(self, img) -> torch.Tensor:
img = (img / 125.5).float() # Convert int img input to float as Linear layer expects float inputs
return self.extractor(self.flatten(img))


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
@pytest.mark.parametrize("feature", [64, 192, 768, 2048, _DummyFeatureExtractor()])
def test_kid_same_input(feature):
"""Test that the metric works."""
metric = KernelInceptionDistance(feature=feature, subsets=5, subset_size=2)
Expand Down
Loading