Skip to content

Commit 14db37b

Browse files
authored
Add files via upload
1 parent 6a4be04 commit 14db37b

File tree

4 files changed

+672
-0
lines changed

4 files changed

+672
-0
lines changed

dfc_vae.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
2+
import tensorflow as tf
3+
import tensorlayer as tl
4+
from tensorlayer.layers import *
5+
6+
flags = tf.app.flags
7+
FLAGS = flags.FLAGS
8+
9+
def encoder(input_imgs, is_train = True, reuse = False):
10+
'''
11+
input_imgs: the input images to be encoded into a vector as latent representation. size here is [b_size,64,64,3]
12+
'''
13+
z_dim = FLAGS.z_dim # 100
14+
ef_dim = 32 # encoder filter number
15+
16+
w_init = tf.random_normal_initializer(stddev=0.02)
17+
gamma_init = tf.random_normal_initializer(1., 0.02)
18+
19+
with tf.variable_scope("encoder", reuse = reuse):
20+
tl.layers.set_name_reuse(reuse)
21+
22+
net_in = InputLayer(input_imgs, name='en/in') # (b_size,64,64,3)
23+
net_h0 = Conv2d(net_in, ef_dim, (4, 4), (2, 2), act=None,
24+
padding='SAME', W_init=w_init, name='en/h0/conv2d')
25+
net_h0 = BatchNormLayer(net_h0, act=lambda x: tl.act.lrelu(x, 0.2),
26+
is_train=is_train, gamma_init=gamma_init, name='en/h0/batch_norm')
27+
# net_h0.outputs._shape = (b_size,32,32,32)
28+
29+
net_h1 = Conv2d(net_h0, ef_dim*2, (4, 4), (2, 2), act=None,
30+
padding='SAME', W_init=w_init, name='en/h1/conv2d')
31+
net_h1 = BatchNormLayer(net_h1, act=lambda x: tl.act.lrelu(x, 0.2),
32+
is_train=is_train, gamma_init=gamma_init, name='en/h1/batch_norm')
33+
# net_h1.outputs._shape = (b_size,16,16,64)
34+
35+
net_h2 = Conv2d(net_h1, ef_dim*4, (4, 4), (2, 2), act=None,
36+
padding='SAME', W_init=w_init, name='en/h2/conv2d')
37+
net_h2 = BatchNormLayer(net_h2, act=lambda x: tl.act.lrelu(x, 0.2),
38+
is_train=is_train, gamma_init=gamma_init, name='en/h2/batch_norm')
39+
# net_h2.outputs._shape = (b_size,8,8,128)
40+
41+
net_h3 = Conv2d(net_h2, ef_dim*8, (4, 4), (2, 2), act=None,
42+
padding='SAME', W_init=w_init, name='en/h3/conv2d')
43+
net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2),
44+
is_train=is_train, gamma_init=gamma_init, name='en/h3/batch_norm')
45+
# net_h2.outputs._shape = (b_size,4,4,256)
46+
47+
# mean of z
48+
net_h4 = FlattenLayer(net_h3, name='en/h4/flatten')
49+
# net_h4.outputs._shape = (b_size,4*4*256)
50+
net_out1 = DenseLayer(net_h4, n_units=z_dim, act=tf.identity,
51+
W_init = w_init, name='en/out1/lin_sigmoid')
52+
# net_out1 = BatchNormLayer(net_out1, act=tf.identity,
53+
# is_train=is_train, gamma_init=gamma_init, name='en/out1/batch_norm')
54+
55+
# net_out1 = DenseLayer(net_h4, n_units=z_dim, act=tf.nn.relu,
56+
# W_init = w_init, name='en/h4/lin_sigmoid')
57+
z_mean = net_out1.outputs # (b_size,100)
58+
59+
# log of variance of z(covariance matrix is diagonal)
60+
net_h5 = FlattenLayer(net_h3, name='en/h5/flatten')
61+
net_out2 = DenseLayer(net_h5, n_units=z_dim, act=tf.identity,
62+
W_init = w_init, name='en/out2/lin_sigmoid')
63+
# net_out2 = BatchNormLayer(net_out2, act=tf.nn.softplus,
64+
# is_train=is_train, gamma_init=gamma_init, name='en/out2/batch_norm')
65+
# net_out2 = DenseLayer(net_h5, n_units=z_dim, act=tf.nn.relu,
66+
# W_init = w_init, name='en/h5/lin_sigmoid')
67+
z_log_sigma_sq = net_out2.outputs + 1e-6# (b_size,100)
68+
69+
return net_out1, net_out2, z_mean, z_log_sigma_sq
70+
71+
def generator(inputs, is_train = True, reuse = False):
72+
'''
73+
generator of GAN, which can also be seen as a decoder of VAE
74+
inputs: latent representation from encoder. [b_size,z_dim]
75+
'''
76+
image_size = FLAGS.output_size # 64 the output size of generator
77+
s2, s4, s8, s16 = int(image_size/2), int(image_size/4), int(image_size/8), int(image_size/16) # 32,16,8,4
78+
gf_dim = 32
79+
c_dim = FLAGS.c_dim # n_color 3
80+
batch_size = FLAGS.batch_size # 64
81+
82+
w_init = tf.random_normal_initializer(stddev=0.02)
83+
gamma_init = tf.random_normal_initializer(1., 0.02)
84+
85+
with tf.variable_scope("generator", reuse = reuse):
86+
tl.layers.set_name_reuse(reuse)
87+
88+
net_in = InputLayer(inputs, name='g/in')
89+
net_h0 = DenseLayer(net_in, n_units=gf_dim*8*s16*s16, W_init=w_init,
90+
act = tf.identity, name='g/h0/lin')
91+
# net_h0.outputs._shape = (b_size,256*4*4)
92+
net_h0 = ReshapeLayer(net_h0, shape=[-1, s16, s16, gf_dim*8], name='g/h0/reshape')
93+
# net_h0.outputs._shape = (b_size,4,4,256)
94+
net_h0 = BatchNormLayer(net_h0, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train,
95+
gamma_init=gamma_init, name='g/h0/batch_norm')
96+
97+
# upsampling
98+
net_h1 = UpSampling2dLayer(net_h0, size=[8, 8], is_scale=False, method=1,
99+
align_corners=False, name='g/h1/upsample2d')
100+
net_h1 = Conv2d(net_h1, gf_dim*4, (3, 3), (1, 1), padding='SAME', W_init=w_init, name='g/h1/conv2d')
101+
# net_h1 = DeConv2d(net_h0, gf_dim*4, (3, 3), out_size=(s4, s4), strides=(2, 2),
102+
# padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h1/decon2d')
103+
net_h1 = BatchNormLayer(net_h1, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train,
104+
gamma_init=gamma_init, name='g/h1/batch_norm')
105+
# net_h1.outputs._shape = (b_size,8,8,128)
106+
107+
net_h2 = UpSampling2dLayer(net_h1, size=[16, 16], is_scale=False, method=1,
108+
align_corners=False, name='g/h2/upsample2d')
109+
net_h2 = Conv2d(net_h2, gf_dim*2, (3, 3), (1, 1), padding='SAME', W_init=w_init, name='g/h2/conv2d')
110+
# net_h2 = DeConv2d(net_h1, gf_dim*2, (3, 3), out_size=(s2, s2), strides=(2, 2),
111+
# padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h2/decon2d')
112+
net_h2 = BatchNormLayer(net_h2, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train,
113+
gamma_init=gamma_init, name='g/h2/batch_norm')
114+
# net_h2.outputs._shape = (b_size,16,16,64)
115+
116+
net_h3 = UpSampling2dLayer(net_h2, size=[32, 32], is_scale=False, method=1,
117+
align_corners=False, name='g/h3/upsample2d')
118+
net_h3 = Conv2d(net_h3, gf_dim, (3, 3), (1, 1), padding='SAME', W_init=w_init, name='g/h3/conv2d')
119+
# net_h3 = DeConv2d(net_h2, gf_dim//2, (3, 3), out_size=(image_size, image_size), strides=(2, 2),
120+
# padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h3/decon2d')
121+
net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train,
122+
gamma_init=gamma_init, name='g/h3/batch_norm')
123+
# net_h3.outputs._shape = (b_size,32,32,32)
124+
125+
# no BN on last deconv
126+
# net_h4 = DeConv2d(net_h3, c_dim, (3, 3), out_size=(image_size, image_size), strides=(1, 1),
127+
# padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h4/decon2d')
128+
net_h4 = UpSampling2dLayer(net_h3, size=[64, 64], is_scale=False, method=1,
129+
align_corners=False, name='g/h4/upsample2d')
130+
net_h4 = Conv2d(net_h4, c_dim, (3, 3), (1, 1), padding='SAME', W_init=w_init, name='g/h4/conv2d')
131+
# net_h4.outputs._shape = (b_size,64,64,3)
132+
# net_h4 = Conv2d(net_h3, c_dim, (5,5),(1,1), padding='SAME', W_init=w_init, name='g/h4/conv2d')
133+
logits = net_h4.outputs
134+
net_h4.outputs = tf.nn.tanh(net_h4.outputs)
135+
return net_h4, logits
136+

0 commit comments

Comments
 (0)