From 7a2fd0c56cbb7433269f009e6bda783f08821d18 Mon Sep 17 00:00:00 2001 From: 1FengL Date: Fri, 12 Apr 2019 09:57:36 +0100 Subject: [PATCH 1/7] create STN tutorial in static mode --- ...ial_spatial_transformer_network_dynamic.py | 161 ++++++++++++++++++ ...rial_spatial_transformer_network_static.py | 161 ++++++++++++++++++ 2 files changed, 322 insertions(+) create mode 100644 examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py create mode 100644 examples/spatial_transformer_network/tutorial_spatial_transformer_network_static.py diff --git a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py new file mode 100644 index 000000000..dfc615fc8 --- /dev/null +++ b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py @@ -0,0 +1,161 @@ +#! /usr/bin/python +# -*- coding: utf8 -*- +import time +import numpy as np +import tensorflow as tf +import tensorlayer as tl +from tensorlayer.layers import * +from tensorlayer.models import Model + +##================== PREPARE DATA ============================================## +X_train, y_train, X_val, y_val, X_test, y_test = \ + tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) + +def pad_distort_im_fn(x): + """ Zero pads an image to 40x40, and distort it. + + Examples + --------- + x = pad_distort_im_fn(X_train[0]) + print(x, x.shape, x.max()) + tl.vis.save_image(x, '_xd.png') + tl.vis.save_image(X_train[0], '_x.png') + """ + b = np.zeros((40, 40, 1), dtype=np.float32) + o = int((40 - 28) / 2) + b[o:o + 28, o:o + 28] = x + x = b + x = tl.prepro.rotation(x, rg=30, is_random=True, fill_mode='constant') + x = tl.prepro.shear(x, 0.05, is_random=True, fill_mode='constant') + x = tl.prepro.shift(x, wrg=0.25, hrg=0.25, is_random=True, fill_mode='constant') + x = tl.prepro.zoom(x, zoom_range=(0.95, 1.05)) + return x + + +def pad_distort_ims_fn(X): + """ Zero pads images to 40x40, and distort them. """ + X_40 = [] + for X_a, _ in tl.iterate.minibatches(X, X, 50, shuffle=False): + X_40.extend(tl.prepro.threading_data(X_a, fn=pad_distort_im_fn)) + X_40 = np.asarray(X_40) + return X_40 + + +# create dataset with size of 40x40 with distortion +X_train_40 = pad_distort_ims_fn(X_train) +X_val_40 = pad_distort_ims_fn(X_val) +X_test_40 = pad_distort_ims_fn(X_test) + +tl.vis.save_images(X_test[0:32], [4, 8], '_imgs_original.png') +tl.vis.save_images(X_test_40[0:32], [4, 8], '_imgs_distorted.png') + + +##================== DEFINE MODEL ============================================## +def get_model(inputs_shape): + ni = Input(inputs_shape) + + ## 1. Localisation network + # use MLP as the localisation net + nn = Flatten()(ni) + nn = Dense(n_units=20, act=tf.nn.tanh)(nn) + nn = Dropout(keep=0.8)(nn) + # you can also use CNN instead for MLP as the localisation net + + ## 2. Spatial transformer module (sampler) + stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20) + s = stn((nn, ni)) + nn = stn((nn, ni)) + + ## 3. Classifier + nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn) + nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn) + nn = Flatten()(nn) + nn = Dense(n_units=1024, act=tf.nn.relu)(nn) + nn = Dense(n_units=10, act=tf.identity)(nn) + + M = Model(inputs=ni, outputs=[nn, s]) + return M + + +net = get_model([None, 40, 40, 1]) + +##================== DEFINE TRAIN OPS ========================================## +n_epoch = 100 +learning_rate = 0.0001 +print_freq = 10 +batch_size = 64 +train_weights = net.weights +optimizer = tf.optimizers.Adam(lr=learning_rate) + +##================== TRAINING ================================================## +print("Training ...") +for epoch in range(n_epoch): + start_time = time.time() + + net.train() # enable dropout + + for X_train_a, y_train_a in tl.iterate.minibatches(X_train_40, y_train, batch_size, shuffle=True): + # input_dim must be of length 4 + X_train_a = tf.expand_dims(X_train_a, 3) + + with tf.GradientTape() as tape: + ## compute outputs + _logits, _ = net(X_train_a) # alternatively, you can use MLP(x, is_train=True) and remove MLP.train() + ## compute loss and update model + _loss = tl.cost.cross_entropy(_logits, y_train_a, name='train_loss') + + grad = tape.gradient(_loss, train_weights) + optimizer.apply_gradients(zip(grad, train_weights)) + + ## use training and evaluation sets to evaluate the model every print_freq epoch + if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: + + net.eval() # disable dropout + + print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) + + train_loss, train_acc, n_iter = 0, 0, 0 + for X_train_a, y_train_a in tl.iterate.minibatches(X_train_40, y_train, batch_size, shuffle=False): + # input_dim must be of length 4 + X_train_a = tf.expand_dims(X_train_a, 3) + + _logits, _ = net(X_train_a) # alternatively, you can use MLP(x, is_train=False) and remove MLP.eval() + train_loss += tl.cost.cross_entropy(_logits, y_train_a, name='eval_train_loss') + train_acc += np.mean(np.equal(np.argmax(_logits, 1), y_train_a)) + n_iter += 1 + print(" train loss: %f" % (train_loss / n_iter)) + print(" train acc: %f" % (train_acc / n_iter)) + + val_loss, val_acc, n_iter = 0, 0, 0 + for X_val_a, y_val_a in tl.iterate.minibatches(X_val_40, y_val, batch_size, shuffle=False): + # input_dim must be of length 4 + X_val_a = tf.expand_dims(X_val_a, 3) + + _logits, _ = net(X_val_a) # is_train=False, disable dropout + val_loss += tl.cost.cross_entropy(_logits, y_val_a, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_val_a)) + n_iter += 1 + print(" val loss: %f" % (val_loss / n_iter)) + print(" val acc: %f" % (val_acc / n_iter)) + + print('save images') + _, trans_imgs = net(tf.expand_dims(X_test_40[0:64], 3)) + trans_imgs = trans_imgs.numpy() + tl.vis.save_images(trans_imgs[0:32], [4, 8], '_imgs_distorted_after_stn_%s.png' % epoch) + +##================== EVALUATION ==============================================## +print('Evaluation') + +net.eval() + +test_loss, test_acc, n_iter = 0, 0, 0 +for X_test_a, y_test_a in tl.iterate.minibatches(X_test_40, y_test, batch_size, shuffle=False): + # input_dim must be of length 4 + X_test_a = tf.expand_dims(X_test_a, 3) + + _logits, _ = net(X_test_a) + test_loss += tl.cost.cross_entropy(_logits, y_test_a, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_test_a)) + n_iter += 1 +print(" test loss: %f" % (test_loss / n_iter)) +print(" test acc: %f" % (test_acc / n_iter)) diff --git a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_static.py b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_static.py new file mode 100644 index 000000000..dfc615fc8 --- /dev/null +++ b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_static.py @@ -0,0 +1,161 @@ +#! /usr/bin/python +# -*- coding: utf8 -*- +import time +import numpy as np +import tensorflow as tf +import tensorlayer as tl +from tensorlayer.layers import * +from tensorlayer.models import Model + +##================== PREPARE DATA ============================================## +X_train, y_train, X_val, y_val, X_test, y_test = \ + tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) + +def pad_distort_im_fn(x): + """ Zero pads an image to 40x40, and distort it. + + Examples + --------- + x = pad_distort_im_fn(X_train[0]) + print(x, x.shape, x.max()) + tl.vis.save_image(x, '_xd.png') + tl.vis.save_image(X_train[0], '_x.png') + """ + b = np.zeros((40, 40, 1), dtype=np.float32) + o = int((40 - 28) / 2) + b[o:o + 28, o:o + 28] = x + x = b + x = tl.prepro.rotation(x, rg=30, is_random=True, fill_mode='constant') + x = tl.prepro.shear(x, 0.05, is_random=True, fill_mode='constant') + x = tl.prepro.shift(x, wrg=0.25, hrg=0.25, is_random=True, fill_mode='constant') + x = tl.prepro.zoom(x, zoom_range=(0.95, 1.05)) + return x + + +def pad_distort_ims_fn(X): + """ Zero pads images to 40x40, and distort them. """ + X_40 = [] + for X_a, _ in tl.iterate.minibatches(X, X, 50, shuffle=False): + X_40.extend(tl.prepro.threading_data(X_a, fn=pad_distort_im_fn)) + X_40 = np.asarray(X_40) + return X_40 + + +# create dataset with size of 40x40 with distortion +X_train_40 = pad_distort_ims_fn(X_train) +X_val_40 = pad_distort_ims_fn(X_val) +X_test_40 = pad_distort_ims_fn(X_test) + +tl.vis.save_images(X_test[0:32], [4, 8], '_imgs_original.png') +tl.vis.save_images(X_test_40[0:32], [4, 8], '_imgs_distorted.png') + + +##================== DEFINE MODEL ============================================## +def get_model(inputs_shape): + ni = Input(inputs_shape) + + ## 1. Localisation network + # use MLP as the localisation net + nn = Flatten()(ni) + nn = Dense(n_units=20, act=tf.nn.tanh)(nn) + nn = Dropout(keep=0.8)(nn) + # you can also use CNN instead for MLP as the localisation net + + ## 2. Spatial transformer module (sampler) + stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20) + s = stn((nn, ni)) + nn = stn((nn, ni)) + + ## 3. Classifier + nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn) + nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn) + nn = Flatten()(nn) + nn = Dense(n_units=1024, act=tf.nn.relu)(nn) + nn = Dense(n_units=10, act=tf.identity)(nn) + + M = Model(inputs=ni, outputs=[nn, s]) + return M + + +net = get_model([None, 40, 40, 1]) + +##================== DEFINE TRAIN OPS ========================================## +n_epoch = 100 +learning_rate = 0.0001 +print_freq = 10 +batch_size = 64 +train_weights = net.weights +optimizer = tf.optimizers.Adam(lr=learning_rate) + +##================== TRAINING ================================================## +print("Training ...") +for epoch in range(n_epoch): + start_time = time.time() + + net.train() # enable dropout + + for X_train_a, y_train_a in tl.iterate.minibatches(X_train_40, y_train, batch_size, shuffle=True): + # input_dim must be of length 4 + X_train_a = tf.expand_dims(X_train_a, 3) + + with tf.GradientTape() as tape: + ## compute outputs + _logits, _ = net(X_train_a) # alternatively, you can use MLP(x, is_train=True) and remove MLP.train() + ## compute loss and update model + _loss = tl.cost.cross_entropy(_logits, y_train_a, name='train_loss') + + grad = tape.gradient(_loss, train_weights) + optimizer.apply_gradients(zip(grad, train_weights)) + + ## use training and evaluation sets to evaluate the model every print_freq epoch + if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: + + net.eval() # disable dropout + + print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) + + train_loss, train_acc, n_iter = 0, 0, 0 + for X_train_a, y_train_a in tl.iterate.minibatches(X_train_40, y_train, batch_size, shuffle=False): + # input_dim must be of length 4 + X_train_a = tf.expand_dims(X_train_a, 3) + + _logits, _ = net(X_train_a) # alternatively, you can use MLP(x, is_train=False) and remove MLP.eval() + train_loss += tl.cost.cross_entropy(_logits, y_train_a, name='eval_train_loss') + train_acc += np.mean(np.equal(np.argmax(_logits, 1), y_train_a)) + n_iter += 1 + print(" train loss: %f" % (train_loss / n_iter)) + print(" train acc: %f" % (train_acc / n_iter)) + + val_loss, val_acc, n_iter = 0, 0, 0 + for X_val_a, y_val_a in tl.iterate.minibatches(X_val_40, y_val, batch_size, shuffle=False): + # input_dim must be of length 4 + X_val_a = tf.expand_dims(X_val_a, 3) + + _logits, _ = net(X_val_a) # is_train=False, disable dropout + val_loss += tl.cost.cross_entropy(_logits, y_val_a, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_val_a)) + n_iter += 1 + print(" val loss: %f" % (val_loss / n_iter)) + print(" val acc: %f" % (val_acc / n_iter)) + + print('save images') + _, trans_imgs = net(tf.expand_dims(X_test_40[0:64], 3)) + trans_imgs = trans_imgs.numpy() + tl.vis.save_images(trans_imgs[0:32], [4, 8], '_imgs_distorted_after_stn_%s.png' % epoch) + +##================== EVALUATION ==============================================## +print('Evaluation') + +net.eval() + +test_loss, test_acc, n_iter = 0, 0, 0 +for X_test_a, y_test_a in tl.iterate.minibatches(X_test_40, y_test, batch_size, shuffle=False): + # input_dim must be of length 4 + X_test_a = tf.expand_dims(X_test_a, 3) + + _logits, _ = net(X_test_a) + test_loss += tl.cost.cross_entropy(_logits, y_test_a, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_test_a)) + n_iter += 1 +print(" test loss: %f" % (test_loss / n_iter)) +print(" test acc: %f" % (test_acc / n_iter)) From dec72eafb6c79ca79071453fa66f97581c31e686 Mon Sep 17 00:00:00 2001 From: 1FengL Date: Fri, 12 Apr 2019 09:58:02 +0100 Subject: [PATCH 2/7] create STN tutorial in dynamic mode --- ...ial_spatial_transformer_network_dynamic.py | 48 ++++++++++--------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py index dfc615fc8..e89767a6a 100644 --- a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py +++ b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py @@ -11,6 +11,7 @@ X_train, y_train, X_val, y_val, X_test, y_test = \ tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) + def pad_distort_im_fn(x): """ Zero pads an image to 40x40, and distort it. @@ -51,33 +52,36 @@ def pad_distort_ims_fn(X): ##================== DEFINE MODEL ============================================## -def get_model(inputs_shape): - ni = Input(inputs_shape) +class Net(Model): + def __init__(self): + super(Net, self).__init__() - ## 1. Localisation network - # use MLP as the localisation net - nn = Flatten()(ni) - nn = Dense(n_units=20, act=tf.nn.tanh)(nn) - nn = Dropout(keep=0.8)(nn) - # you can also use CNN instead for MLP as the localisation net + ## 1. Localisation network + # use MLP as the localisation net + self.flatten1 = Flatten() + self.dense1 = Dense(n_units=20, in_channels=1600, act=tf.nn.tanh) + self.dropout1 = Dropout(keep=0.8) + # you can also use CNN instead for MLP as the localisation net - ## 2. Spatial transformer module (sampler) - stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20) - s = stn((nn, ni)) - nn = stn((nn, ni)) + ## 2. Spatial transformer module (sampler) + self.stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20) + stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20) - ## 3. Classifier - nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn) - nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn) - nn = Flatten()(nn) - nn = Dense(n_units=1024, act=tf.nn.relu)(nn) - nn = Dense(n_units=10, act=tf.identity)(nn) + ## 3. Classifier + self.conv1 = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', in_channels=1) + self.conv2 = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', in_channels=16) + self.flatten2 = Flatten() + self.dense2 = Dense(n_units=1024, in_channels=1600, act=tf.nn.relu) + self.dense3 = Dense(n_units=10, in_channels=1024, act=tf.identity) - M = Model(inputs=ni, outputs=[nn, s]) - return M + def forward(self, inputs): + theta_input = self.dropout1(self.dense1(self.flatten1(inputs))) + V = self.stn((theta_input, inputs)) + _logits = self.dense3(self.dense2(self.flatten2(self.conv2(self.conv1(V))))) + return _logits, V -net = get_model([None, 40, 40, 1]) +net = Net() ##================== DEFINE TRAIN OPS ========================================## n_epoch = 100 @@ -120,7 +124,7 @@ def get_model(inputs_shape): X_train_a = tf.expand_dims(X_train_a, 3) _logits, _ = net(X_train_a) # alternatively, you can use MLP(x, is_train=False) and remove MLP.eval() - train_loss += tl.cost.cross_entropy(_logits, y_train_a, name='eval_train_loss') + train_loss += tl.cost.cross_entropy(_logits, y_train_a, name='eval_train_loss') train_acc += np.mean(np.equal(np.argmax(_logits, 1), y_train_a)) n_iter += 1 print(" train loss: %f" % (train_loss / n_iter)) From 73b462c254a1538be1daa58e4a81e73ecf68a99b Mon Sep 17 00:00:00 2001 From: 1FengL Date: Fri, 12 Apr 2019 10:03:22 +0100 Subject: [PATCH 3/7] refactor SpatialTransformer2dAffine as TL2.0 Layer --- tensorlayer/layers/spatial_transformer.py | 331 ++++++++++------------ 1 file changed, 152 insertions(+), 179 deletions(-) diff --git a/tensorlayer/layers/spatial_transformer.py b/tensorlayer/layers/spatial_transformer.py index ede33dd95..fe6d09534 100644 --- a/tensorlayer/layers/spatial_transformer.py +++ b/tensorlayer/layers/spatial_transformer.py @@ -3,6 +3,7 @@ import numpy as np import tensorflow as tf +import tensorlayer as tl from six.moves import xrange from tensorflow.python.ops import array_ops @@ -60,129 +61,124 @@ def transformer(U, theta, out_size, name='SpatialTransformer2dAffine'): """ def _repeat(x, n_repeats): - with tf.compat.v1.variable_scope('_repeat'): - rep = tf.transpose(a=tf.expand_dims(tf.ones(shape=tf.stack([ - n_repeats, - ])), 1), perm=[1, 0]) - rep = tf.cast(rep, 'int32') - x = tf.matmul(tf.reshape(x, (-1, 1)), rep) - return tf.reshape(x, [-1]) + rep = tf.transpose(a=tf.expand_dims(tf.ones(shape=tf.stack([ + n_repeats, + ])), 1), perm=[1, 0]) + rep = tf.cast(rep, 'int32') + x = tf.matmul(tf.reshape(x, (-1, 1)), rep) + return tf.reshape(x, [-1]) def _interpolate(im, x, y, out_size): - with tf.compat.v1.variable_scope('_interpolate'): - # constants - num_batch = tf.shape(input=im)[0] - height = tf.shape(input=im)[1] - width = tf.shape(input=im)[2] - channels = tf.shape(input=im)[3] - - x = tf.cast(x, 'float32') - y = tf.cast(y, 'float32') - height_f = tf.cast(height, 'float32') - width_f = tf.cast(width, 'float32') - out_height = out_size[0] - out_width = out_size[1] - zero = tf.zeros([], dtype='int32') - max_y = tf.cast(tf.shape(input=im)[1] - 1, 'int32') - max_x = tf.cast(tf.shape(input=im)[2] - 1, 'int32') - - # scale indices from [-1, 1] to [0, width/height] - x = (x + 1.0) * (width_f) / 2.0 - y = (y + 1.0) * (height_f) / 2.0 - - # do sampling - x0 = tf.cast(tf.floor(x), 'int32') - x1 = x0 + 1 - y0 = tf.cast(tf.floor(y), 'int32') - y1 = y0 + 1 - - x0 = tf.clip_by_value(x0, zero, max_x) - x1 = tf.clip_by_value(x1, zero, max_x) - y0 = tf.clip_by_value(y0, zero, max_y) - y1 = tf.clip_by_value(y1, zero, max_y) - dim2 = width - dim1 = width * height - base = _repeat(tf.range(num_batch) * dim1, out_height * out_width) - base_y0 = base + y0 * dim2 - base_y1 = base + y1 * dim2 - idx_a = base_y0 + x0 - idx_b = base_y1 + x0 - idx_c = base_y0 + x1 - idx_d = base_y1 + x1 - - # use indices to lookup pixels in the flat image and restore - # channels dim - im_flat = tf.reshape(im, tf.stack([-1, channels])) - im_flat = tf.cast(im_flat, 'float32') - Ia = tf.gather(im_flat, idx_a) - Ib = tf.gather(im_flat, idx_b) - Ic = tf.gather(im_flat, idx_c) - Id = tf.gather(im_flat, idx_d) - - # and finally calculate interpolated values - x0_f = tf.cast(x0, 'float32') - x1_f = tf.cast(x1, 'float32') - y0_f = tf.cast(y0, 'float32') - y1_f = tf.cast(y1, 'float32') - wa = tf.expand_dims(((x1_f - x) * (y1_f - y)), 1) - wb = tf.expand_dims(((x1_f - x) * (y - y0_f)), 1) - wc = tf.expand_dims(((x - x0_f) * (y1_f - y)), 1) - wd = tf.expand_dims(((x - x0_f) * (y - y0_f)), 1) - output = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id]) - return output + # constants + num_batch = tf.shape(input=im)[0] + height = tf.shape(input=im)[1] + width = tf.shape(input=im)[2] + channels = tf.shape(input=im)[3] + + x = tf.cast(x, 'float32') + y = tf.cast(y, 'float32') + height_f = tf.cast(height, 'float32') + width_f = tf.cast(width, 'float32') + out_height = out_size[0] + out_width = out_size[1] + zero = tf.zeros([], dtype='int32') + max_y = tf.cast(tf.shape(input=im)[1] - 1, 'int32') + max_x = tf.cast(tf.shape(input=im)[2] - 1, 'int32') + + # scale indices from [-1, 1] to [0, width/height] + x = (x + 1.0) * (width_f) / 2.0 + y = (y + 1.0) * (height_f) / 2.0 + + # do sampling + x0 = tf.cast(tf.floor(x), 'int32') + x1 = x0 + 1 + y0 = tf.cast(tf.floor(y), 'int32') + y1 = y0 + 1 + + x0 = tf.clip_by_value(x0, zero, max_x) + x1 = tf.clip_by_value(x1, zero, max_x) + y0 = tf.clip_by_value(y0, zero, max_y) + y1 = tf.clip_by_value(y1, zero, max_y) + dim2 = width + dim1 = width * height + base = _repeat(tf.range(num_batch) * dim1, out_height * out_width) + base_y0 = base + y0 * dim2 + base_y1 = base + y1 * dim2 + idx_a = base_y0 + x0 + idx_b = base_y1 + x0 + idx_c = base_y0 + x1 + idx_d = base_y1 + x1 + + # use indices to lookup pixels in the flat image and restore + # channels dim + im_flat = tf.reshape(im, tf.stack([-1, channels])) + im_flat = tf.cast(im_flat, 'float32') + Ia = tf.gather(im_flat, idx_a) + Ib = tf.gather(im_flat, idx_b) + Ic = tf.gather(im_flat, idx_c) + Id = tf.gather(im_flat, idx_d) + + # and finally calculate interpolated values + x0_f = tf.cast(x0, 'float32') + x1_f = tf.cast(x1, 'float32') + y0_f = tf.cast(y0, 'float32') + y1_f = tf.cast(y1, 'float32') + wa = tf.expand_dims(((x1_f - x) * (y1_f - y)), 1) + wb = tf.expand_dims(((x1_f - x) * (y - y0_f)), 1) + wc = tf.expand_dims(((x - x0_f) * (y1_f - y)), 1) + wd = tf.expand_dims(((x - x0_f) * (y - y0_f)), 1) + output = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id]) + return output def _meshgrid(height, width): - with tf.compat.v1.variable_scope('_meshgrid'): - # This should be equivalent to: - # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), - # np.linspace(-1, 1, height)) - # ones = np.ones(np.prod(x_t.shape)) - # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) - x_t = tf.matmul( - tf.ones(shape=tf.stack([height, 1])), - tf.transpose(a=tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), perm=[1, 0]) - ) - y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1), tf.ones(shape=tf.stack([1, width]))) - - x_t_flat = tf.reshape(x_t, (1, -1)) - y_t_flat = tf.reshape(y_t, (1, -1)) - - ones = tf.ones_like(x_t_flat) - grid = tf.concat(axis=0, values=[x_t_flat, y_t_flat, ones]) - return grid + # This should be equivalent to: + # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), + # np.linspace(-1, 1, height)) + # ones = np.ones(np.prod(x_t.shape)) + # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) + x_t = tf.matmul( + tf.ones(shape=tf.stack([height, 1])), + tf.transpose(a=tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), perm=[1, 0]) + ) + y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1), tf.ones(shape=tf.stack([1, width]))) - def _transform(theta, input_dim, out_size): - with tf.compat.v1.variable_scope('_transform'): - num_batch = tf.shape(input=input_dim)[0] - num_channels = tf.shape(input=input_dim)[3] - theta = tf.reshape(theta, (-1, 2, 3)) - theta = tf.cast(theta, 'float32') - - # grid of (x_t, y_t, 1), eq (1) in ref [1] - out_height = out_size[0] - out_width = out_size[1] - grid = _meshgrid(out_height, out_width) - grid = tf.expand_dims(grid, 0) - grid = tf.reshape(grid, [-1]) - grid = tf.tile(grid, tf.stack([num_batch])) - grid = tf.reshape(grid, tf.stack([num_batch, 3, -1])) - - # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) - T_g = tf.matmul(theta, grid) - x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) - y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) - x_s_flat = tf.reshape(x_s, [-1]) - y_s_flat = tf.reshape(y_s, [-1]) - - input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat, out_size) - - output = tf.reshape(input_transformed, tf.stack([num_batch, out_height, out_width, num_channels])) - return output + x_t_flat = tf.reshape(x_t, (1, -1)) + y_t_flat = tf.reshape(y_t, (1, -1)) - with tf.compat.v1.variable_scope(name): - output = _transform(theta, U, out_size) + ones = tf.ones_like(x_t_flat) + grid = tf.concat(axis=0, values=[x_t_flat, y_t_flat, ones]) + return grid + + def _transform(theta, input_dim, out_size): + num_batch = tf.shape(input=input_dim)[0] + num_channels = tf.shape(input=input_dim)[3] + theta = tf.reshape(theta, (-1, 2, 3)) + theta = tf.cast(theta, 'float32') + + # grid of (x_t, y_t, 1), eq (1) in ref [1] + out_height = out_size[0] + out_width = out_size[1] + grid = _meshgrid(out_height, out_width) + grid = tf.expand_dims(grid, 0) + grid = tf.reshape(grid, [-1]) + grid = tf.tile(grid, tf.stack([num_batch])) + grid = tf.reshape(grid, tf.stack([num_batch, 3, -1])) + + # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) + T_g = tf.matmul(theta, grid) + x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) + y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) + x_s_flat = tf.reshape(x_s, [-1]) + y_s_flat = tf.reshape(y_s, [-1]) + + input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat, out_size) + + output = tf.reshape(input_transformed, tf.stack([num_batch, out_height, out_width, num_channels])) return output + output = _transform(theta, U, out_size) + return output + def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer2dAffine'): """Batch Spatial Transformer function for `2D Affine Transformation `__. @@ -234,76 +230,53 @@ class SpatialTransformer2dAffine(Layer): """ - @deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release def __init__( self, - prev_layer, - theta_layer, - out_size=None, - name='spatial_trans_2d_affine', + in_channels=None, + out_size=(40, 40), + name=None, ): + super(SpatialTransformer2dAffine, self).__init__(name) - super(SpatialTransformer2dAffine, self).__init__(prev_layer=[prev_layer, theta_layer], name=name) - - self.inputs = prev_layer.outputs # Do not remove - self.theta_layer = theta_layer + self.in_channels = in_channels + self.out_size = out_size - if out_size is None: - out_size = [40, 40] + if self.in_channels is not None: + self.build(self.in_channels) + self._built = True logging.info( - "SpatialTransformer2dAffine %s: in_size: %s out_size: %s" % - (self.name, self.inputs.get_shape().as_list(), out_size) + "SpatialTransformer2dAffine %s" % self.name ) - with tf.compat.v1.variable_scope(name) as vs: - - # 1. make the localisation network to [batch, 6] via Flatten and Dense. - if self.theta_layer.outputs.get_shape().ndims > 2: - self.theta_layer.outputs = flatten_reshape(self.theta_layer.outputs, 'flatten') - - # 2. To initialize the network to the identity transform init. - # 2.1 W - n_in = int(self.theta_layer.outputs.get_shape()[-1]) - shape = (n_in, 6) - - W = tf.compat.v1.get_variable(name='W', initializer=tf.zeros(shape), dtype=LayersConfig.tf_dtype) - # 2.2 b - - identity = tf.constant(np.array([[1., 0, 0], [0, 1., 0]]).astype('float32').flatten()) - - b = tf.compat.v1.get_variable(name='b', initializer=identity, dtype=LayersConfig.tf_dtype) - # 2.3 transformation matrix - - self.theta = tf.nn.tanh(tf.matmul(self.theta_layer.outputs, W) + b) - # 3. Spatial Transformer Sampling - # 3.1 transformation - - self.outputs = transformer(self.inputs, self.theta, out_size=out_size) - - # 3.2 automatically set batch_size and channels - # e.g. [?, 40, 40, ?] --> [64, 40, 40, 1] or [64, 20, 20, 4]/ Hao Dong - # - fixed_batch_size = self.inputs.get_shape().with_rank_at_least(1)[0] - - if fixed_batch_size.value: - batch_size = fixed_batch_size.value - - else: - batch_size = array_ops.shape(self.inputs)[0] - - n_channels = self.inputs.get_shape().as_list()[-1] - # logging.info(self.outputs) - self.outputs = tf.reshape(self.outputs, shape=[batch_size, out_size[0], out_size[1], n_channels]) - # logging.info(self.outputs) - # exit() - # 4. Get all parameters - variables = tf.compat.v1.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) - - # # theta_layer - # self._add_layers(theta_layer.all_layers) - # self._add_params(theta_layer.all_params) - # self.all_drop.update(theta_layer.all_drop) - - self._add_layers(self.outputs) - self._add_params(variables) + def __repr__(self): + s = '{classname}(out_size={out_size}, ' + if self.in_channels is not None: + s += 'in_channels=\'{in_channels}\'' + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs_shape): + if self.in_channels is None and len(inputs_shape) != 2: + raise AssertionError("The dimension of theta layer input must be rank 2, please reshape or flatten it") + if self.in_channels: + shape = [self.in_channels, 6] + else: + self.in_channels = inputs_shape[1] + shape = [inputs_shape[1], 6] + self.W = self._get_weights("weights", shape=tuple(shape), init=tl.initializers.Zeros()) + identity = tf.constant(np.array([[1., 0, 0], [0, 1., 0]]).astype('float32').flatten()) + self.b = self._get_weights("biases", shape=(6,), init=identity) + + def forward(self, inputs): + theta_input, U = inputs + theta = tf.nn.tanh(tf.matmul(theta_input, self.W) + self.b) + outputs = transformer(U, theta, out_size=self.out_size) + # automatically set batch_size and channels + # e.g. [?, 40, 40, ?] --> [64, 40, 40, 1] or [64, 20, 20, 4] + batch_size = theta_input.shape[0] + n_channels = U.shape[-1] + outputs = tf.reshape(outputs, shape=[batch_size, self.out_size[0], self.out_size[1], n_channels]) + return outputs From 951871f9ec014cb247025205b27b173df7bab241 Mon Sep 17 00:00:00 2001 From: 1FengL Date: Fri, 12 Apr 2019 10:03:43 +0100 Subject: [PATCH 4/7] fix a typo --- .../tutorial_spatial_transformer_network_dynamic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py index e89767a6a..e0db623fb 100644 --- a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py +++ b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py @@ -65,7 +65,6 @@ def __init__(self): ## 2. Spatial transformer module (sampler) self.stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20) - stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20) ## 3. Classifier self.conv1 = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', in_channels=1) From 87b638f653f83d820e3e7a029b256fab7611d383 Mon Sep 17 00:00:00 2001 From: 1FengL Date: Fri, 12 Apr 2019 10:05:26 +0100 Subject: [PATCH 5/7] enable _get_weights() to use tf.Tensor as initializer --- tensorlayer/layers/utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tensorlayer/layers/utils.py b/tensorlayer/layers/utils.py index 4a21fda4b..c6057b9b0 100644 --- a/tensorlayer/layers/utils.py +++ b/tensorlayer/layers/utils.py @@ -3,6 +3,7 @@ import numpy as np import tensorflow as tf +import tensorlayer as tl from tensorflow.python.ops.rnn_cell import LSTMStateTuple from tensorlayer import logging @@ -128,7 +129,7 @@ def get_layers_with_name(net, name="", verbose=False): return layers -def get_variable_with_initializer(scope_name, var_name, shape, init=tf.compat.v1.initializers.random_normal()): +def get_variable_with_initializer(scope_name, var_name, shape, init=tl.initializers.random_normal()): # FIXME: documentation needed # if tf.executing_eagerly(): var_name = scope_name + "/" + var_name @@ -138,7 +139,13 @@ def get_variable_with_initializer(scope_name, var_name, shape, init=tf.compat.v1 # initial_value = init()(shape=shape) # var = tf.Variable(initial_value=initial_value, name=var_name) # FIXME: not sure whether this is correct? - initial_value = init(shape=shape) + if isinstance(init, tf.Tensor): + if shape != init.shape: + raise ValueError('The shape of initial value: %s is not equal to the shape of variable: %s' + % (init.shape, shape)) + initial_value = init + else: + initial_value = init(shape=shape) var = tf.Variable(initial_value=initial_value, name=var_name) #, **init_args) # else: From 742ad699ea3f3d2b4563bc1ce66c3dfa0b727616 Mon Sep 17 00:00:00 2001 From: 1FengL Date: Fri, 12 Apr 2019 12:11:44 +0100 Subject: [PATCH 6/7] get_variable_with_initializer() rollback --- tensorlayer/layers/spatial_transformer.py | 4 ++-- tensorlayer/layers/utils.py | 8 +------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/tensorlayer/layers/spatial_transformer.py b/tensorlayer/layers/spatial_transformer.py index fe6d09534..99ed7af72 100644 --- a/tensorlayer/layers/spatial_transformer.py +++ b/tensorlayer/layers/spatial_transformer.py @@ -267,8 +267,8 @@ def build(self, inputs_shape): self.in_channels = inputs_shape[1] shape = [inputs_shape[1], 6] self.W = self._get_weights("weights", shape=tuple(shape), init=tl.initializers.Zeros()) - identity = tf.constant(np.array([[1., 0, 0], [0, 1., 0]]).astype('float32').flatten()) - self.b = self._get_weights("biases", shape=(6,), init=identity) + identity = np.reshape(np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32), newshape=(6, )) + self.b = self._get_weights("biases", shape=(6,), init=tl.initializers.Constant(identity)) def forward(self, inputs): theta_input, U = inputs diff --git a/tensorlayer/layers/utils.py b/tensorlayer/layers/utils.py index c6057b9b0..10cc1fc18 100644 --- a/tensorlayer/layers/utils.py +++ b/tensorlayer/layers/utils.py @@ -139,13 +139,7 @@ def get_variable_with_initializer(scope_name, var_name, shape, init=tl.initializ # initial_value = init()(shape=shape) # var = tf.Variable(initial_value=initial_value, name=var_name) # FIXME: not sure whether this is correct? - if isinstance(init, tf.Tensor): - if shape != init.shape: - raise ValueError('The shape of initial value: %s is not equal to the shape of variable: %s' - % (init.shape, shape)) - initial_value = init - else: - initial_value = init(shape=shape) + initial_value = init(shape=shape) var = tf.Variable(initial_value=initial_value, name=var_name) #, **init_args) # else: From 17362d9564ee6462f4de9058ea5122b8893cf8e7 Mon Sep 17 00:00:00 2001 From: 1FengL Date: Fri, 12 Apr 2019 20:39:35 +0100 Subject: [PATCH 7/7] update the docs for SpatialTransformer2dAffine --- tensorlayer/layers/spatial_transformer.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tensorlayer/layers/spatial_transformer.py b/tensorlayer/layers/spatial_transformer.py index 99ed7af72..892bc02f4 100644 --- a/tensorlayer/layers/spatial_transformer.py +++ b/tensorlayer/layers/spatial_transformer.py @@ -213,15 +213,11 @@ class SpatialTransformer2dAffine(Layer): Parameters ----------- - prev_layer : :class:`Layer` - Previous layer. - theta_layer : :class:`Layer` - The localisation network. - - We will use a :class:`Dense` to make the theta size to [batch, 6], value range to [0, 1] (via tanh). + in_channels: out_size : tuple of int or None - The size of the output of the network (height, width), the feature maps will be resized by this. + - The size of the output of the network (height, width), the feature maps will be resized by this. name : str - A unique layer name. + - A unique layer name. References ----------- @@ -271,6 +267,14 @@ def build(self, inputs_shape): self.b = self._get_weights("biases", shape=(6,), init=tl.initializers.Constant(identity)) def forward(self, inputs): + """ + :param inputs: a tuple (theta_input, U). + - theta_input is of size [batch, in_channels]. We will use a :class:`Dense` to + make the theta size to [batch, 6], value range to [0, 1] (via tanh). + - U is the previous layer, which the affine transformation is applied to. + :return: tensor of size [batch, out_size[0], out_size[1], n_channels] after affine transformation, + n_channels is identical to that of U. + """ theta_input, U = inputs theta = tf.nn.tanh(tf.matmul(theta_input, self.W) + self.b) outputs = transformer(U, theta, out_size=self.out_size)