@@ -95,13 +95,19 @@ def __init__(
95
95
self .model .requires_grad_ (False )
96
96
97
97
# Initialize state variables used to compute FID
98
- self ._add_state ("real_sum" , torch .zeros (feature_dim , device = device ))
99
98
self ._add_state (
100
- "real_cov_sum " , torch .zeros (( feature_dim , feature_dim ), device = device )
99
+ "real_sum " , torch .zeros (feature_dim , device = device , dtype = torch . float64 )
101
100
)
102
- self ._add_state ("fake_sum" , torch .zeros (feature_dim , device = device ))
103
101
self ._add_state (
104
- "fake_cov_sum" , torch .zeros ((feature_dim , feature_dim ), device = device )
102
+ "real_cov_sum" ,
103
+ torch .zeros ((feature_dim , feature_dim ), device = device , dtype = torch .float64 ),
104
+ )
105
+ self ._add_state (
106
+ "fake_sum" , torch .zeros (feature_dim , device = device , dtype = torch .float64 )
107
+ )
108
+ self ._add_state (
109
+ "fake_cov_sum" ,
110
+ torch .zeros ((feature_dim , feature_dim ), device = device , dtype = torch .float64 ),
105
111
)
106
112
self ._add_state ("num_real_images" , torch .tensor (0 , device = device ).int ())
107
113
self ._add_state ("num_fake_images" , torch .tensor (0 , device = device ).int ())
@@ -200,6 +206,7 @@ def compute(self: TFrechetInceptionDistance) -> Tensor:
200
206
fid = gaussian_frechet_distance (
201
207
real_mean .squeeze (), real_cov , fake_mean .squeeze (), fake_cov
202
208
)
209
+ fid = fid .to (torch .float32 )
203
210
return fid
204
211
205
212
def _FID_parameter_check (
0 commit comments