@@ -230,6 +230,12 @@ class FrechetInceptionDistance(Metric):
230
230
your dataset does not change.
231
231
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
232
232
233
+ .. note::
234
+ If a custom feature extractor is provided through the `feature` argument it is expected to either have a
235
+ attribute called ``num_features`` that indicates the number of features returned by the forward pass or
236
+ alternatively we will pass through tensor of shape ``(1, 3, 299, 299)`` and dtype ``torch.uint8``` to the
237
+ forward pass and expect a tensor of shape ``(1, num_features)`` as output.
238
+
233
239
Raises:
234
240
ValueError:
235
241
If torch version is lower than 1.9
@@ -297,8 +303,11 @@ def __init__(
297
303
298
304
elif isinstance (feature , Module ):
299
305
self .inception = feature
300
- dummy_image = torch .randint (0 , 255 , (1 , 3 , 299 , 299 ), dtype = torch .uint8 )
301
- num_features = self .inception (dummy_image ).shape [- 1 ]
306
+ if hasattr (self .inception , "num_features" ):
307
+ num_features = self .inception .num_features
308
+ else :
309
+ dummy_image = torch .randint (0 , 255 , (1 , 3 , 299 , 299 ), dtype = torch .uint8 )
310
+ num_features = self .inception (dummy_image ).shape [- 1 ]
302
311
else :
303
312
raise TypeError ("Got unknown input to argument `feature`" )
304
313
0 commit comments