5
5
6
6
from lib .datasets import StaticBinaryMnist
7
7
8
-
9
8
multiobject_paths = {
10
- 'multi_mnist_binary' : './data/multi_mnist/multi_binary_mnist_012.npz' ,
11
- 'multi_dsprites_binary_rgb' : './data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz' ,
9
+ 'multi_mnist_binary' :
10
+ './data/multi_mnist/multi_binary_mnist_012.npz' ,
11
+ 'multi_dsprites_binary_rgb' :
12
+ './data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz' ,
12
13
}
13
14
multiobject_datasets = multiobject_paths .keys ()
14
15
@@ -32,10 +33,14 @@ def __init__(self, args, cuda):
32
33
33
34
if args .dataset_name == 'static_mnist' :
34
35
data_folder = './data/static_bin_mnist/'
35
- train_set = StaticBinaryMnist (data_folder , train = True ,
36
- download = True , shuffle_init = True )
37
- test_set = StaticBinaryMnist (data_folder , train = False ,
38
- download = True , shuffle_init = True )
36
+ train_set = StaticBinaryMnist (data_folder ,
37
+ train = True ,
38
+ download = True ,
39
+ shuffle_init = True )
40
+ test_set = StaticBinaryMnist (data_folder ,
41
+ train = False ,
42
+ download = True ,
43
+ shuffle_init = True )
39
44
40
45
elif args .dataset_name == 'cifar10' :
41
46
# Discrete values 0, 1/255, ..., 254/255, 1
@@ -46,18 +51,26 @@ def __init__(self, args, cuda):
46
51
transforms .ToTensor (),
47
52
])
48
53
data_folder = './data/cifar10/'
49
- train_set = CIFAR10 (data_folder , train = True ,
50
- download = True , transform = transform )
51
- test_set = CIFAR10 (data_folder , train = False ,
52
- download = True , transform = transform )
54
+ train_set = CIFAR10 (data_folder ,
55
+ train = True ,
56
+ download = True ,
57
+ transform = transform )
58
+ test_set = CIFAR10 (data_folder ,
59
+ train = False ,
60
+ download = True ,
61
+ transform = transform )
53
62
54
63
elif args .dataset_name == 'svhn' :
55
64
transform = transforms .ToTensor ()
56
65
data_folder = './data/svhn/'
57
- train_set = SVHN (data_folder , split = 'train' ,
58
- download = True , transform = transform )
59
- test_set = SVHN (data_folder , split = 'test' ,
60
- download = True , transform = transform )
66
+ train_set = SVHN (data_folder ,
67
+ split = 'train' ,
68
+ download = True ,
69
+ transform = transform )
70
+ test_set = SVHN (data_folder ,
71
+ split = 'test' ,
72
+ download = True ,
73
+ transform = transform )
61
74
62
75
elif args .dataset_name == 'celeba' :
63
76
transform = transforms .Compose ([
@@ -66,10 +79,14 @@ def __init__(self, args, cuda):
66
79
transforms .ToTensor (),
67
80
])
68
81
data_folder = '/scratch/adit/data/celeba/'
69
- train_set = CelebA (data_folder , split = 'train' ,
70
- download = True , transform = transform )
71
- test_set = CelebA (data_folder , split = 'valid' ,
72
- download = True , transform = transform )
82
+ train_set = CelebA (data_folder ,
83
+ split = 'train' ,
84
+ download = True ,
85
+ transform = transform )
86
+ test_set = CelebA (data_folder ,
87
+ split = 'valid' ,
88
+ download = True ,
89
+ transform = transform )
73
90
74
91
elif args .dataset_name in multiobject_datasets :
75
92
data_path = multiobject_paths [args .dataset_name ]
@@ -80,21 +97,18 @@ def __init__(self, args, cuda):
80
97
dataloader_class = MultiObjectDataLoader
81
98
82
99
else :
83
- raise RuntimeError ("Unrecognized data set '{}'" .format (args .dataset_name ))
84
-
85
- self .train = dataloader_class (
86
- train_set ,
87
- batch_size = args .batch_size ,
88
- shuffle = True ,
89
- drop_last = True ,
90
- ** kwargs
91
- )
92
- self .test = dataloader_class (
93
- test_set ,
94
- batch_size = args .test_batch_size ,
95
- shuffle = False ,
96
- ** kwargs
97
- )
100
+ raise RuntimeError ("Unrecognized data set '{}'" .format (
101
+ args .dataset_name ))
102
+
103
+ self .train = dataloader_class (train_set ,
104
+ batch_size = args .batch_size ,
105
+ shuffle = True ,
106
+ drop_last = True ,
107
+ ** kwargs )
108
+ self .test = dataloader_class (test_set ,
109
+ batch_size = args .test_batch_size ,
110
+ shuffle = False ,
111
+ ** kwargs )
98
112
99
113
self .data_shape = self .train .dataset [0 ][0 ].size ()
100
114
self .img_size = self .data_shape [1 :]
0 commit comments