-
Notifications
You must be signed in to change notification settings - Fork 0
/
wGAN_train.py
110 lines (93 loc) · 3.84 KB
/
wGAN_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import glob
import tensorflow as tf
import numpy as np
from PIL import Image
from utils import dataset
from models import generator, discriminator
def preprocess_fn(img):
re_size = 96
img = tf.to_float(tf.image.resize_images(img, [re_size, re_size], method=tf.image.ResizeMethod.BICUBIC)) / 127.5 - 1 #resize the img
return img
class wGAN:
def __init__(self, G, D, dataset):
self.G = G
self.D = D
self.dataset = dataset
self.x_dim = 96*96
self.z_dim = 100
self.x = tf.placeholder(tf.float32, [None, 96, 96, 3], name='x')
self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z')
self.x_ = self.G(self.z)
self.d = self.D(self.x, reuse=False)
self.d_ = self.D(self.x_)
self.g_loss = tf.reduce_mean(self.d_)
self.d_loss = tf.reduce_mean(self.d) - tf.reduce_mean(self.d_)
#regularization
self.reg = tf.contrib.layers.apply_regularization(
tf.contrib.layers.l1_regularizer(2.5e-5),
weights_list=[var for var in tf.global_variables() if 'weights' in var.name]
)
self.g_loss_reg = self.g_loss + self.reg
self.d_loss_reg = self.d_loss + self.reg
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.d_rmsprop = tf.train.RMSPropOptimizer(learning_rate=5e-5) \
.minimize(self.d_loss_reg, var_list=self.D.vars)
self.g_rmsprop = tf.train.RMSPropOptimizer(learning_rate=5e-5) \
.minimize(self.g_loss_reg, var_list=self.G.vars)
self.d_clip = [v.assign(tf.clip_by_value(v, -0.01, 0.01)) for v in self.D.vars]
gpu_options = tf.GPUOptions(allow_growth=True)
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
def train(self, batch_size=64, num_batches=1000000, n_critic=5):
self.sess.run(tf.global_variables_initializer())
for t in range(num_batches):
print('now:',t)
#if t % 500 == 0 or t < 25:
#n_critic = 100
for _ in range(n_critic):
bx = self.dataset.batch()
bz = np.random.normal(size=[batch_size, self.z_dim])
self.sess.run(self.d_clip)
self.sess.run(self.d_rmsprop, feed_dict={
self.x: bx,
self.z: bz
})
bz = np.random.normal(size=[batch_size, self.z_dim])
self.sess.run(self.g_rmsprop, feed_dict={
self.z: bz
})
if t % 1 == 0:
bx = self.dataset.batch()
bz = np.random.normal(size=[batch_size, self.z_dim])
d_loss = self.sess.run(
self.d_loss, feed_dict={
self.x: bx,
self.z: bz
}
)
g_loss = self.sess.run(
self.g_loss, feed_dict={
self.z: bx,
self.z: bz
}
)
print('Iter time: %8d, d_loss: %.4f, g_loss: %.4f' % (t, d_loss, g_loss))
if t % 10 == 0 and t > 0:
save_dir = './faces2'
bz = np.random.normal(size=[batch_size, self.z_dim])
bx = self.sess.run(self.x_, feed_dict={self.z: bz})
bx = np.reshape(bx, (64,96,96,3))
bx = (bx * 255).astype(np.uint8)
im = Image.fromarray(bx[0])
im.save(r'./res/'+str(t)+'.png')
print(bx.shape)
print(bx)
datapaths = glob.glob('./faces/*.jpg')
data = dataset(image_paths=datapaths,
batch_size=64,
shape=[96,96,3],
preprocess_fn=preprocess_fn
)
G = generator()
D = discriminator()
gan = wGAN(G, D, data)
gan.train()