Skip to content

Commit

Permalink
add feature_extractor_internal_dtype kwarg to calculate_metrics to he…
Browse files Browse the repository at this point in the history
…lp numerical issues with inception feature extractor and its output variation due to the batch size.

fix #43, related in torchmetrics:
- Lightning-AI/torchmetrics#1620
- Lightning-AI/torchmetrics#1628
add explicit eval in the inception fe to help a case if someone copies just that file for metrics evaluation
add explicit require_grad(False) to clip feature extractor
add test cases to troubleshoot batch size dependence of metrics values
  • Loading branch information
rustoneee committed Apr 30, 2023
1 parent 145bf42 commit 80ad57f
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 16 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `calculate_metrics`
- `samples_resize_and_crop`: Transform all images found in the directory to a given size and square shape
- `feature_extractor`: Accepts a new feature extractor `clip-vit-b-32`
- `feature_extractor_internal_dtype`: Allows to change the internal dtype used in the feature extractor's weights and activations; might be useful to counter numerical issues arising in fp32 implementations, e.g. those seen with the growth of the batch size
- Command line
- `--samples-resize-and-crop`: Transform all images found in the directory to a given size and square shape
- `--feature-extractor`: Accepts a new feature extractor `clip-vit-b-32`
- `--feature-extractor-internal-dtype`: Allows to change the internal dtype used in the feature extractor's weights and activations; might be useful to counter numerical issues arising in fp32 implementations, e.g. those seen with the growth of the batch size
- Registered inputs: `cifar100-train`, `cifar100-val`
- Default features for all metrics are now read from the selected feature extractor
- Tests run in docker now
Expand Down
176 changes: 176 additions & 0 deletions tests/tf1/functional/test_batchsize_independence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import math
import unittest

import torch

from torch_fidelity import calculate_metrics, KEY_METRIC_ISC_MEAN
from torch_fidelity.datasets import RandomlyGeneratedDataset


class TestBatchSizeIndependence(unittest.TestCase):

def _test_batch_size_independence(self, fe, num_samples, dtype, cuda):
if cuda and not torch.cuda.is_available():
raise RuntimeError('CUDA not available')
input1 = RandomlyGeneratedDataset(num_samples, 3, 299, 299, dtype=torch.uint8, seed=2023)
metrics_b_1 = calculate_metrics(
input1=input1, isc=True, isc_splits=2, verbose=True, cache=False,
feature_extractor=fe, feature_extractor_internal_dtype=dtype, batch_size=1, cuda=cuda
)[KEY_METRIC_ISC_MEAN]
metrics_b_all = calculate_metrics(
input1=input1, isc=True, isc_splits=2, verbose=True, cache=False,
feature_extractor=fe, feature_extractor_internal_dtype=dtype, batch_size=num_samples, cuda=cuda
)[KEY_METRIC_ISC_MEAN]
discrepancy = math.fabs(metrics_b_1 - metrics_b_all)
self.assertTrue(
discrepancy < 1e-5,
f'Batch size affects metrics outputs: size_1 gives {metrics_b_1}, size_all gives {metrics_b_all}',
)

def test_batch_size_independence_inceptionfe_4_fp32_cpu(self):
self._test_batch_size_independence('inception-v3-compat', 4, 'float32', False)

def test_batch_size_independence_inceptionfe_8_fp32_cpu(self):
self._test_batch_size_independence('inception-v3-compat', 8, 'float32', False)

def test_batch_size_independence_inceptionfe_16_fp32_cpu(self):
self._test_batch_size_independence('inception-v3-compat', 16, 'float32', False)

def test_batch_size_independence_inceptionfe_32_fp32_cpu(self):
self._test_batch_size_independence('inception-v3-compat', 32, 'float32', False)

def test_batch_size_independence_inceptionfe_64_fp32_cpu(self):
self._test_batch_size_independence('inception-v3-compat', 64, 'float32', False)

def test_batch_size_independence_inceptionfe_128_fp32_cpu(self):
self._test_batch_size_independence('inception-v3-compat', 128, 'float32', False)

def test_batch_size_independence_inceptionfe_4_fp32_cuda(self):
self._test_batch_size_independence('inception-v3-compat', 4, 'float32', True)

def test_batch_size_independence_inceptionfe_8_fp32_cuda(self):
self._test_batch_size_independence('inception-v3-compat', 8, 'float32', True)

def test_batch_size_independence_inceptionfe_16_fp32_cuda(self):
self._test_batch_size_independence('inception-v3-compat', 16, 'float32', True)

def test_batch_size_independence_inceptionfe_32_fp32_cuda(self):
self._test_batch_size_independence('inception-v3-compat', 32, 'float32', True)

def test_batch_size_independence_inceptionfe_64_fp32_cuda(self):
self._test_batch_size_independence('inception-v3-compat', 64, 'float32', True)

def test_batch_size_independence_inceptionfe_128_fp32_cuda(self):
self._test_batch_size_independence('inception-v3-compat', 128, 'float32', True)

def test_batch_size_independence_inceptionfe_4_fp64_cpu(self):
self._test_batch_size_independence('inception-v3-compat', 4, 'float64', False)

def test_batch_size_independence_inceptionfe_8_fp64_cpu(self):
self._test_batch_size_independence('inception-v3-compat', 8, 'float64', False)

def test_batch_size_independence_inceptionfe_16_fp64_cpu(self):
self._test_batch_size_independence('inception-v3-compat', 16, 'float64', False)

def test_batch_size_independence_inceptionfe_32_fp64_cpu(self):
self._test_batch_size_independence('inception-v3-compat', 32, 'float64', False)

def test_batch_size_independence_inceptionfe_64_fp64_cpu(self):
self._test_batch_size_independence('inception-v3-compat', 64, 'float64', False)

def test_batch_size_independence_inceptionfe_128_fp64_cpu(self):
self._test_batch_size_independence('inception-v3-compat', 128, 'float64', False)

def test_batch_size_independence_inceptionfe_4_fp64_cuda(self):
self._test_batch_size_independence('inception-v3-compat', 4, 'float64', True)

def test_batch_size_independence_inceptionfe_8_fp64_cuda(self):
self._test_batch_size_independence('inception-v3-compat', 8, 'float64', True)

def test_batch_size_independence_inceptionfe_16_fp64_cuda(self):
self._test_batch_size_independence('inception-v3-compat', 16, 'float64', True)

def test_batch_size_independence_inceptionfe_32_fp64_cuda(self):
self._test_batch_size_independence('inception-v3-compat', 32, 'float64', True)

def test_batch_size_independence_inceptionfe_64_fp64_cuda(self):
self._test_batch_size_independence('inception-v3-compat', 64, 'float64', True)

def test_batch_size_independence_inceptionfe_128_fp64_cuda(self):
self._test_batch_size_independence('inception-v3-compat', 128, 'float64', True)

def test_batch_size_independence_clipfe_4_fp32_cpu(self):
self._test_batch_size_independence('clip-vit-b-32', 4, 'float32', False)

def test_batch_size_independence_clipfe_8_fp32_cpu(self):
self._test_batch_size_independence('clip-vit-b-32', 8, 'float32', False)

def test_batch_size_independence_clipfe_16_fp32_cpu(self):
self._test_batch_size_independence('clip-vit-b-32', 16, 'float32', False)

def test_batch_size_independence_clipfe_32_fp32_cpu(self):
self._test_batch_size_independence('clip-vit-b-32', 32, 'float32', False)

def test_batch_size_independence_clipfe_64_fp32_cpu(self):
self._test_batch_size_independence('clip-vit-b-32', 64, 'float32', False)

def test_batch_size_independence_clipfe_128_fp32_cpu(self):
self._test_batch_size_independence('clip-vit-b-32', 128, 'float32', False)

def test_batch_size_independence_clipfe_4_fp32_cuda(self):
self._test_batch_size_independence('clip-vit-b-32', 4, 'float32', True)

def test_batch_size_independence_clipfe_8_fp32_cuda(self):
self._test_batch_size_independence('clip-vit-b-32', 8, 'float32', True)

def test_batch_size_independence_clipfe_16_fp32_cuda(self):
self._test_batch_size_independence('clip-vit-b-32', 16, 'float32', True)

def test_batch_size_independence_clipfe_32_fp32_cuda(self):
self._test_batch_size_independence('clip-vit-b-32', 32, 'float32', True)

def test_batch_size_independence_clipfe_64_fp32_cuda(self):
self._test_batch_size_independence('clip-vit-b-32', 64, 'float32', True)

def test_batch_size_independence_clipfe_128_fp32_cuda(self):
self._test_batch_size_independence('clip-vit-b-32', 128, 'float32', True)

def test_batch_size_independence_clipfe_4_fp64_cpu(self):
self._test_batch_size_independence('clip-vit-b-32', 4, 'float64', False)

def test_batch_size_independence_clipfe_8_fp64_cpu(self):
self._test_batch_size_independence('clip-vit-b-32', 8, 'float64', False)

def test_batch_size_independence_clipfe_16_fp64_cpu(self):
self._test_batch_size_independence('clip-vit-b-32', 16, 'float64', False)

def test_batch_size_independence_clipfe_32_fp64_cpu(self):
self._test_batch_size_independence('clip-vit-b-32', 32, 'float64', False)

def test_batch_size_independence_clipfe_64_fp64_cpu(self):
self._test_batch_size_independence('clip-vit-b-32', 64, 'float64', False)

def test_batch_size_independence_clipfe_128_fp64_cpu(self):
self._test_batch_size_independence('clip-vit-b-32', 128, 'float64', False)

def test_batch_size_independence_clipfe_4_fp64_cuda(self):
self._test_batch_size_independence('clip-vit-b-32', 4, 'float64', True)

def test_batch_size_independence_clipfe_8_fp64_cuda(self):
self._test_batch_size_independence('clip-vit-b-32', 8, 'float64', True)

def test_batch_size_independence_clipfe_16_fp64_cuda(self):
self._test_batch_size_independence('clip-vit-b-32', 16, 'float64', True)

def test_batch_size_independence_clipfe_32_fp64_cuda(self):
self._test_batch_size_independence('clip-vit-b-32', 32, 'float64', True)

def test_batch_size_independence_clipfe_64_fp64_cuda(self):
self._test_batch_size_independence('clip-vit-b-32', 64, 'float64', True)

def test_batch_size_independence_clipfe_128_fp64_cuda(self):
self._test_batch_size_independence('clip-vit-b-32', 128, 'float64', True)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions torch_fidelity/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'feature_layer_fid': None,
'feature_layer_kid': None,
'feature_extractor_weights_path': None,
'feature_extractor_internal_dtype': None,
'isc_splits': 10,
'kid_subsets': 100,
'kid_subset_size': 1000,
Expand Down
6 changes: 6 additions & 0 deletions torch_fidelity/feature_extractor_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import torch
import torch.nn as nn

from torch_fidelity.helpers import vassert


class FeatureExtractorBase(nn.Module):
SUPPORTED_DTYPES = {
'float32': torch.float32,
'float64': torch.float64,
}

def __init__(self, name, features_list):
"""
Base class for feature extractors that can be used in :func:`calculate_metrics`.
Expand Down
32 changes: 23 additions & 9 deletions torch_fidelity/feature_extractor_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,16 @@ def stem(x):
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""

def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
if orig_dtype == torch.float16:
out = F.layer_norm(
x.to(torch.float32), self.normalized_shape, self.weight.to(torch.float32),
self.bias.to(torch.float32), self.eps
).to(orig_dtype)
else:
out = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return out


class QuickGELU(nn.Module):
Expand Down Expand Up @@ -345,7 +351,7 @@ def _convert_weights_to_fp16(l):
model.apply(_convert_weights_to_fp16)


def build_model(state_dict: dict):
def build_model(state_dict, feature_extractor_internal_dtype):
vit = "visual.proj" in state_dict

if vit:
Expand Down Expand Up @@ -385,7 +391,9 @@ def build_model(state_dict: dict):

convert_weights(model)
model.load_state_dict(state_dict)
model.float()
model.to(feature_extractor_internal_dtype)
for p in model.parameters():
p.requires_grad_(False)
model.eval()
return model

Expand All @@ -397,6 +405,7 @@ def __init__(
name,
features_list,
feature_extractor_weights_path=None,
feature_extractor_internal_dtype=None,
**kwargs,
):
"""
Expand All @@ -414,9 +423,13 @@ def __init__(
feature_extractor_weights_path (str): Path to the pretrained CLIP model weights in PyTorch format.
Downloads from internet if `None`.
feature_extractor_internal_dtype (str): dtype to use inside the feature extractor. Specifying it may improve
numerical precision in some cases. Supported values are 'float32' (default), and 'float64'.
"""
super(FeatureExtractorCLIP, self).__init__(name, features_list)
vassert(name in MODEL_URLS, f'Model {name} not found; available models = {list(MODEL_URLS.keys())}')
self.feature_extractor_internal_dtype = self.SUPPORTED_DTYPES[feature_extractor_internal_dtype or 'float32']

if feature_extractor_weights_path is None:
with redirect_stdout(sys.stderr), warnings.catch_warnings():
Expand All @@ -433,7 +446,7 @@ def __init__(
else:
model_jit = torch.jit.load(feature_extractor_weights_path, map_location="cpu")

self.model = build_model(model_jit.state_dict())
self.model = build_model(model_jit.state_dict(), self.feature_extractor_internal_dtype)
self.resolution = self.model.visual.input_resolution

for p in self.parameters():
Expand All @@ -443,8 +456,9 @@ def forward(self, x):
vassert(torch.is_tensor(x) and x.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8')
features = {}

x = x.to(self.feature_extractor_internal_dtype)
x = torchvision.transforms.functional.normalize(
x.float(),
x,
(255 * 0.48145466, 255 * 0.4578275, 255 * 0.40821073),
(255 * 0.26862954, 255 * 0.26130258, 255 * 0.27577711),
inplace=False,
Expand All @@ -459,7 +473,7 @@ def forward(self, x):
# N x 3 x R x R

x = self.model.visual(x)
features['clip'] = x
features['clip'] = x.to(torch.float32)

return tuple(features[a] for a in self.features_list)

Expand Down
21 changes: 14 additions & 7 deletions torch_fidelity/feature_extractor_inceptionv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
name,
features_list,
feature_extractor_weights_path=None,
feature_extractor_internal_dtype=None,
**kwargs,
):
"""
Expand All @@ -50,8 +51,12 @@ def __init__(
feature_extractor_weights_path (str): Path to the pretrained InceptionV3 model weights in PyTorch format.
Refer to `util_convert_inception_weights` for making your own. Downloads from internet if `None`.
feature_extractor_internal_dtype (str): dtype to use inside the feature extractor. Specifying it may improve
numerical precision in some cases. Supported values are 'float32' (default), and 'float64'.
"""
super(FeatureExtractorInceptionV3, self).__init__(name, features_list)
self.feature_extractor_internal_dtype = self.SUPPORTED_DTYPES[feature_extractor_internal_dtype or 'float32']

self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
Expand Down Expand Up @@ -85,15 +90,17 @@ def __init__(
state_dict = torch.load(feature_extractor_weights_path)
self.load_state_dict(state_dict)

self.to(self.feature_extractor_internal_dtype)
for p in self.parameters():
p.requires_grad_(False)
self.eval()

def forward(self, x):
vassert(torch.is_tensor(x) and x.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8')
features = {}
remaining_features = self.features_list.copy()

x = x.float()
x = x.to(self.feature_extractor_internal_dtype)
# N x 3 x ? x ?

x = interpolate_bilinear_2d_like_tensorflow1x(
Expand All @@ -117,7 +124,7 @@ def forward(self, x):
# N x 64 x 73 x 73

if '64' in remaining_features:
features['64'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
features['64'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1).to(torch.float32)
remaining_features.remove('64')
if len(remaining_features) == 0:
return tuple(features[a] for a in self.features_list)
Expand All @@ -130,7 +137,7 @@ def forward(self, x):
# N x 192 x 35 x 35

if '192' in remaining_features:
features['192'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
features['192'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1).to(torch.float32)
remaining_features.remove('192')
if len(remaining_features) == 0:
return tuple(features[a] for a in self.features_list)
Expand All @@ -153,7 +160,7 @@ def forward(self, x):
# N x 768 x 17 x 17

if '768' in remaining_features:
features['768'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
features['768'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1).to(torch.float32)
remaining_features.remove('768')
if len(remaining_features) == 0:
return tuple(features[a] for a in self.features_list)
Expand All @@ -171,15 +178,15 @@ def forward(self, x):
# N x 2048

if '2048' in remaining_features:
features['2048'] = x
features['2048'] = x.to(torch.float32)
remaining_features.remove('2048')
if len(remaining_features) == 0:
return tuple(features[a] for a in self.features_list)

if 'logits_unbiased' in remaining_features:
x = x.mm(self.fc.weight.T)
# N x 1008 (num_classes)
features['logits_unbiased'] = x
features['logits_unbiased'] = x.to(torch.float32)
remaining_features.remove('logits_unbiased')
if len(remaining_features) == 0:
return tuple(features[a] for a in self.features_list)
Expand All @@ -189,7 +196,7 @@ def forward(self, x):
x = self.fc(x)
# N x 1008 (num_classes)

features['logits'] = x
features['logits'] = x.to(torch.float32)
return tuple(features[a] for a in self.features_list)

@staticmethod
Expand Down
Loading

0 comments on commit 80ad57f

Please sign in to comment.