-
Notifications
You must be signed in to change notification settings - Fork 26
/
train_lsun.py
104 lines (74 loc) · 3.3 KB
/
train_lsun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from __future__ import absolute_import
from __future__ import division
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
import vgg
import os
import lsun
from utils import *
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_float('learning_rate', '0.001', """learning rate""")
tf.app.flags.DEFINE_float('beta1', '0.5', """beta for Adam""")
tf.app.flags.DEFINE_integer('batch_size', '64', """batch size""")
tf.app.flags.DEFINE_integer('c_dim', '3', """c dimsion""")
tf.app.flags.DEFINE_integer('z_dim', '1024', """z dimsion""")
tf.app.flags.DEFINE_integer('output_size', '112', """output size""")
tf.app.flags.DEFINE_integer('max_steps', 1000000, """Number of batches to run.""")
tf.app.flags.DEFINE_string('log_dir', './log/', """Directory where to write event logs """)
tf.app.flags.DEFINE_string('checkpoint_dir', './checkpoint/', """Directory where to write the checkpoint""")
tf.app.flags.DEFINE_string('data_path', './church112.tfrecords', """Path to the lsun data file""")
def sess_init():
init = tf.initialize_all_variables()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(init)
return sess
def train():
with tf.Graph().as_default():
data_set = lsun.LSUN()
images = data_set.load(FLAGS.data_path)
global_step = tf.Variable(0, trainable=False)
random_z = vgg.inputs()
D_logits_real, D_logits_fake, D_logits_fake_for_G, \
D_sigmoid_real, D_sigmoid_fake, D_sigmoid_fake_for_G = \
vgg.inference(images, random_z)
G_loss, D_loss = vgg.loss_l2(D_logits_real, D_logits_fake, D_logits_fake_for_G)
t_vars = tf.trainable_variables()
G_vars = [var for var in t_vars if 'g_' in var.name]
D_vars = [var for var in t_vars if 'd_' in var.name]
G_train_op, D_train_op = vgg.train(G_loss, D_loss, G_vars, D_vars, global_step)
sampler = vgg.sampler(random_z)
#summary_op = tf.merge_all_summaries()
sess = sess_init()
tf.train.start_queue_runners(sess=sess)
#summary_writer = tf.train.SummaryWriter(FLAGS.log_dir, sess.graph)
saver = tf.train.Saver()
for step in xrange(1, FLAGS.max_steps+1):
batch_z = np.random.uniform(-1, 1,
[FLAGS.batch_size, FLAGS.z_dim]).astype(np.float32)
_, errD = sess.run([D_train_op, D_loss],
feed_dict={ random_z: batch_z})
_, errG = sess.run([G_train_op, G_loss],
feed_dict={ random_z: batch_z})
if step % 100 == 0:
print "step = %d, errD = %f, errG = %f" % (step, errD, errG)
if np.mod(step, 1000) == 0:
samples = sess.run(sampler,
feed_dict={random_z: batch_z})
save_images(samples, [8, 8],
'./samples/train_{:d}.bmp'.format(step))
# if step % 1000 == 0:
# summary_str = sess.run(summary_op,
# feed_dict={random_z: batch_z})
# summary_writer.add_summary(summary_str, step)
if step % 10000 == 0:
saver.save(sess, '{0}/vgg-{1}.model'.format(FLAGS.checkpoint_dir, step), global_step)
def main(argv=None):
os.system('mkdir -p samples')
os.system('mkdir -p {0}'.format(FLAGS.checkpoint_dir))
os.system('mkdir -p {0}'.format(FLAGS.log_dir))
train()
if __name__ == "__main__":
tf.app.run()