Skip to content

Commit d2de821

Browse files
author
tiphaine
committed
Fix precision in FID metric (pytorch#192)
1 parent 63a31a0 commit d2de821

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

torcheval/metrics/image/fid.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,19 @@ def __init__(
9595
self.model.requires_grad_(False)
9696

9797
# Initialize state variables used to compute FID
98-
self._add_state("real_sum", torch.zeros(feature_dim, device=device))
9998
self._add_state(
100-
"real_cov_sum", torch.zeros((feature_dim, feature_dim), device=device)
99+
"real_sum", torch.zeros(feature_dim, device=device, dtype=torch.float64)
101100
)
102-
self._add_state("fake_sum", torch.zeros(feature_dim, device=device))
103101
self._add_state(
104-
"fake_cov_sum", torch.zeros((feature_dim, feature_dim), device=device)
102+
"real_cov_sum",
103+
torch.zeros((feature_dim, feature_dim), device=device, dtype=torch.float64),
104+
)
105+
self._add_state(
106+
"fake_sum", torch.zeros(feature_dim, device=device, dtype=torch.float64)
107+
)
108+
self._add_state(
109+
"fake_cov_sum",
110+
torch.zeros((feature_dim, feature_dim), device=device, dtype=torch.float64),
105111
)
106112
self._add_state("num_real_images", torch.tensor(0, device=device).int())
107113
self._add_state("num_fake_images", torch.tensor(0, device=device).int())
@@ -200,6 +206,7 @@ def compute(self: TFrechetInceptionDistance) -> Tensor:
200206
fid = gaussian_frechet_distance(
201207
real_mean.squeeze(), real_cov, fake_mean.squeeze(), fake_cov
202208
)
209+
fid = fid.to(torch.float32)
203210
return fid
204211

205212
def _FID_parameter_check(

0 commit comments

Comments
 (0)