Skip to content

Commit 7543641

Browse files
SkafteNickiBorda
andauthored
Add note for expectation of custom feature extrator in FID metric (#2277)
* update options * changelog --------- Co-authored-by: Jirka Borovec <[email protected]>
1 parent 4059085 commit 7543641

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6666
- Fixed warning incorrectly being raised in `Running` metrics ([#2256](https://github.com/Lightning-AI/torchmetrics/pull/2265))
6767

6868

69+
- Fixed integration with custom feature extractor in `FID` metric ([#2277](https://github.com/Lightning-AI/torchmetrics/pull/2277))
70+
71+
6972
## [1.2.1] - 2023-11-30
7073

7174
### Added

src/torchmetrics/image/fid.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,12 @@ class FrechetInceptionDistance(Metric):
230230
your dataset does not change.
231231
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
232232
233+
.. note::
234+
If a custom feature extractor is provided through the `feature` argument it is expected to either have a
235+
attribute called ``num_features`` that indicates the number of features returned by the forward pass or
236+
alternatively we will pass through tensor of shape ``(1, 3, 299, 299)`` and dtype ``torch.uint8``` to the
237+
forward pass and expect a tensor of shape ``(1, num_features)`` as output.
238+
233239
Raises:
234240
ValueError:
235241
If torch version is lower than 1.9
@@ -297,8 +303,11 @@ def __init__(
297303

298304
elif isinstance(feature, Module):
299305
self.inception = feature
300-
dummy_image = torch.randint(0, 255, (1, 3, 299, 299), dtype=torch.uint8)
301-
num_features = self.inception(dummy_image).shape[-1]
306+
if hasattr(self.inception, "num_features"):
307+
num_features = self.inception.num_features
308+
else:
309+
dummy_image = torch.randint(0, 255, (1, 3, 299, 299), dtype=torch.uint8)
310+
num_features = self.inception(dummy_image).shape[-1]
302311
else:
303312
raise TypeError("Got unknown input to argument `feature`")
304313

0 commit comments

Comments
 (0)