-
Notifications
You must be signed in to change notification settings - Fork 413
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
648 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
scipy | ||
torch-fidelity |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,3 +18,6 @@ cloudpickle>=1.3 | |
scikit-learn>=0.24 | ||
scikit-image>0.17.1 | ||
nltk>=3.6 | ||
|
||
# add extra requirements | ||
-r image.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import logging | ||
import os | ||
import random | ||
import time | ||
import urllib.request | ||
from typing import Optional, Sequence, Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class MNIST(Dataset): | ||
""" | ||
Customized `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset for testing torchmetrics | ||
without the torchvision dependency. | ||
Part of the code was copied from | ||
https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py | ||
Args: | ||
root: Root directory of dataset where ``MNIST/processed/training.pt`` | ||
and ``MNIST/processed/test.pt`` exist. | ||
train: If ``True``, creates dataset from ``training.pt``, | ||
otherwise from ``test.pt``. | ||
normalize: mean and std deviation of the MNIST dataset. | ||
download: If true, downloads the dataset from the internet and | ||
puts it in root directory. If dataset is already downloaded, it is not | ||
downloaded again. | ||
Examples: | ||
>>> dataset = MNIST(".", download=True) | ||
>>> len(dataset) | ||
60000 | ||
>>> torch.bincount(dataset.targets) | ||
tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]) | ||
""" | ||
|
||
RESOURCES = ( | ||
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt", | ||
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt", | ||
) | ||
|
||
TRAIN_FILE_NAME = 'training.pt' | ||
TEST_FILE_NAME = 'test.pt' | ||
cache_folder_name = 'complete' | ||
|
||
def __init__( | ||
self, | ||
root: str, | ||
train: bool = True, | ||
normalize: tuple = (0.1307, 0.3081), | ||
download: bool = True, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
self.root = root | ||
self.train = train # training set or test set | ||
self.normalize = normalize | ||
|
||
self.prepare_data(download) | ||
|
||
data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME | ||
self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) | ||
|
||
def __getitem__(self, idx: int) -> Tuple[Tensor, int]: | ||
img = self.data[idx].float().unsqueeze(0) | ||
target = int(self.targets[idx]) | ||
|
||
if self.normalize is not None and len(self.normalize) == 2: | ||
img = self.normalize_tensor(img, *self.normalize) | ||
|
||
return img, target | ||
|
||
def __len__(self) -> int: | ||
return len(self.data) | ||
|
||
@property | ||
def cached_folder_path(self) -> str: | ||
return os.path.join(self.root, 'MNIST', self.cache_folder_name) | ||
|
||
def _check_exists(self, data_folder: str) -> bool: | ||
existing = True | ||
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): | ||
existing = existing and os.path.isfile(os.path.join(data_folder, fname)) | ||
return existing | ||
|
||
def prepare_data(self, download: bool = True): | ||
if download and not self._check_exists(self.cached_folder_path): | ||
self._download(self.cached_folder_path) | ||
if not self._check_exists(self.cached_folder_path): | ||
raise RuntimeError('Dataset not found.') | ||
|
||
def _download(self, data_folder: str) -> None: | ||
os.makedirs(data_folder) | ||
for url in self.RESOURCES: | ||
logging.info(f'Downloading {url}') | ||
fpath = os.path.join(data_folder, os.path.basename(url)) | ||
urllib.request.urlretrieve(url, fpath) | ||
|
||
@staticmethod | ||
def _try_load(path_data, trials: int = 30, delta: float = 1.): | ||
"""Resolving loading from the same time from multiple concurrent processes.""" | ||
res, exception = None, None | ||
assert trials, "at least some trial has to be set" | ||
assert os.path.isfile(path_data), f'missing file: {path_data}' | ||
for _ in range(trials): | ||
try: | ||
res = torch.load(path_data) | ||
# todo: specify the possible exception | ||
except Exception as e: | ||
exception = e | ||
time.sleep(delta * random.random()) | ||
else: | ||
break | ||
if exception is not None: | ||
# raise the caught exception | ||
raise exception | ||
return res | ||
|
||
@staticmethod | ||
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor: | ||
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) | ||
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) | ||
return tensor.sub(mean).div(std) | ||
|
||
|
||
class TrialMNIST(MNIST): | ||
"""Constrained MNIST dataset | ||
Args: | ||
num_samples: number of examples per selected class/digit | ||
digits: list selected MNIST digits/classes | ||
kwargs: Same as MNIST | ||
Examples: | ||
>>> dataset = TrialMNIST(".", download=True) | ||
>>> len(dataset) | ||
300 | ||
>>> sorted(set([d.item() for d in dataset.targets])) | ||
[0, 1, 2] | ||
>>> torch.bincount(dataset.targets) | ||
tensor([100, 100, 100]) | ||
""" | ||
|
||
def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs): | ||
# number of examples per class | ||
self.num_samples = num_samples | ||
# take just a subset of MNIST dataset | ||
self.digits = sorted(digits) if digits else list(range(10)) | ||
|
||
self.cache_folder_name = f"digits-{'-'.join(str(d) for d in self.digits)}_nb-{self.num_samples}" | ||
|
||
super().__init__(root, normalize=(0.5, 1.0), **kwargs) | ||
|
||
@staticmethod | ||
def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence): | ||
classes = {d: 0 for d in digits} | ||
indexes = [] | ||
for idx, target in enumerate(full_targets): | ||
label = target.item() | ||
if classes.get(label, float('inf')) >= num_samples: | ||
continue | ||
indexes.append(idx) | ||
classes[label] += 1 | ||
if all(classes[k] >= num_samples for k in classes): | ||
break | ||
data = full_data[indexes] | ||
targets = full_targets[indexes] | ||
return data, targets | ||
|
||
def _download(self, data_folder: str) -> None: | ||
super()._download(data_folder) | ||
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): | ||
path_fname = os.path.join(self.cached_folder_path, fname) | ||
assert os.path.isfile(path_fname), f'Missing cached file: {path_fname}' | ||
data, targets = self._try_load(path_fname) | ||
data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits) | ||
torch.save((data, targets), os.path.join(self.cached_folder_path, fname)) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import pickle | ||
|
||
import pytest | ||
import torch | ||
from scipy.linalg import sqrtm as scipy_sqrtm | ||
from torch.utils.data import Dataset | ||
|
||
from tests.helpers.datasets import TrialMNIST | ||
from torchmetrics.image.fid import FID, sqrtm | ||
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE | ||
|
||
torch.manual_seed(42) | ||
|
||
|
||
@pytest.mark.parametrize("matrix_size", [2, 10, 100, 500]) | ||
def test_matrix_sqrt(matrix_size): | ||
""" test that metrix sqrt function works as expected """ | ||
|
||
def generate_cov(n): | ||
data = torch.randn(2 * n, n) | ||
return (data - data.mean(dim=0)).T @ (data - data.mean(dim=0)) | ||
|
||
cov1 = generate_cov(matrix_size) | ||
cov2 = generate_cov(matrix_size) | ||
|
||
scipy_res = scipy_sqrtm((cov1 @ cov2).numpy()).real | ||
tm_res = sqrtm(cov1 @ cov2) | ||
assert torch.allclose(torch.tensor(scipy_res).float(), tm_res, atol=1e-3) | ||
|
||
|
||
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') | ||
def test_fid_pickle(): | ||
""" Assert that we can initialize the metric and pickle it""" | ||
metric = FID() | ||
assert metric | ||
|
||
# verify metrics work after being loaded from pickled state | ||
pickled_metric = pickle.dumps(metric) | ||
metric = pickle.loads(pickled_metric) | ||
|
||
|
||
def test_fid_raises_errors_and_warnings(): | ||
""" Test that expected warnings and errors are raised """ | ||
with pytest.warns( | ||
UserWarning, | ||
match='Metric `FID` will save all extracted features in buffer.' | ||
' For large datasets this may lead to large memory footprint.' | ||
): | ||
_ = FID() | ||
|
||
if _TORCH_FIDELITY_AVAILABLE: | ||
with pytest.raises(ValueError, match='Integer input to argument `feature` must be one of .*'): | ||
_ = FID(feature=2) | ||
else: | ||
with pytest.raises( | ||
ValueError, | ||
match='FID metric requires that Torch-fidelity is installed.' | ||
'Either install as `pip install torchmetrics[image-quality]`' | ||
' or `pip install torch-fidelity`' | ||
): | ||
_ = FID() | ||
|
||
|
||
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') | ||
def test_fid_same_input(): | ||
""" if real and fake are update on the same data the fid score should be 0 """ | ||
metric = FID(feature=192) | ||
|
||
for _ in range(2): | ||
img = torch.randint(0, 255, (5, 3, 299, 299), dtype=torch.uint8) | ||
metric.update(img, real=True) | ||
metric.update(img, real=False) | ||
|
||
assert torch.allclose(torch.cat(metric.real_features, dim=0), torch.cat(metric.fake_features, dim=0)) | ||
|
||
val = metric.compute() | ||
assert torch.allclose(val, torch.zeros_like(val), atol=1e-3) | ||
|
||
|
||
class _ImgDataset(Dataset): | ||
|
||
def __init__(self, imgs): | ||
self.imgs = imgs | ||
|
||
def __getitem__(self, idx): | ||
return self.imgs[idx] | ||
|
||
def __len__(self): | ||
return self.imgs.shape[0] | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test is too slow without gpu') | ||
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') | ||
def test_compare_fid(tmpdir, feature=2048): | ||
""" check that the hole pipeline give the same result as torch-fidelity """ | ||
from torch_fidelity import calculate_metrics | ||
|
||
metric = FID(feature=feature).cuda() | ||
|
||
# We need more samples than the size of the feature vectors to not end up with a singular covariance | ||
img1 = TrialMNIST(tmpdir, num_samples=1000, digits=(0, 1, 2)).data.unsqueeze(1).repeat(1, 3, 1, 1) | ||
img2 = TrialMNIST(tmpdir, num_samples=1000, digits=(1, 2, 3)).data.unsqueeze(1).repeat(1, 3, 1, 1) | ||
|
||
batch_size = 100 | ||
for i in range(img1.shape[0] // batch_size): | ||
metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda(), real=True) | ||
|
||
for i in range(img2.shape[0] // batch_size): | ||
metric.update(img2[batch_size * i:batch_size * (i + 1)].cuda(), real=False) | ||
|
||
torch_fid = calculate_metrics(_ImgDataset(img1), _ImgDataset(img2), fid=True, feature_layer_fid=str(feature)) | ||
|
||
tm_res = metric.compute() | ||
|
||
assert torch.allclose(tm_res.cpu(), torch.tensor([torch_fid['frechet_inception_distance']]), atol=1e-3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.