Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchmetrics.image.fid FrechetInceptionDistance normalize argument not working #1339

Closed
Jelle-Plomp opened this issue Nov 15, 2022 · 3 comments
Labels
bug / fix Something isn't working help wanted Extra attention is needed topic: Image v0.9.x

Comments

@Jelle-Plomp
Copy link

Jelle-Plomp commented Nov 15, 2022

🐛 Bug

torchmetrics.image.fid FrechetInceptionDistance normalize argument is not working. The description states that if normalize=True, then the input images are expected to be of float type. Yet if I input float I still get the ValueError: Expecting image as torch.Tensor with dtype=torch.uint8

To Reproduce

The following is adapted from the example provided with the docs, now using random images of dtype float. [(https://torchmetrics.readthedocs.io/en/latest/image/frechet_inception_distance.html)]

Steps to reproduce the behavior...

Code sample

import torch
_ = torch.manual_seed(123)
from torchmetrics.image.fid import FrechetInceptionDistance
fid = FrechetInceptionDistance(feature=64, normalize=True)
imgs_dist1 = torch.rand((100, 3, 299, 299), dtype=torch.float)
imgs_dist2 = torch.rand((100, 3, 299, 299), dtype=torch.float)
fid.update(imgs_dist1, real=True)
fid.update(imgs_dist2, real=False)
fid.compute()
### Traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [9], in <cell line: 10>()
      8 print(imgs_dist1.size())
      9 imgs_dist2 = torch.rand((100, 3, 299, 299), dtype=torch.float)
---> 10 fid.update(imgs_dist1, real=True)
     11 fid.update(imgs_dist2, real=False)
     12 fid.compute()

File ~\anaconda3\lib\site-packages\torchmetrics\metric.py:391, in Metric._wrap_update.<locals>.wrapped_func(*args, **kwargs)
    389 with torch.set_grad_enabled(self._enable_grad):
    390     try:
--> 391         update(*args, **kwargs)
    392     except RuntimeError as err:
    393         if "Expected all tensors to be on" in str(err):

File ~\anaconda3\lib\site-packages\torchmetrics\image\fid.py:253, in FrechetInceptionDistance.update(self, imgs, real)
    246 def update(self, imgs: Tensor, real: bool) -> None:  # type: ignore
    247     """Update the state with extracted features.
    248 
    249     Args:
    250         imgs: tensor with images feed to the feature extractor
    251         real: bool indicating if ``imgs`` belong to the real or the fake distribution
    252     """
--> 253     features = self.inception(imgs)
    255     if real:
    256         self.real_features.append(features)

File ~\anaconda3\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\lib\site-packages\torchmetrics\image\fid.py:57, in NoTrainInceptionV3.forward(self, x)
     56 def forward(self, x: Tensor) -> Tensor:
---> 57     out = super().forward(x)
     58     return out[0].reshape(x.shape[0], -1)

File ~\anaconda3\lib\site-packages\torch_fidelity\feature_extractor_inceptionv3.py:92, in FeatureExtractorInceptionV3.forward(self, x)
     91 def forward(self, x):
---> 92     vassert(torch.is_tensor(x) and x.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8')
     93     features = {}
     94     remaining_features = self.features_list.copy()

File ~\anaconda3\lib\site-packages\torch_fidelity\helpers.py:9, in vassert(truecond, message)
      7 def vassert(truecond, message):
      8     if not truecond:
----> 9         raise ValueError(message)

ValueError: Expecting image as torch.Tensor with dtype=torch.uint8

Environment

  • TorchMetrics version 0.9.3 (conda)
  • Python 3.9.12 & PyTorch 1.10.2
  • torch fidelity installed using pip install torch-fidelity
@Jelle-Plomp Jelle-Plomp added bug / fix Something isn't working help wanted Extra attention is needed labels Nov 15, 2022
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@stancld
Copy link
Contributor

stancld commented Nov 15, 2022

Hi @Jelle-Plomp, thanks for raising the issue. I've run your code on the master and also on the latest 0.10.2 release. Found out I got the same error as you for the latest release. However, everything seems to work fine on master (likely thanks to #1246). You can either install torchmetrics from the source, or we are also planning to come with a new release soon.

cc: @Borda @SkafteNicki

@SkafteNicki
Copy link
Member

I can confirm @stancld that on master it is fixed/feature has been added. @Jelle-Plomp I think the confusion is that you had navigated to the latest page for our documentation which includes changes to master that have still not being release in a official version
https://torchmetrics.readthedocs.io/en/latest/image/frechet_inception_distance.html
However, if we look at the stable page
https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html
(which is the default landing page) it correctly shows that this feature is not included in v0.9.3 of torchmetrics.

Closing issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed topic: Image v0.9.x
Projects
None yet
Development

No branches or pull requests

4 participants