Skip to content

Commit

Permalink
Merge b65e3f7 into 90743f0
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Jun 10, 2021
2 parents 90743f0 + b65e3f7 commit 0237671
Show file tree
Hide file tree
Showing 12 changed files with 648 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `squared` argument to `MeanSquaredError` for computing `RMSE` ([#249](https://github.com/PyTorchLightning/metrics/pull/249))


- Added FID metric ([#213](https://github.com/PyTorchLightning/metrics/pull/213))


- Added `is_differentiable` property to `ConfusionMatrix`, `F1`, `FBeta`, `Hamming`, `Hinge`, `IOU`, `MatthewsCorrcoef`, `Precision`, `Recall`, `PrecisionRecallCurve`, `ROC`, `StatScores` ([#253](https://github.com/PyTorchLightning/metrics/pull/253))


Expand Down
11 changes: 11 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,17 @@ StatScores
:noindex:


*********************
Image Quality Metrics
*********************

Image quality metrics can be used to access the quality of synthetic generated images from machine
learning algorithms such as `Generative Adverserial Networks (GANs) <https://en.wikipedia.org/wiki/Generative_adversarial_network>`_.

.. autoclass:: torchmetrics.FID
:noindex:


******************
Regression Metrics
******************
Expand Down
2 changes: 2 additions & 0 deletions requirements/image.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
scipy
torch-fidelity
3 changes: 3 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ cloudpickle>=1.3
scikit-learn>=0.24
scikit-image>0.17.1
nltk>=3.6

# add extra requirements
-r image.txt
10 changes: 10 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from setuptools import find_packages, setup

_PATH_ROOT = os.path.realpath(os.path.dirname(__file__))
_PATH_REQUIRE = os.path.join(_PATH_ROOT, 'requirements')


def _load_py_module(fname, pkg="torchmetrics"):
Expand All @@ -22,6 +23,14 @@ def _load_py_module(fname, pkg="torchmetrics"):
version=f'v{about.__version__}',
)


def _prepare_extras():
extras = {
'image': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='image.txt'),
}
return extras


# https://packaging.python.org/discussions/install-requires-vs-requirements /
# keep the meta-data here for simplicity in reading this file... it's not obvious
# what happens and to non-engineers they won't know to look in init ...
Expand Down Expand Up @@ -72,4 +81,5 @@ def _load_py_module(fname, pkg="torchmetrics"):
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
],
extras_require=_prepare_extras(),
)
191 changes: 191 additions & 0 deletions tests/helpers/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import time
import urllib.request
from typing import Optional, Sequence, Tuple

import torch
from torch import Tensor
from torch.utils.data import Dataset


class MNIST(Dataset):
"""
Customized `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset for testing torchmetrics
without the torchvision dependency.
Part of the code was copied from
https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py
Args:
root: Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
train: If ``True``, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
normalize: mean and std deviation of the MNIST dataset.
download: If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
Examples:
>>> dataset = MNIST(".", download=True)
>>> len(dataset)
60000
>>> torch.bincount(dataset.targets)
tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949])
"""

RESOURCES = (
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
)

TRAIN_FILE_NAME = 'training.pt'
TEST_FILE_NAME = 'test.pt'
cache_folder_name = 'complete'

def __init__(
self,
root: str,
train: bool = True,
normalize: tuple = (0.1307, 0.3081),
download: bool = True,
**kwargs,
):
super().__init__()
self.root = root
self.train = train # training set or test set
self.normalize = normalize

self.prepare_data(download)

data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))

def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
img = self.data[idx].float().unsqueeze(0)
target = int(self.targets[idx])

if self.normalize is not None and len(self.normalize) == 2:
img = self.normalize_tensor(img, *self.normalize)

return img, target

def __len__(self) -> int:
return len(self.data)

@property
def cached_folder_path(self) -> str:
return os.path.join(self.root, 'MNIST', self.cache_folder_name)

def _check_exists(self, data_folder: str) -> bool:
existing = True
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
existing = existing and os.path.isfile(os.path.join(data_folder, fname))
return existing

def prepare_data(self, download: bool = True):
if download and not self._check_exists(self.cached_folder_path):
self._download(self.cached_folder_path)
if not self._check_exists(self.cached_folder_path):
raise RuntimeError('Dataset not found.')

def _download(self, data_folder: str) -> None:
os.makedirs(data_folder)
for url in self.RESOURCES:
logging.info(f'Downloading {url}')
fpath = os.path.join(data_folder, os.path.basename(url))
urllib.request.urlretrieve(url, fpath)

@staticmethod
def _try_load(path_data, trials: int = 30, delta: float = 1.):
"""Resolving loading from the same time from multiple concurrent processes."""
res, exception = None, None
assert trials, "at least some trial has to be set"
assert os.path.isfile(path_data), f'missing file: {path_data}'
for _ in range(trials):
try:
res = torch.load(path_data)
# todo: specify the possible exception
except Exception as e:
exception = e
time.sleep(delta * random.random())
else:
break
if exception is not None:
# raise the caught exception
raise exception
return res

@staticmethod
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
return tensor.sub(mean).div(std)


class TrialMNIST(MNIST):
"""Constrained MNIST dataset
Args:
num_samples: number of examples per selected class/digit
digits: list selected MNIST digits/classes
kwargs: Same as MNIST
Examples:
>>> dataset = TrialMNIST(".", download=True)
>>> len(dataset)
300
>>> sorted(set([d.item() for d in dataset.targets]))
[0, 1, 2]
>>> torch.bincount(dataset.targets)
tensor([100, 100, 100])
"""

def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs):
# number of examples per class
self.num_samples = num_samples
# take just a subset of MNIST dataset
self.digits = sorted(digits) if digits else list(range(10))

self.cache_folder_name = f"digits-{'-'.join(str(d) for d in self.digits)}_nb-{self.num_samples}"

super().__init__(root, normalize=(0.5, 1.0), **kwargs)

@staticmethod
def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence):
classes = {d: 0 for d in digits}
indexes = []
for idx, target in enumerate(full_targets):
label = target.item()
if classes.get(label, float('inf')) >= num_samples:
continue
indexes.append(idx)
classes[label] += 1
if all(classes[k] >= num_samples for k in classes):
break
data = full_data[indexes]
targets = full_targets[indexes]
return data, targets

def _download(self, data_folder: str) -> None:
super()._download(data_folder)
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
path_fname = os.path.join(self.cached_folder_path, fname)
assert os.path.isfile(path_fname), f'Missing cached file: {path_fname}'
data, targets = self._try_load(path_fname)
data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits)
torch.save((data, targets), os.path.join(self.cached_folder_path, fname))
Empty file added tests/image/__init__.py
Empty file.
128 changes: 128 additions & 0 deletions tests/image/test_fid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle

import pytest
import torch
from scipy.linalg import sqrtm as scipy_sqrtm
from torch.utils.data import Dataset

from tests.helpers.datasets import TrialMNIST
from torchmetrics.image.fid import FID, sqrtm
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE

torch.manual_seed(42)


@pytest.mark.parametrize("matrix_size", [2, 10, 100, 500])
def test_matrix_sqrt(matrix_size):
""" test that metrix sqrt function works as expected """

def generate_cov(n):
data = torch.randn(2 * n, n)
return (data - data.mean(dim=0)).T @ (data - data.mean(dim=0))

cov1 = generate_cov(matrix_size)
cov2 = generate_cov(matrix_size)

scipy_res = scipy_sqrtm((cov1 @ cov2).numpy()).real
tm_res = sqrtm(cov1 @ cov2)
assert torch.allclose(torch.tensor(scipy_res).float(), tm_res, atol=1e-3)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_fid_pickle():
""" Assert that we can initialize the metric and pickle it"""
metric = FID()
assert metric

# verify metrics work after being loaded from pickled state
pickled_metric = pickle.dumps(metric)
metric = pickle.loads(pickled_metric)


def test_fid_raises_errors_and_warnings():
""" Test that expected warnings and errors are raised """
with pytest.warns(
UserWarning,
match='Metric `FID` will save all extracted features in buffer.'
' For large datasets this may lead to large memory footprint.'
):
_ = FID()

if _TORCH_FIDELITY_AVAILABLE:
with pytest.raises(ValueError, match='Integer input to argument `feature` must be one of .*'):
_ = FID(feature=2)
else:
with pytest.raises(
ValueError,
match='FID metric requires that Torch-fidelity is installed.'
'Either install as `pip install torchmetrics[image-quality]`'
' or `pip install torch-fidelity`'
):
_ = FID()


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_fid_same_input():
""" if real and fake are update on the same data the fid score should be 0 """
metric = FID(feature=192)

for _ in range(2):
img = torch.randint(0, 255, (5, 3, 299, 299), dtype=torch.uint8)
metric.update(img, real=True)
metric.update(img, real=False)

assert torch.allclose(torch.cat(metric.real_features, dim=0), torch.cat(metric.fake_features, dim=0))

val = metric.compute()
assert torch.allclose(val, torch.zeros_like(val), atol=1e-3)


class _ImgDataset(Dataset):

def __init__(self, imgs):
self.imgs = imgs

def __getitem__(self, idx):
return self.imgs[idx]

def __len__(self):
return self.imgs.shape[0]


@pytest.mark.skipif(not torch.cuda.is_available(), reason='test is too slow without gpu')
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_compare_fid(tmpdir, feature=2048):
""" check that the hole pipeline give the same result as torch-fidelity """
from torch_fidelity import calculate_metrics

metric = FID(feature=feature).cuda()

# We need more samples than the size of the feature vectors to not end up with a singular covariance
img1 = TrialMNIST(tmpdir, num_samples=1000, digits=(0, 1, 2)).data.unsqueeze(1).repeat(1, 3, 1, 1)
img2 = TrialMNIST(tmpdir, num_samples=1000, digits=(1, 2, 3)).data.unsqueeze(1).repeat(1, 3, 1, 1)

batch_size = 100
for i in range(img1.shape[0] // batch_size):
metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda(), real=True)

for i in range(img2.shape[0] // batch_size):
metric.update(img2[batch_size * i:batch_size * (i + 1)].cuda(), real=False)

torch_fid = calculate_metrics(_ImgDataset(img1), _ImgDataset(img2), fid=True, feature_layer_fid=str(feature))

tm_res = metric.compute()

assert torch.allclose(tm_res.cpu(), torch.tensor([torch_fid['frechet_inception_distance']]), atol=1e-3)
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: F401 E402
from torchmetrics.image import FID # noqa: F401 E402
from torchmetrics.metric import Metric # noqa: F401 E402
from torchmetrics.regression import ( # noqa: F401 E402
PSNR,
Expand Down
Loading

0 comments on commit 0237671

Please sign in to comment.