Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch size dependent FID #1620

Closed
nicolas-dufour opened this issue Mar 14, 2023 · 3 comments · Fixed by #1628
Closed

Batch size dependent FID #1620

nicolas-dufour opened this issue Mar 14, 2023 · 3 comments · Fixed by #1628
Labels
bug / fix Something isn't working help wanted Extra attention is needed topic: Image v0.10.x
Milestone

Comments

@nicolas-dufour
Copy link
Contributor

nicolas-dufour commented Mar 14, 2023

🐛 Bug

As pointed out by @kimihailv in #1198, FID seem to be batch-size dependent.
After some experimentation, the dependency seems to be linked to the inception network.

To Reproduce

If one runs:

inception = NoTrainInceptionV3(name="inception-v3-compat", features_list=["2048"])
inception.to(device)
imgs = torch.randn(100, 3, 256, 256).to(device)
imgs = ((imgs.clamp(-1, 1) / 2 + 0.5) * 255).to(torch.uint8)
features = []
for img in imgs:
    feature = inception(img.unsqueeze(0))
    features.append(feature)
features_b_1 = torch.cat(features, dim=0)
features = inception(imgs)

We then have:

torch.allclose(features, features_b_1)
>> False
torch.norm(features - features_b_1, p=2)
>> tensor(0.2950, device='cuda:0')

To be noted that when replacing

features = inception(imgs)

by

features = []
for img in imgs:
    feature = inception(img.unsqueeze(0))
    features.append(feature)
features= torch.cat(features, dim=0)

the FID is again batch independent. However, this is not a possible fix. Indeed, first this is very ineficient. Also, from experimentation, the batch bias in FID seems to be higher from small batch-sizes. If we compute FID between 2 uniformly sampled distributions with a 1000 points each, if we compute it with a batch size of 1000, we get FID 1.9 but if we compute it with batch-size=2, then the FID is 10. Since we sample from the same distribution, FID should be as close too zero as possible.

Expected behavior

FID computation should be batch_size independent

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 0.10.3
  • Python & PyTorch Version (e.g., 1.0): 3.10 and 1.13
  • Any other relevant information such as OS (e.g., Linux): Linux
@SkafteNicki
Copy link
Member

Hi @nicolas-dufour, thanks for reporting this issue.
I tried reproducing your results, but were unable to do that, please see this notebook:
https://colab.research.google.com/drive/16aVuIJQj7TUmhiy6TGL2QcyQOTb3ibJr?usp=sharing
On CPU i get a norm difference of tensor(0.0002) and on CUDA i get tensor(0.0172) between the two approaches, which seems reasonable to still calculate the metric correctly.

@nicolas-dufour
Copy link
Contributor Author

Hi @SkafteNicki, thanks for checking this out.

Hum that is strange! From further experiments, i found that the discrepancy disappeared when using float64. Maybe the problem is accelerator dependent? The previous experiment was done on a RTX 3090.

Also, I've observed that the impact on FID was minimal for dataset size > 100. However, it's still weird that the metric changes with respect to the batch size. One solution would be to offer the option to run the embedding network at float64 precision.

@SkafteNicki
Copy link
Member

Hi again @nicolas-dufour,
So I created PR #1628 that will allow the user to run the embedding network with float64 by simply calling the .set_dtype method of the metric.

from torchmetrics.image.fid import NoTrainInceptionV3, FrechetInceptionDistance
import torch

metric = FrechetInceptionDistance()
metric.set_dtype(torch.float64)

imgs = torch.randn(1, 3, 256, 256)
imgs = ((imgs.clamp(-1, 1) / 2 + 0.5) * 255).to(torch.uint8)

metric.inception(imgs)

still need a bit of testing a documentation.

@SkafteNicki SkafteNicki added this to the v0.12 milestone Mar 21, 2023
toshas added a commit to toshas/torch-fidelity that referenced this issue Apr 30, 2023
…lp numerical issues with inception feature extractor and its output variation due to the batch size.

fix #43, related in torchmetrics:
- Lightning-AI/torchmetrics#1620
- Lightning-AI/torchmetrics#1628
add explicit eval in the inception fe to help a case if someone copies just that file for metrics evaluation
add explicit require_grad(False) to clip feature extractor
add test cases to troubleshoot batch size dependence of metrics values
rustoneee added a commit to rustoneee/Pytorch-Generative-models-GAN- that referenced this issue Nov 6, 2023
…lp numerical issues with inception feature extractor and its output variation due to the batch size.

fix #43, related in torchmetrics:
- Lightning-AI/torchmetrics#1620
- Lightning-AI/torchmetrics#1628
add explicit eval in the inception fe to help a case if someone copies just that file for metrics evaluation
add explicit require_grad(False) to clip feature extractor
add test cases to troubleshoot batch size dependence of metrics values
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed topic: Image v0.10.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants