6
6
from datasets import StaticBinaryMnist
7
7
8
8
9
+ multiobject_paths = {
10
+ 'multi_mnist_binary' : './data/multi_mnist/multi_binary_mnist_012.npz' ,
11
+ 'multi_dsprites_binary_rgb' : './data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz' ,
12
+ }
13
+ multiobject_datasets = multiobject_paths .keys ()
14
+
15
+
9
16
class DatasetLoader :
10
17
"""
11
18
Wrapper for DataLoaders. Data attributes:
@@ -20,8 +27,8 @@ def __init__(self, args, cuda):
20
27
21
28
kwargs = {'num_workers' : 1 , 'pin_memory' : False } if cuda else {}
22
29
23
- # Init dataloaders to None
24
- self . train = self . test = None
30
+ # Default dataloader class
31
+ dataloader_class = DataLoader
25
32
26
33
if args .dataset_name == 'static_mnist' :
27
34
data_folder = './data/static_bin_mnist/'
@@ -64,47 +71,30 @@ def __init__(self, args, cuda):
64
71
test_set = CelebA (data_folder , split = 'valid' ,
65
72
download = True , transform = transform )
66
73
67
- elif args .dataset_name == 'multi_dsprites_binary_rgb' :
68
- data_path = './data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz'
74
+ elif args .dataset_name in multiobject_datasets :
75
+ data_path = multiobject_paths [ args . dataset_name ]
69
76
train_set = MultiObjectDataset (data_path , train = True )
70
77
test_set = MultiObjectDataset (data_path , train = False )
71
78
72
- # Custom data loaders
73
- self .train = MultiObjectDataLoader (
74
- train_set ,
75
- batch_size = args .batch_size ,
76
- shuffle = True ,
77
- drop_last = True ,
78
- ** kwargs
79
- )
80
- self .test = MultiObjectDataLoader (
81
- test_set ,
82
- batch_size = args .test_batch_size ,
83
- shuffle = False ,
84
- ** kwargs
85
- )
79
+ # Custom data loader class
80
+ dataloader_class = MultiObjectDataLoader
86
81
87
82
else :
88
83
raise RuntimeError ("Unrecognized data set '{}'" .format (args .dataset_name ))
89
84
90
- # Default training set loader if it hasn't been defined yet
91
- if self .train is None :
92
- self .train = DataLoader (
93
- train_set ,
94
- batch_size = args .batch_size ,
95
- shuffle = True ,
96
- drop_last = True ,
97
- ** kwargs
98
- )
99
-
100
- # Default test set loader if it hasn't been defined yet
101
- if self .test is None :
102
- self .test = DataLoader (
103
- test_set ,
104
- batch_size = args .test_batch_size ,
105
- shuffle = False ,
106
- ** kwargs
107
- )
85
+ self .train = dataloader_class (
86
+ train_set ,
87
+ batch_size = args .batch_size ,
88
+ shuffle = True ,
89
+ drop_last = True ,
90
+ ** kwargs
91
+ )
92
+ self .test = dataloader_class (
93
+ test_set ,
94
+ batch_size = args .test_batch_size ,
95
+ shuffle = False ,
96
+ ** kwargs
97
+ )
108
98
109
99
self .data_shape = self .train .dataset [0 ][0 ].size ()
110
100
self .img_size = self .data_shape [1 :]
0 commit comments