Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding KID metric #301

Merged
merged 27 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9d11ab0
implementation
SkafteNicki Jun 16, 2021
948a178
parameter testing
SkafteNicki Jun 16, 2021
c534a60
fix test
SkafteNicki Jun 16, 2021
72608f7
implementation
SkafteNicki Jun 16, 2021
c0de4f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 16, 2021
36c5a82
update to torch fidelity 0.3.0
SkafteNicki Jun 16, 2021
1cd256c
Merge branch 'kid' of https://github.com/PyTorchLightning/metrics int…
SkafteNicki Jun 16, 2021
6aa956d
changelog
SkafteNicki Jun 16, 2021
19fefc1
docs
SkafteNicki Jun 16, 2021
90b7a76
Merge branch 'master' into kid
mergify[bot] Jun 17, 2021
a8a2ab5
Apply suggestions from code review
SkafteNicki Jun 17, 2021
3cabc96
Apply suggestions from code review
Borda Jun 17, 2021
6a2a20a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2021
751b145
add test
SkafteNicki Jun 17, 2021
dae0fd8
Merge branch 'master' into kid
SkafteNicki Jun 18, 2021
7e99282
Merge branches 'kid' and 'kid' of https://github.com/PyTorchLightning…
SkafteNicki Jun 18, 2021
39cc0f9
update
SkafteNicki Jun 18, 2021
beb9d27
Merge branch 'master' into kid
mergify[bot] Jun 21, 2021
42ec431
fix tests
SkafteNicki Jun 21, 2021
81968df
Merge branch 'kid' of https://github.com/PyTorchLightning/metrics int…
SkafteNicki Jun 21, 2021
5dca79a
typing
SkafteNicki Jun 21, 2021
4203628
fix typing
SkafteNicki Jun 21, 2021
889c066
Merge branch 'master' into kid
mergify[bot] Jun 21, 2021
94d44a5
fix bus error
SkafteNicki Jun 21, 2021
08a5bb7
Merge branch 'kid' of https://github.com/PyTorchLightning/metrics int…
SkafteNicki Jun 21, 2021
128fc84
Merge branch 'master' into kid
mergify[bot] Jun 21, 2021
3f1e0e0
Apply suggestions from code review
Borda Jun 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Inception Score metric to image module ([#299](https://github.com/PyTorchLightning/metrics/pull/299))


- Added KID metric to image module ([#301](https://github.com/PyTorchLightning/metrics/pull/301))

### Changed

- Forward cache is now reset when `reset` method is called ([#260](https://github.com/PyTorchLightning/metrics/pull/260))
Expand Down
4 changes: 4 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ IS
.. autoclass:: torchmetrics.IS
:noindex:


.. autoclass:: torchmetrics.KID
:noindex:

******************
Regression Metrics
******************
Expand Down
3 changes: 2 additions & 1 deletion tests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def test_compare_fid(tmpdir, feature=2048):
input2=_ImgDataset(img2),
fid=True,
feature_layer_fid=str(feature),
batch_size=batch_size
batch_size=batch_size,
save_cpu_ram=True
)

tm_res = metric.compute()
Expand Down
8 changes: 5 additions & 3 deletions tests/image/test_inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def test_is_update_compute():
metric.update(img)

mean, std = metric.compute()
assert mean != 0.0
assert std != 0.0
assert mean >= 0.0
assert std >= 0.0


class _ImgDataset(Dataset):
Expand Down Expand Up @@ -118,7 +118,9 @@ def test_compare_is(tmpdir):
for i in range(img1.shape[0] // batch_size):
metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda())

torch_fid = calculate_metrics(input1=_ImgDataset(img1), isc=True, isc_splits=1, batch_size=batch_size)
torch_fid = calculate_metrics(
input1=_ImgDataset(img1), isc=True, isc_splits=1, batch_size=batch_size, save_cpu_ram=True
)

tm_mean, tm_std = metric.compute()

Expand Down
168 changes: 168 additions & 0 deletions tests/image/test_kid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle

import pytest
import torch
from torch.utils.data import Dataset

from torchmetrics.image.kid import KID
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE

torch.manual_seed(42)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
def test_no_train():
""" Assert that metric never leaves evaluation mode """

class MyModel(torch.nn.Module):

def __init__(self):
super().__init__()
self.metric = KID()

def forward(self, x):
return x

model = MyModel()
model.train()
assert model.training
assert not model.metric.inception.training, 'FID metric was changed to training mode which should not happen'


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_kid_pickle():
""" Assert that we can initialize the metric and pickle it"""
metric = KID()
assert metric

# verify metrics work after being loaded from pickled state
pickled_metric = pickle.dumps(metric)
metric = pickle.loads(pickled_metric)


def test_kid_raises_errors_and_warnings():
""" Test that expected warnings and errors are raised """
with pytest.warns(
UserWarning,
match='Metric `KID` will save all extracted features in buffer.'
' For large datasets this may lead to large memory footprint.'
):
KID()

if _TORCH_FIDELITY_AVAILABLE:
with pytest.raises(ValueError, match='Integer input to argument `feature` must be one of .*'):
KID(feature=2)
else:
with pytest.raises(
ValueError,
match='KID metric requires that Torch-fidelity is installed.'
'Either install as `pip install torchmetrics[image]`'
' or `pip install torch-fidelity`'
):
KID()

with pytest.raises(TypeError, match='Got unknown input to argument `feature`'):
KID(feature=[1, 2])

with pytest.raises(ValueError, match='Argument `subset_size` should be smaller than the number of samples'):
m = KID()
m.update(torch.randint(0, 255, (5, 3, 299, 299), dtype=torch.uint8), real=True)
m.update(torch.randint(0, 255, (5, 3, 299, 299), dtype=torch.uint8), real=False)
m.compute()


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_kid_extra_parameters():
with pytest.raises(ValueError, match="Argument `subsets` expected to be integer larger than 0"):
KID(subsets=-1)

with pytest.raises(ValueError, match="Argument `subset_size` expected to be integer larger than 0"):
KID(subset_size=-1)

with pytest.raises(ValueError, match="Argument `degree` expected to be integer larger than 0"):
KID(degree=-1)

with pytest.raises(ValueError, match="Argument `gamma` expected to be `None` or float larger than 0"):
KID(gamma=-1)

with pytest.raises(ValueError, match="Argument `coef` expected to be float larger than 0"):
KID(coef=-1)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
def test_kid_same_input(feature):
""" test that the metric works """
metric = KID(feature=feature, subsets=5, subset_size=2)

for _ in range(2):
img = torch.randint(0, 255, (10, 3, 299, 299), dtype=torch.uint8)
metric.update(img, real=True)
metric.update(img, real=False)

assert torch.allclose(torch.cat(metric.real_features, dim=0), torch.cat(metric.fake_features, dim=0))

mean, std = metric.compute()
assert mean != 0.0
assert std >= 0.0


class _ImgDataset(Dataset):

def __init__(self, imgs):
self.imgs = imgs

def __getitem__(self, idx):
return self.imgs[idx]

def __len__(self):
return self.imgs.shape[0]


@pytest.mark.skipif(not torch.cuda.is_available(), reason='test is too slow without gpu')
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_compare_kid(tmpdir, feature=2048):
""" check that the hole pipeline give the same result as torch-fidelity """
from torch_fidelity import calculate_metrics

metric = KID(feature=feature, subsets=1, subset_size=100).cuda()

# Generate some synthetic data
img1 = torch.randint(0, 180, (100, 3, 299, 299), dtype=torch.uint8)
img2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)

batch_size = 10
for i in range(img1.shape[0] // batch_size):
metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda(), real=True)

for i in range(img2.shape[0] // batch_size):
metric.update(img2[batch_size * i:batch_size * (i + 1)].cuda(), real=False)

torch_fid = calculate_metrics(
input1=_ImgDataset(img1),
input2=_ImgDataset(img2),
kid=True,
feature_layer_fid=str(feature),
batch_size=batch_size,
kid_subsets=1,
kid_subset_size=100,
save_cpu_ram=True
)

tm_mean, tm_std = metric.compute()

assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid['kernel_inception_distance_mean']]), atol=1e-3)
assert torch.allclose(tm_std.cpu(), torch.tensor([torch_fid['kernel_inception_distance_std']]), atol=1e-3)
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: F401 E402
from torchmetrics.image import FID, IS # noqa: F401 E402
from torchmetrics.image import FID, IS, KID # noqa: F401 E402
from torchmetrics.metric import Metric # noqa: F401 E402
from torchmetrics.regression import ( # noqa: F401 E402
PSNR,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.
from torchmetrics.image.fid import FID # noqa: F401
from torchmetrics.image.inception import IS # noqa: F401
from torchmetrics.image.kid import KID # noqa: F401
5 changes: 3 additions & 2 deletions torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_info, rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE

if _TORCH_FIDELITY_AVAILABLE:
Expand Down Expand Up @@ -257,8 +258,8 @@ def update(self, imgs: Tensor, real: bool) -> None: # type: ignore

def compute(self) -> Tensor:
""" Calculate FID score based on accumulated extracted features from the two distributions """
real_features = torch.cat(self.real_features, dim=0)
fake_features = torch.cat(self.fake_features, dim=0)
real_features = dim_zero_cat(self.real_features)
fake_features = dim_zero_cat(self.fake_features)
# computation is extremely sensitive so it needs to happen in double precision
orig_dtype = real_features.dtype
real_features = real_features.double()
Expand Down
3 changes: 2 additions & 1 deletion torchmetrics/image/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics.image.fid import NoTrainInceptionV3
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE


Expand Down Expand Up @@ -153,7 +154,7 @@ def update(self, imgs: Tensor) -> None: # type: ignore
self.features.append(features)

def compute(self) -> Tuple[Tensor, Tensor]:
features = torch.cat(self.features, dim=0)
features = dim_zero_cat(self.features)
# random permute the features
idx = torch.randperm(features.shape[0])
features = features[idx]
Expand Down
Loading