Skip to content

Commit 75e05bf

Browse files
authored
add plain vae codes
1 parent d1508f8 commit 75e05bf

File tree

2 files changed

+537
-0
lines changed

2 files changed

+537
-0
lines changed

model_vae.py

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
2+
import tensorflow as tf
3+
import tensorlayer as tl
4+
from tensorlayer.layers import *
5+
6+
flags = tf.app.flags
7+
FLAGS = flags.FLAGS
8+
9+
def encoder(input_imgs, is_train = True, reuse = False):
10+
'''
11+
input_imgs: the input images to be encoded into a vector as latent representation. size here is [b_size,64,64,3]
12+
'''
13+
z_dim = FLAGS.z_dim # 512
14+
ef_dim = 64 # encoder filter number
15+
16+
w_init = tf.random_normal_initializer(stddev=0.02)
17+
gamma_init = tf.random_normal_initializer(1., 0.02)
18+
19+
with tf.variable_scope("encoder", reuse = reuse):
20+
tl.layers.set_name_reuse(reuse)
21+
22+
net_in = InputLayer(input_imgs, name='en/in') # (b_size,64,64,3)
23+
net_h0 = Conv2d(net_in, ef_dim, (5, 5), (2, 2), act=None,
24+
padding='SAME', W_init=w_init, name='en/h0/conv2d')
25+
net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu,
26+
is_train=is_train, gamma_init=gamma_init, name='en/h0/batch_norm')
27+
# net_h0.outputs._shape = (b_size,32,32,64)
28+
29+
net_h1 = Conv2d(net_h0, ef_dim*2, (5, 5), (2, 2), act=None,
30+
padding='SAME', W_init=w_init, name='en/h1/conv2d')
31+
net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu,
32+
is_train=is_train, gamma_init=gamma_init, name='en/h1/batch_norm')
33+
# net_h1.outputs._shape = (b_size,16,16,64*2)
34+
35+
net_h2 = Conv2d(net_h1, ef_dim*4, (5, 5), (2, 2), act=None,
36+
padding='SAME', W_init=w_init, name='en/h2/conv2d')
37+
net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu,
38+
is_train=is_train, gamma_init=gamma_init, name='en/h2/batch_norm')
39+
# net_h2.outputs._shape = (b_size,8,8,64*4)
40+
41+
net_h3 = Conv2d(net_h2, ef_dim*8, (5, 5), (2, 2), act=None,
42+
padding='SAME', W_init=w_init, name='en/h3/conv2d')
43+
net_h3 = BatchNormLayer(net_h3, act=tf.nn.relu,
44+
is_train=is_train, gamma_init=gamma_init, name='en/h3/batch_norm')
45+
# net_h2.outputs._shape = (b_size,4,4,64*8)
46+
47+
# mean of z
48+
net_h4 = FlattenLayer(net_h3, name='en/h4/flatten')
49+
# net_h4.outputs._shape = (b_size,8*8*64*4)
50+
net_out1 = DenseLayer(net_h4, n_units=z_dim, act=tf.identity,
51+
W_init = w_init, name='en/h3/lin_sigmoid')
52+
net_out1 = BatchNormLayer(net_out1, act=tf.identity,
53+
is_train=is_train, gamma_init=gamma_init, name='en/out1/batch_norm')
54+
55+
# net_out1 = DenseLayer(net_h4, n_units=z_dim, act=tf.nn.relu,
56+
# W_init = w_init, name='en/h4/lin_sigmoid')
57+
z_mean = net_out1.outputs # (b_size,512)
58+
59+
# log of variance of z(covariance matrix is diagonal)
60+
net_h5 = FlattenLayer(net_h3, name='en/h5/flatten')
61+
net_out2 = DenseLayer(net_h5, n_units=z_dim, act=tf.identity,
62+
W_init = w_init, name='en/h4/lin_sigmoid')
63+
net_out2 = BatchNormLayer(net_out2, act=tf.nn.softplus,
64+
is_train=is_train, gamma_init=gamma_init, name='en/out2/batch_norm')
65+
# net_out2 = DenseLayer(net_h5, n_units=z_dim, act=tf.nn.relu,
66+
# W_init = w_init, name='en/h5/lin_sigmoid')
67+
z_log_sigma_sq = net_out2.outputs + 1e-6# (b_size,512)
68+
69+
return net_out1, net_out2, z_mean, z_log_sigma_sq
70+
71+
def generator(inputs, is_train = True, reuse = False):
72+
'''
73+
generator of GAN, which can also be seen as a decoder of VAE
74+
inputs: latent representation from encoder. [b_size,z_dim]
75+
'''
76+
image_size = FLAGS.output_size # 64 the output size of generator
77+
s2, s4, s8, s16 = int(image_size/2), int(image_size/4), int(image_size/8), int(image_size/16) # 32,16,8,4
78+
gf_dim = 64
79+
c_dim = FLAGS.c_dim # n_color 3
80+
batch_size = FLAGS.batch_size # 64
81+
82+
w_init = tf.random_normal_initializer(stddev=0.02)
83+
gamma_init = tf.random_normal_initializer(1., 0.02)
84+
85+
with tf.variable_scope("generator", reuse = reuse):
86+
tl.layers.set_name_reuse(reuse)
87+
88+
net_in = InputLayer(inputs, name='g/in')
89+
net_h0 = DenseLayer(net_in, n_units=gf_dim*4*s8*s8, W_init=w_init,
90+
act = tf.identity, name='g/h0/lin')
91+
# net_h0.outputs._shape = (b_size,256*8*8)
92+
net_h0 = ReshapeLayer(net_h0, shape=[-1, s8, s8, gf_dim*4], name='g/h0/reshape')
93+
# net_h0.outputs._shape = (b_size,8,8,256)
94+
net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu, is_train=is_train,
95+
gamma_init=gamma_init, name='g/h0/batch_norm')
96+
97+
# upsampling
98+
net_h1 = DeConv2d(net_h0, gf_dim*4, (5, 5), out_size=(s4, s4), strides=(2, 2),
99+
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h1/decon2d')
100+
net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu, is_train=is_train,
101+
gamma_init=gamma_init, name='g/h1/batch_norm')
102+
# net_h1.outputs._shape = (b_size,16,16,256)
103+
104+
net_h2 = DeConv2d(net_h1, gf_dim*2, (5, 5), out_size=(s2, s2), strides=(2, 2),
105+
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h2/decon2d')
106+
net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu, is_train=is_train,
107+
gamma_init=gamma_init, name='g/h2/batch_norm')
108+
# net_h2.outputs._shape = (b_size,32,32,128)
109+
110+
net_h3 = DeConv2d(net_h2, gf_dim//2, (5, 5), out_size=(image_size, image_size), strides=(2, 2),
111+
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h3/decon2d')
112+
net_h3 = BatchNormLayer(net_h3, act=tf.nn.relu, is_train=is_train,
113+
gamma_init=gamma_init, name='g/h3/batch_norm')
114+
# net_h3.outputs._shape = (b_size,64,64,32)
115+
116+
# no BN on last deconv
117+
net_h4 = DeConv2d(net_h3, c_dim, (5, 5), out_size=(image_size, image_size), strides=(1, 1),
118+
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h4/decon2d')
119+
# net_h4.outputs._shape = (b_size,64,64,3)
120+
# net_h4 = Conv2d(net_h3, c_dim, (5,5),(1,1), padding='SAME', W_init=w_init, name='g/h4/conv2d')
121+
logits = net_h4.outputs
122+
net_h4.outputs = tf.nn.tanh(net_h4.outputs)
123+
return net_h4, logits
124+
125+
# net_in = InputLayer(inputs, name='g/in')
126+
# net_h0 = DenseLayer(net_in, n_units=gf_dim*8*s16*s16, W_init=w_init,
127+
# act = tf.identity, name='g/h0/lin')
128+
# # net_h0.outputs._shape = (b_size,512*4*4)
129+
# net_h0 = ReshapeLayer(net_h0, shape=[-1, s16, s16, gf_dim*8], name='g/h0/reshape')
130+
# # net_h0.outputs._shape = (b_size,4,4,512)
131+
# net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu, is_train=is_train,
132+
# gamma_init=gamma_init, name='g/h0/batch_norm')
133+
134+
# # upsampling
135+
# net_h1 = DeConv2d(net_h0, gf_dim*4, (5, 5), out_size=(s8, s8), strides=(2, 2),
136+
# padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h1/decon2d')
137+
# net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu, is_train=is_train,
138+
# gamma_init=gamma_init, name='g/h1/batch_norm')
139+
# # net_h1.outputs._shape = (b_size,8,8,256)
140+
141+
# net_h2 = DeConv2d(net_h1, gf_dim*2, (5, 5), out_size=(s4, s4), strides=(2, 2),
142+
# padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h2/decon2d')
143+
# net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu, is_train=is_train,
144+
# gamma_init=gamma_init, name='g/h2/batch_norm')
145+
# # net_h2.outputs._shape = (b_size,16,16,128)
146+
147+
# net_h3 = DeConv2d(net_h2, gf_dim, (5, 5), out_size=(s2, s2), strides=(2, 2),
148+
# padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h3/decon2d')
149+
# net_h3 = BatchNormLayer(net_h3, act=tf.nn.relu, is_train=is_train,
150+
# gamma_init=gamma_init, name='g/h3/batch_norm')
151+
# # net_h3.outputs._shape = (b_size,32,32,64)
152+
153+
# net_h4 = DeConv2d(net_h3, gf_dim//2, (5, 5), out_size=(image_size, image_size), strides=(2, 2),
154+
# padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h4/decon2d')
155+
# net_h4 = BatchNormLayer(net_h4, act=tf.nn.relu, is_train=is_train,
156+
# gamma_init=gamma_init, name='g/h4/batch_norm')
157+
# # net_h4.outputs._shape = (b_size,64,64,32)
158+
159+
# # no BN on last deconv
160+
# net_h5 = DeConv2d(net_h4, c_dim, (5, 5), out_size=(image_size, image_size), strides=(1, 1),
161+
# padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h5/decon2d')
162+
# # net_h4.outputs._shape = (b_size,64,64,3)
163+
# # net_h4 = Conv2d(net_h3, c_dim, (5,5),(1,1), padding='SAME', W_init=w_init, name='g/h4/conv2d')
164+
# logits = net_h5.outputs
165+
# net_h5.outputs = tf.nn.tanh(net_h5.outputs)
166+
# return net_h5, logits
167+

0 commit comments

Comments
 (0)