@@ -29,7 +29,7 @@ for PyTorch) and [multiobject](https://github.com/addtt/multi-object-datasets)
29
29
30
30
## Likelihood results
31
31
32
- Log likelihood bounds on the test set. Final results coming soon .
32
+ Log likelihood bounds on the test set (average over 4 random seeds) .
33
33
34
34
| dataset | num layers | -ELBO | - log _ p(x)_ ≤ <br > [ 100 iws] | - log _ p(x)_ ≤ <br > [ 1000 iws] |
35
35
| -------------------- | :----------:| :------------:| :-------------:| :--------------:|
@@ -47,10 +47,12 @@ Note:
47
47
- Bits per dimension in brackets.
48
48
- 'iws' stands for importance weighted samples. More samples means tighter log
49
49
likelihood lower bound. The bound converges to the actual log likelihood as
50
- the number of samples goes to infinity [ 5] .
50
+ the number of samples goes to infinity [ 5] . Note that the model is always
51
+ trained with the ELBO (1 sample).
51
52
- Each pixel in the images is modeled independently. The likelihood is Bernoulli
52
53
for binary images, and discretized mixture of logistics with 10
53
54
components [ 6] otherwise.
55
+ - One day I'll get around to evaluating the IW bound on all datasets with 10000 samples.
54
56
55
57
56
58
## Supported datasets
@@ -118,7 +120,7 @@ more variability in each row as we move to higher layers. When the sampling
118
120
happens in the top layer (_ i = L_ ), all samples are completely independent,
119
121
even within a row.
120
122
121
- #### Binarized MNIST: layers 4, 8, 10, and 12 (top)
123
+ #### Binarized MNIST: layers 4, 8, 10, and 12 (top layer )
122
124
123
125
![ MNIST layers 4] ( _readme_figs/layers_mnist/sample_mode_layer3.png )   ;  ;
124
126
![ MNIST layers 8] ( _readme_figs/layers_mnist/sample_mode_layer7.png )
@@ -127,7 +129,7 @@ even within a row.
127
129
![ MNIST layers 12] ( _readme_figs/layers_mnist/sample_mode_layer11.png )
128
130
129
131
130
- #### SVHN: layers 4, 10, 13, and 15 (top)
132
+ #### SVHN: layers 4, 10, 13, and 15 (top layer )
131
133
132
134
![ SVHN layers 4] ( _readme_figs/layers_svhn/sample_mode_layer3.png )   ;  ;
133
135
![ SVHN layers 10] ( _readme_figs/layers_svhn/sample_mode_layer9.png )
@@ -136,7 +138,7 @@ even within a row.
136
138
![ SVHN layers 15] ( _readme_figs/layers_svhn/sample_mode_layer14.png )
137
139
138
140
139
- #### CIFAR: layers 3, 7, 10, and 15 (top)
141
+ #### CIFAR: layers 3, 7, 10, and 15 (top layer )
140
142
141
143
![ CIFAR layers 3] ( _readme_figs/layers_cifar/sample_mode_layer2.png )   ;  ;
142
144
![ CIFAR layers 7] ( _readme_figs/layers_cifar/sample_mode_layer6.png )
@@ -145,7 +147,7 @@ even within a row.
145
147
![ CIFAR layers 15] ( _readme_figs/layers_cifar/sample_mode_layer14.png )
146
148
147
149
148
- #### CelebA: layers 6, 11, 16, and 20 (top)
150
+ #### CelebA: layers 6, 11, 16, and 20 (top layer )
149
151
150
152
![ CelebA layers 6] ( _readme_figs/layers_celeba/sample_mode_layer5.png )
151
153
@@ -156,7 +158,7 @@ even within a row.
156
158
![ CelebA layers 20] ( _readme_figs/layers_celeba/sample_mode_layer19.png )
157
159
158
160
159
- #### Multi-dSprites: layers 3, 7, 10, and 12 (top)
161
+ #### Multi-dSprites: layers 3, 7, 10, and 12 (top layer )
160
162
161
163
![ MNIST layers 4] ( _readme_figs/layers_multidsprites/sample_mode_layer2.png )   ;  ;
162
164
![ MNIST layers 8] ( _readme_figs/layers_multidsprites/sample_mode_layer6.png )
@@ -191,28 +193,14 @@ I did not perform an extensive hyperparameter search, but this worked pretty wel
191
193
See code for details.
192
194
- freebits=1.0 in experiments with more than 6 stochastic layers, and 0.5 for
193
195
smaller models.
194
- - For everything else, see ` _parse_args ()` in ` experiment/experiment_manager.py ` .
196
+ - For everything else, see ` _add_args ()` in ` experiment/experiment_manager.py ` .
195
197
196
198
With these settings, the number of parameters is roughly 1M per stochastic
197
199
layer. I tried to control for this by experimenting e.g. with half the number
198
200
of layers but twice the number of residual blocks, but it looks like the number
199
201
of stochastic layers is what matters the most.
200
202
201
203
202
- ## Requirements
203
-
204
- Tested with:
205
- ```
206
- python 3.7.6
207
- numpy 1.18.1
208
- torch 1.4.0
209
- torchvision 0.5.0
210
- matplotlib 3.1.2
211
- seaborn 0.9.0
212
- boilr 0.6.0
213
- multiobject 0.0.3
214
- ```
215
-
216
204
## References
217
205
218
206
[ 1] CK Sønderby,
0 commit comments