Skip to content

Commit 89a0d38

Browse files
committed
Style
1 parent d2c31fd commit 89a0d38

File tree

8 files changed

+318
-217
lines changed

8 files changed

+318
-217
lines changed

Diff for: experiment/data.py

+48-34
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55

66
from lib.datasets import StaticBinaryMnist
77

8-
98
multiobject_paths = {
10-
'multi_mnist_binary': './data/multi_mnist/multi_binary_mnist_012.npz',
11-
'multi_dsprites_binary_rgb': './data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz',
9+
'multi_mnist_binary':
10+
'./data/multi_mnist/multi_binary_mnist_012.npz',
11+
'multi_dsprites_binary_rgb':
12+
'./data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz',
1213
}
1314
multiobject_datasets = multiobject_paths.keys()
1415

@@ -32,10 +33,14 @@ def __init__(self, args, cuda):
3233

3334
if args.dataset_name == 'static_mnist':
3435
data_folder = './data/static_bin_mnist/'
35-
train_set = StaticBinaryMnist(data_folder, train=True,
36-
download=True, shuffle_init=True)
37-
test_set = StaticBinaryMnist(data_folder, train=False,
38-
download=True, shuffle_init=True)
36+
train_set = StaticBinaryMnist(data_folder,
37+
train=True,
38+
download=True,
39+
shuffle_init=True)
40+
test_set = StaticBinaryMnist(data_folder,
41+
train=False,
42+
download=True,
43+
shuffle_init=True)
3944

4045
elif args.dataset_name == 'cifar10':
4146
# Discrete values 0, 1/255, ..., 254/255, 1
@@ -46,18 +51,26 @@ def __init__(self, args, cuda):
4651
transforms.ToTensor(),
4752
])
4853
data_folder = './data/cifar10/'
49-
train_set = CIFAR10(data_folder, train=True,
50-
download=True, transform=transform)
51-
test_set = CIFAR10(data_folder, train=False,
52-
download=True, transform=transform)
54+
train_set = CIFAR10(data_folder,
55+
train=True,
56+
download=True,
57+
transform=transform)
58+
test_set = CIFAR10(data_folder,
59+
train=False,
60+
download=True,
61+
transform=transform)
5362

5463
elif args.dataset_name == 'svhn':
5564
transform = transforms.ToTensor()
5665
data_folder = './data/svhn/'
57-
train_set = SVHN(data_folder, split='train',
58-
download=True, transform=transform)
59-
test_set = SVHN(data_folder, split='test',
60-
download=True, transform=transform)
66+
train_set = SVHN(data_folder,
67+
split='train',
68+
download=True,
69+
transform=transform)
70+
test_set = SVHN(data_folder,
71+
split='test',
72+
download=True,
73+
transform=transform)
6174

6275
elif args.dataset_name == 'celeba':
6376
transform = transforms.Compose([
@@ -66,10 +79,14 @@ def __init__(self, args, cuda):
6679
transforms.ToTensor(),
6780
])
6881
data_folder = '/scratch/adit/data/celeba/'
69-
train_set = CelebA(data_folder, split='train',
70-
download=True, transform=transform)
71-
test_set = CelebA(data_folder, split='valid',
72-
download=True, transform=transform)
82+
train_set = CelebA(data_folder,
83+
split='train',
84+
download=True,
85+
transform=transform)
86+
test_set = CelebA(data_folder,
87+
split='valid',
88+
download=True,
89+
transform=transform)
7390

7491
elif args.dataset_name in multiobject_datasets:
7592
data_path = multiobject_paths[args.dataset_name]
@@ -80,21 +97,18 @@ def __init__(self, args, cuda):
8097
dataloader_class = MultiObjectDataLoader
8198

8299
else:
83-
raise RuntimeError("Unrecognized data set '{}'".format(args.dataset_name))
84-
85-
self.train = dataloader_class(
86-
train_set,
87-
batch_size=args.batch_size,
88-
shuffle=True,
89-
drop_last=True,
90-
**kwargs
91-
)
92-
self.test = dataloader_class(
93-
test_set,
94-
batch_size=args.test_batch_size,
95-
shuffle=False,
96-
**kwargs
97-
)
100+
raise RuntimeError("Unrecognized data set '{}'".format(
101+
args.dataset_name))
102+
103+
self.train = dataloader_class(train_set,
104+
batch_size=args.batch_size,
105+
shuffle=True,
106+
drop_last=True,
107+
**kwargs)
108+
self.test = dataloader_class(test_set,
109+
batch_size=args.test_batch_size,
110+
shuffle=False,
111+
**kwargs)
98112

99113
self.data_shape = self.train.dataset[0][0].size()
100114
self.img_size = self.data_shape[1:]

Diff for: lib/datasets.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ class StaticBinaryMnist(TensorDataset):
1111
def __init__(self, folder, train, download=False, shuffle_init=False):
1212
self.download = download
1313
if train:
14-
x = np.concatenate([
14+
sets = [
1515
self._get_binarized_mnist(folder, shuffle_init, split='train'),
1616
self._get_binarized_mnist(folder, shuffle_init, split='valid')
17-
], axis=0)
17+
]
18+
x = np.concatenate(sets, axis=0)
1819
else:
1920
x = self._get_binarized_mnist(folder, shuffle_init, split='test')
2021
labels = torch.zeros(len(x),).fill_(float('nan'))
2122
super().__init__(torch.from_numpy(x), labels)
2223

23-
2424
def _get_binarized_mnist(self, folder, shuffle_init, split=None):
2525
"""
2626
Get statically binarized MNIST. Code partially taken from
@@ -56,9 +56,12 @@ def _get_binarized_mnist(self, folder, shuffle_init, split=None):
5656
lines = f.readlines()
5757

5858
os.remove(path_mat)
59-
lines = np.array([[int(i) for i in line.split()] for line in lines])
60-
data[subdataset] = lines.astype('float32').reshape((-1, 1, 28, 28))
61-
np.savez_compressed(path_mat.split(".amat")[0], data=data[subdataset])
59+
lines = np.array(
60+
[[int(i) for i in line.split()] for line in lines])
61+
data[subdataset] = lines.astype('float32').reshape(
62+
(-1, 1, 28, 28))
63+
np.savez_compressed(path_mat.split(".amat")[0],
64+
data=data[subdataset])
6265

6366
else:
6467
data[split] = np.load(path)['data']

0 commit comments

Comments
 (0)