Skip to content

Commit f377909

Browse files
committed
Add binary multi-MNIST dataset, minor refactoring
1 parent 2bfadc0 commit f377909

File tree

3 files changed

+29
-37
lines changed

3 files changed

+29
-37
lines changed

Diff for: data/multi_mnist/multi_binary_mnist_012.npz

4.54 MB
Binary file not shown.

Diff for: experiment/data.py

+26-36
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
from datasets import StaticBinaryMnist
77

88

9+
multiobject_paths = {
10+
'multi_mnist_binary': './data/multi_mnist/multi_binary_mnist_012.npz',
11+
'multi_dsprites_binary_rgb': './data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz',
12+
}
13+
multiobject_datasets = multiobject_paths.keys()
14+
15+
916
class DatasetLoader:
1017
"""
1118
Wrapper for DataLoaders. Data attributes:
@@ -20,8 +27,8 @@ def __init__(self, args, cuda):
2027

2128
kwargs = {'num_workers': 1, 'pin_memory': False} if cuda else {}
2229

23-
# Init dataloaders to None
24-
self.train = self.test = None
30+
# Default dataloader class
31+
dataloader_class = DataLoader
2532

2633
if args.dataset_name == 'static_mnist':
2734
data_folder = './data/static_bin_mnist/'
@@ -64,47 +71,30 @@ def __init__(self, args, cuda):
6471
test_set = CelebA(data_folder, split='valid',
6572
download=True, transform=transform)
6673

67-
elif args.dataset_name == 'multi_dsprites_binary_rgb':
68-
data_path = './data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz'
74+
elif args.dataset_name in multiobject_datasets:
75+
data_path = multiobject_paths[args.dataset_name]
6976
train_set = MultiObjectDataset(data_path, train=True)
7077
test_set = MultiObjectDataset(data_path, train=False)
7178

72-
# Custom data loaders
73-
self.train = MultiObjectDataLoader(
74-
train_set,
75-
batch_size=args.batch_size,
76-
shuffle=True,
77-
drop_last=True,
78-
**kwargs
79-
)
80-
self.test = MultiObjectDataLoader(
81-
test_set,
82-
batch_size=args.test_batch_size,
83-
shuffle=False,
84-
**kwargs
85-
)
79+
# Custom data loader class
80+
dataloader_class = MultiObjectDataLoader
8681

8782
else:
8883
raise RuntimeError("Unrecognized data set '{}'".format(args.dataset_name))
8984

90-
# Default training set loader if it hasn't been defined yet
91-
if self.train is None:
92-
self.train = DataLoader(
93-
train_set,
94-
batch_size=args.batch_size,
95-
shuffle=True,
96-
drop_last=True,
97-
**kwargs
98-
)
99-
100-
# Default test set loader if it hasn't been defined yet
101-
if self.test is None:
102-
self.test = DataLoader(
103-
test_set,
104-
batch_size=args.test_batch_size,
105-
shuffle=False,
106-
**kwargs
107-
)
85+
self.train = dataloader_class(
86+
train_set,
87+
batch_size=args.batch_size,
88+
shuffle=True,
89+
drop_last=True,
90+
**kwargs
91+
)
92+
self.test = dataloader_class(
93+
test_set,
94+
batch_size=args.test_batch_size,
95+
shuffle=False,
96+
**kwargs
97+
)
10898

10999
self.data_shape = self.train.dataset[0][0].size()
110100
self.img_size = self.data_shape[1:]

Diff for: experiment/experiment_manager.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def list_options(lst):
9494
legal_nonlin = ['relu', 'leakyrelu', 'elu', 'selu']
9595
legal_resblock = ['cabdcabd', 'bacdbac', 'bacdbacd']
9696
legal_datasets = ['static_mnist', 'cifar10', 'celeba',
97-
'multi_dsprites_binary_rgb', 'svhn']
97+
'svhn', 'multi_dsprites_binary_rgb',
98+
'multi_mnist_binary']
9899
legal_likelihoods = ['bernoulli', 'gaussian',
99100
'discr_log', 'discr_log_mix']
100101

@@ -257,6 +258,7 @@ def list_options(lst):
257258
likelihood_map = {
258259
'static_mnist': 'bernoulli',
259260
'multi_dsprites_binary_rgb': 'bernoulli',
261+
'multi_mnist_binary': 'bernoulli',
260262
'cifar10': 'discr_log_mix',
261263
'celeba': 'discr_log_mix',
262264
'svhn': 'discr_log_mix',

0 commit comments

Comments
 (0)