Skip to content

Commit

Permalink
working
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Sep 30, 2017
1 parent e82421b commit 4ae0c78
Show file tree
Hide file tree
Showing 18 changed files with 101 additions and 79 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# CycleGAN_Tensorlayer
Re-implement CycleGAN in Tensorlayer
Re-implement CycleGAN in TensorLayer

- Original CycleGAN
- Improved CycleGAN with resize-convolution



### Prerequisites:

* Tensorlayer
* TensorLayer
* TensorFlow
* Python

Expand Down
150 changes: 77 additions & 73 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
flags.DEFINE_float("weight_decay", 1e-5, "Weight decay for l2 loss")
flags.DEFINE_float("pool_size", 50, 'size of image buffer that stores previously generated images, default: 50')
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("batch_size", 1, "The number of batch images [1]")
flags.DEFINE_integer("batch_size", 1, "The number of batch images [1] if we use InstanceNormLayer !")
flags.DEFINE_integer("image_size", 256, "The size of image to use (will be center cropped) [256]")
flags.DEFINE_integer("gf_dim", 32, "Size of generator filters in first layer")
flags.DEFINE_integer("df_dim", 64, "Size of discriminator filters in first layer")
Expand All @@ -29,9 +29,9 @@
flags.DEFINE_integer("c_dim", 3, "Dimension of image color. [3]")
flags.DEFINE_integer("sample_step", 500, "The interval of generating sample. [500]")
flags.DEFINE_integer("save_step", 200, "The interval of saveing checkpoints. [200]")
flags.DEFINE_string("dataset_dir", "horse2zebra", "The name of dataset [horse2zebra, apple2orange, sunflower2daisy]")
flags.DEFINE_string("checkpoint_dir", "data/Models", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "data/samples", "Directory name to save the image samples [samples]")
flags.DEFINE_string("dataset_dir", "horse2zebra", "The name of dataset [horse2zebra, apple2orange, sunflower2daisy and etc]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_string("direction", "forward", "The direction of generator [forward, backward]")
flags.DEFINE_string("test_dir", "./test", "The direction of test")
flags.DEFINE_boolean("is_train", True, "True for training, False for testing [False]")
Expand All @@ -50,7 +50,7 @@ def train_cyclegan():
ni = int(np.sqrt(FLAGS.batch_size))
h, w = 256, 256

# data augmentation
## data augmentation
def prepro(x):
x = tl.prepro.flip_axis(x, axis=1, is_random=True)
x = tl.prepro.rotation(x, rg=16, is_random=True, fill_mode='nearest')
Expand All @@ -60,65 +60,64 @@ def prepro(x):
x = x - 1.
return x

real_A = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, FLAGS.c_dim],
def rescale(x):
x = x / (255. / 2.)
x = x - 1.
return x

real_A = tf.placeholder(tf.float32, [None, FLAGS.image_size, FLAGS.image_size, FLAGS.c_dim],
name='real_A')
real_B = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, FLAGS.c_dim],
real_B = tf.placeholder(tf.float32, [None, FLAGS.image_size, FLAGS.image_size, FLAGS.c_dim],
name='real_B')

fake_A_pool = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, FLAGS.c_dim],
name='fake_A')
name='fake_A')
fake_B_pool = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, FLAGS.c_dim],
name='fake_B')
gen_B,gen_B_logits = cyclegan_generator_resnet(real_A, 9, is_train=True, reuse=False, name='gen_A2B')
gen_A,gen_A_logits = cyclegan_generator_resnet(real_B, 9, is_train=True, reuse=False, name='gen_B2A')
cyc_B,cyc_B_logits = cyclegan_generator_resnet(gen_A_logits, 9, is_train=True, reuse=True, name='gen_A2B')
cyc_A,cyc_A_logits = cyclegan_generator_resnet(gen_B_logits, 9, is_train=True, reuse=True, name='gen_B2A')
name='fake_B')

d_real_A,d_real_A_logits = cyclegan_discriminator_patch(real_A, is_train=True, reuse=False, name='dis_A') # dx
d_real_B,d_real_B_logits = cyclegan_discriminator_patch(real_B, is_train=True, reuse=False, name='dis_B') # dy
d_fake_A,d_fake_A_logits = cyclegan_discriminator_patch(gen_A_logits, is_train=True, reuse=True, name='dis_A') # d_fy
d_fake_B,d_fake_B_logits = cyclegan_discriminator_patch(gen_B_logits, is_train=True, reuse=True, name='dis_B') # d_gx
gen_B, gen_B_out = cyclegan_generator_resnet(real_A, 9, is_train=True, reuse=False, name='gen_A2B')
gen_A, gen_A_out = cyclegan_generator_resnet(real_B, 9, is_train=True, reuse=False, name='gen_B2A')
cyc_B, cyc_B_out = cyclegan_generator_resnet(gen_A_out, 9, is_train=True, reuse=True, name='gen_A2B')
cyc_A, cyc_A_out = cyclegan_generator_resnet(gen_B_out, 9, is_train=True, reuse=True, name='gen_B2A')

d_A_pool,d_A_pool_logits = cyclegan_discriminator_patch(fake_A_pool, is_train=True, reuse=True, name='dis_A') # d_fakex
d_B_pool,d_B_pool_logits = cyclegan_discriminator_patch(fake_B_pool, is_train=True, reuse=True, name='dis_B') # d_fakey
d_real_A, d_real_A_logits = cyclegan_discriminator_patch(real_A, is_train=True, reuse=False, name='dis_A') # dx
d_real_B, d_real_B_logits = cyclegan_discriminator_patch(real_B, is_train=True, reuse=False, name='dis_B') # dy
d_fake_A, d_fake_A_logits = cyclegan_discriminator_patch(gen_A_out, is_train=True, reuse=True, name='dis_A') # d_fy
d_fake_B, d_fake_B_logits = cyclegan_discriminator_patch(gen_B_out, is_train=True, reuse=True, name='dis_B') # d_gx

## test inference
gen_B_test,gen_B_test_logits = cyclegan_generator_resnet(real_A, 9, is_train=False, reuse=True, name='gen_A2B')
gen_A_test,gen_A_test_logits = cyclegan_generator_resnet(real_B, 9, is_train=False, reuse=True, name='gen_B2A')
d_A_pool, d_A_pool_logits = cyclegan_discriminator_patch(fake_A_pool, is_train=True, reuse=True, name='dis_A') # d_fakex
d_B_pool, d_B_pool_logits = cyclegan_discriminator_patch(fake_B_pool, is_train=True, reuse=True, name='dis_B') # d_fakey

# calculate cycle loss
cyc_loss = tf.reduce_mean(tf.abs(cyc_A_logits - real_A)) + tf.reduce_mean(tf.abs(cyc_B_logits - real_B))
# cyc_loss = tf.reduce_mean(tf.reduce_mean(tf.abs(cyc_A - real_A), [1, 2, 3])) + tf.reduce_mean(
# tf.reduce_mean(tf.abs(cyc_B - real_B), [1, 2, 3]))
## test inference
# gen_B_test, gen_B_test_logits = cyclegan_generator_resnet(real_A, 9, is_train=False, reuse=True, name='gen_A2B')
# gen_A_test, gen_A_test_logits = cyclegan_generator_resnet(real_B, 9, is_train=False, reuse=True, name='gen_B2A')

# calculate adversial loss
g_loss_A2B = tf.reduce_mean(tf.squared_difference(d_fake_B_logits, tf.ones_like(d_fake_B_logits)), name='g_loss_b')
## calculate cycle loss
cyc_loss = tf.reduce_mean(tf.abs(cyc_A_out - real_A)) + tf.reduce_mean(tf.abs(cyc_B_out - real_B))
# cyc_loss = tf.reduce_mean(tf.reduce_mean(tf.abs(cyc_A - real_A), [1, 2, 3])) + tf.reduce_mean(
# tf.reduce_mean(tf.abs(cyc_B - real_B), [1, 2, 3]))

# g_loss_A2B = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(d_fake_B, tf.ones_like(d_fake_B)), [1, 2, 3]),
# name='g_loss_b')
## calculate adversial loss
g_loss_A2B = tf.reduce_mean(tf.squared_difference(d_fake_B_logits, tf.ones_like(d_fake_B_logits)), name='g_loss_b')
# g_loss_A2B = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(d_fake_B, tf.ones_like(d_fake_B)), [1, 2, 3]), name='g_loss_b')

g_loss_B2A = tf.reduce_mean(tf.squared_difference(d_fake_A_logits, tf.ones_like(d_fake_A_logits)),name='g_loss_a')
g_loss_B2A = tf.reduce_mean(tf.squared_difference(d_fake_A_logits, tf.ones_like(d_fake_A_logits)),name='g_loss_a')
# g_loss_B2A = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(d_fake_A, tf.ones_like(d_fake_A)), [1, 2, 3]), name='g_loss_a')

# g_loss_B2A = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(d_fake_A, tf.ones_like(d_fake_A)), [1, 2, 3]),
# name='g_loss_a')

# calculate totalloss of generator
g_a2b_loss = lamda * cyc_loss + g_loss_A2B # forward
## calculate totalloss of generator
g_a2b_loss = lamda * cyc_loss + g_loss_A2B # forward
g_b2a_loss = lamda * cyc_loss + g_loss_B2A # backward

# calculate discriminator loss
# d_a_loss = (tf.reduce_mean(
# tf.reduce_mean(tf.squared_difference(d_real_A, tf.ones_like(d_real_A)), [1, 2, 3])) + tf.reduce_mean(
# tf.reduce_mean(tf.square(d_fake_A), [1, 2, 3]))) / 2.0
## calculate discriminator loss
d_a_loss = (tf.reduce_mean(tf.squared_difference(d_real_A_logits, tf.ones_like(d_real_A_logits))) + tf.reduce_mean(tf.square(d_fake_A_logits))) / 2.0

# d_a_loss = (tf.reduce_mean(
# tf.reduce_mean(tf.squared_difference(d_real_A, tf.ones_like(d_real_A)), [1, 2, 3])) + tf.reduce_mean(
# tf.reduce_mean(tf.square(d_fake_A), [1, 2, 3]))) / 2.0
d_b_loss = (tf.reduce_mean(tf.squared_difference(d_real_B_logits, tf.ones_like(d_real_B_logits))) + tf.reduce_mean(tf.square(d_fake_B_logits))) / 2.0
# d_b_loss = (tf.reduce_mean(
# tf.reduce_mean(tf.squared_difference(d_real_B, tf.ones_like(d_real_B)), [1, 2, 3])) + tf.reduce_mean(
# tf.reduce_mean(tf.square(d_fake_B), [1, 2, 3]))) / 2.0

# d_b_loss = (tf.reduce_mean(
# tf.reduce_mean(tf.squared_difference(d_real_B, tf.ones_like(d_real_B)), [1, 2, 3])) + tf.reduce_mean(
# tf.reduce_mean(tf.square(d_fake_B), [1, 2, 3]))) / 2.0

t_vars = tf.trainable_variables()
# t_vars = tf.trainable_variables()

g_A2B_vars = tl.layers.get_variables_with_name('gen_A2B', True, True)
g_B2A_vars = tl.layers.get_variables_with_name('gen_B2A', True, True)
Expand All @@ -134,7 +133,6 @@ def prepro(x):
d_b_optim = tf.train.AdamOptimizer(lr_v, beta1=FLAGS.beta1).minimize(d_b_loss, var_list=d_B_vars)

## init params

tl.layers.initialize_global_variables(sess)

net_g_A2B_name = os.path.join(FLAGS.checkpoint_dir, '{}_net_g_A2B.npz'.format(FLAGS.dataset_dir))
Expand All @@ -153,13 +151,19 @@ def prepro(x):

dataA, dataB, im_test_A, im_test_B = tl.files.load_cyclegan_dataset(filename=FLAGS.dataset_dir, path='datasets')

sample_A = np.asarray(im_test_A[0: 16])
sample_B = np.asarray(im_test_B[0: 16])
sample_A = tl.prepro.threading_data(sample_A, fn=rescale)
sample_B = tl.prepro.threading_data(sample_B, fn=rescale)

tl.vis.save_images(sample_A, [4, 4], './{}/sample_A.jpg'.format(FLAGS.sample_dir))
tl.vis.save_images(sample_B, [4, 4], './{}/sample_B.jpg'.format(FLAGS.sample_dir))

shuffle(dataA)
shuffle(dataB)

for epoch in range(FLAGS.epoch):

## change learning rate

if epoch >= 100:
new_lr = FLAGS.learning_rate - FLAGS.learning_rate * (epoch - 100) / 100
sess.run(tf.assign(lr_v, new_lr))
Expand All @@ -170,35 +174,35 @@ def prepro(x):
batch_imgA = tl.prepro.threading_data(dataA[idx * FLAGS.batch_size:(idx + 1) * FLAGS.batch_size], fn=prepro)
batch_imgB = tl.prepro.threading_data(dataB[idx * FLAGS.batch_size:(idx + 1) * FLAGS.batch_size], fn=prepro)

gen_A_temp_logits,gen_B_temp_logits = sess.run([gen_A_logits,gen_B_logits], feed_dict={real_A: batch_imgA, real_B: batch_imgB})

# update forward network
_, errGA2B = sess.run([g_a2b_optim, g_a2b_loss], feed_dict={real_A: batch_imgA, real_B: batch_imgB})
# update DB network
_, errDB = sess.run([d_b_optim, d_b_loss],
feed_dict={real_A: batch_imgA, real_B: batch_imgB, fake_B_pool: gen_B_temp_logits})
# update (backword) network
gen_A_temp_out, gen_B_temp_out = sess.run([gen_A_out, gen_B_out],
feed_dict={real_A: batch_imgA, real_B: batch_imgB})

_, errGB2A = sess.run([g_b2a_optim, g_b2a_loss], feed_dict={real_A: batch_imgB, real_B: batch_imgB})
# update DA network
## update forward network
_, errGA2B = sess.run([g_a2b_optim, g_a2b_loss],
feed_dict={real_A: batch_imgA, real_B: batch_imgB})
## update DB network
_, errDB = sess.run([d_b_optim, d_b_loss],
feed_dict={real_A: batch_imgA, real_B: batch_imgB, fake_B_pool: gen_B_temp_out})
## update (backword) network
_, errGB2A = sess.run([g_b2a_optim, g_b2a_loss],
feed_dict={real_A: batch_imgB, real_B: batch_imgB})
## update DA network
_, errDA = sess.run([d_a_optim, d_a_loss],
feed_dict={real_A: batch_imgA, real_B: batch_imgB, fake_A_pool: gen_A_temp_logits})
feed_dict={real_A: batch_imgA, real_B: batch_imgB, fake_A_pool: gen_A_temp_out})

print(
"Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_a_loss: %.8f, d_b_loss: %.8f,g_a2b_loss: %.8f,g_b2a_loss: %.8f" \
% (epoch, FLAGS.epoch, idx, batch_idxs, time.time() - start_time, errDA, errDB, errGA2B, errGB2A))
print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4fs, d_a_loss: %.8f, d_b_loss: %.8f, g_a2b_loss: %.8f, g_b2a_loss: %.8f" \
% (epoch, FLAGS.epoch, idx, batch_idxs, time.time() - start_time, errDA, errDB, errGA2B, errGB2A))

iter_counter += 1
num_fake += 1

if np.mod(iter_counter, 100) == 1:

sample_gen_A_logits,sample_gen_B_logits = sess.run([gen_A_logits, gen_B_logits],
feed_dict={real_A: batch_imgA, real_B: batch_imgB})
tl.vis.save_images(sample_gen_A_logits, [ni, ni],
if np.mod(iter_counter, 500) == 0:
oA, oB = sess.run([gen_A_out, gen_B_out],
feed_dict={real_A: sample_A, real_B: sample_B})
tl.vis.save_images(oA, [4, 4],
'./{}/B2A_{:02d}_{:04d}.jpg'.format(FLAGS.sample_dir, epoch, idx))
print("save image gen_A, Epoch: %2d idx:%4d" % (epoch, idx))
tl.vis.save_images(sample_gen_B_logits, [ni, ni],
tl.vis.save_images(oB, [4, 4],
'./{}/A2B_{:02d}_{:04d}.jpg'.format(FLAGS.sample_dir, epoch, idx))
print("save image gen_B, Epoch: %2d idx:%4d" % (epoch, idx))

Expand All @@ -218,15 +222,14 @@ def pro(x):
x = x - 1.
return x


test_A = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, FLAGS.c_dim],
name='test_x')
test_B = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, FLAGS.c_dim],
name='test_y')
# testB = cyclegan_generator_resnet(test_A, options, True, name="gen_forward")
# testA = cyclegan_generator_resnet(test_B, options, True, name="gen_backward")
test_gen_A2B,test_gen_A2B_logits = cyclegan_generator_resnet(test_A, 9, is_train=False, reuse=False, name='gen_A2B')
test_gen_B2A,test_gen_B2A_logits = cyclegan_generator_resnet(test_B, 9, is_train=False, reuse=False, name='gen_B2A')
test_gen_A2B, test_gen_A2B_logits = cyclegan_generator_resnet(test_A, 9, is_train=False, reuse=False, name='gen_A2B')
test_gen_B2A, test_gen_B2A_logits = cyclegan_generator_resnet(test_B, 9, is_train=False, reuse=False, name='gen_B2A')

out_var, in_var = (test_B, test_A) if FLAGS.direction == 'forward' else (test_A, test_B)

Expand Down Expand Up @@ -274,6 +277,7 @@ def main(_):
# elif args.phase == 'test':
# test_cyclegan()
train_cyclegan()
# test_cyclegan()


if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions model_upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def cyclegan_generator_resnet(image, num=9, is_train=True, reuse=False, batch_si

net_r9 = n
# net_d1 = DeConv2d(net_r9, gf_dim * 2, (3, 3), out_size=(128,128),
# strides=(2, 2), padding='SAME', batch_size=batch_size, act=None, name='u64') #
# strides=(2, 2), padding='SAME', batch_size=batch_size, act=None, name='u64') #

size_d1 = net_r9.outputs.get_shape().as_list()
net_up1 = UpSampling2dLayer(net_r9, size=[size_d1[1] * 2, size_d1[2] * 2], is_scale=False, method=1,
Expand Down Expand Up @@ -75,8 +75,8 @@ def cyclegan_discriminator_patch(inputs, is_train=True, reuse=False, name='discr
with tf.variable_scope(name, reuse=reuse):
tl.layers.set_name_reuse(reuse)

patch_inputs = tf.random_crop(inputs, [1, 70, 70, 3])
net_in = InputLayer(patch_inputs, name='d/in')
# patch_inputs = tf.random_crop(inputs, [1, 70, 70, 3])
net_in = InputLayer(inputs, name='d/in')
# 1st
net_h0 = Conv2d(net_in, df_dim, (4, 4), (2, 2), act=lrelu,
padding='SAME', W_init=w_init, name='d/h0/conv2d') # C64
Expand Down
Binary file added tensorlayer/__pycache__/__init__.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/activation.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/cost.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/files.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/iterate.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/layers.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/nlp.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/ops.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/prepro.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/rein.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/utils.cpython-35.pyc
Binary file not shown.
Binary file added tensorlayer/__pycache__/visualize.cpython-35.pyc
Binary file not shown.
11 changes: 11 additions & 0 deletions tensorlayer/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,17 @@ def load_image_from_folder(path):
im_test_A = load_image_from_folder(path+"/"+filename+"/testA")
im_test_B = load_image_from_folder(path+"/"+filename+"/testB")

def if_2d_to_3d(images): # [h, w] --> [h, w, 3]
for i in range(len(images)):
if len(images[i].shape) == 2:
images[i] = images[i][:, :, np.newaxis]
images[i] = np.tile(images[i], (1, 1, 3))
return images

im_train_A = if_2d_to_3d(im_train_A)
im_train_B = if_2d_to_3d(im_train_B)
im_test_A = if_2d_to_3d(im_test_A)
im_test_B = if_2d_to_3d(im_test_B)
return im_train_A, im_train_B, im_test_A, im_test_B


Expand Down
5 changes: 4 additions & 1 deletion tensorlayer/prepro.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def apply_fn(results, i, data, kwargs):
t.join()

if thread_count is None:
return np.asarray(results)
try:
return np.asarray(results)
except:
return results
else:
return np.concatenate(results)

Expand Down
1 change: 1 addition & 0 deletions tensorlayer/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def read_images(img_list, path='', n_threads=10, printable=True):
b_imgs_list = img_list[idx : idx + n_threads]
b_imgs = prepro.threading_data(b_imgs_list, fn=read_image, path=path)
# print(b_imgs.shape)
# exit()
imgs.extend(b_imgs)
if printable:
print('read %d from %s' % (len(imgs), path))
Expand Down

0 comments on commit 4ae0c78

Please sign in to comment.