Skip to content

Commit 41c52d6

Browse files
committed
Spatial Transformer model
Shorten STN summary in README relinked to data files adding license header, editing AUTHORS file adding tensorflow version
1 parent d51fdd2 commit 41c52d6

File tree

7 files changed

+625
-0
lines changed

7 files changed

+625
-0
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
# The email address is not required for organizations.
88

99
Google Inc.
10+
David Dao <[email protected]>

transformer/README.md

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Spatial Transformer Network
2+
3+
The Spatial Transformer Network [1] allows the spatial manipulation of data within the network.
4+
5+
<div align="center">
6+
<img width="600px" src="http://i.imgur.com/ExGDVul.png"><br><br>
7+
</div>
8+
9+
### API
10+
11+
A Spatial Transformer Network implemented in Tensorflow 0.7 and based on [2].
12+
13+
#### How to use
14+
15+
<div align="center">
16+
<img src="http://i.imgur.com/gfqLV3f.png"><br><br>
17+
</div>
18+
19+
```python
20+
transformer(U, theta, downsample_factor=1)
21+
```
22+
23+
#### Parameters
24+
25+
U : float
26+
The output of a convolutional net should have the
27+
shape [num_batch, height, width, num_channels].
28+
theta: float
29+
The output of the
30+
localisation network should be [num_batch, 6].
31+
downsample_factor : float
32+
A value of 1 will keep the original size of the image
33+
Values larger than 1 will downsample the image.
34+
Values below 1 will upsample the image
35+
example image: height = 100, width = 200
36+
downsample_factor = 2
37+
output image will then be 50, 100
38+
39+
40+
#### Notes
41+
To initialize the network to the identity transform init ``theta`` to :
42+
43+
```python
44+
identity = np.array([[1., 0., 0.],
45+
[0., 1., 0.]])
46+
identity = identity.flatten()
47+
theta = tf.Variable(initial_value=identity)
48+
```
49+
50+
#### Experiments
51+
52+
<div align="center">
53+
<img width="600px" src="http://i.imgur.com/HtCBYk2.png"><br><br>
54+
</div>
55+
56+
We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN.
57+
58+
All experiments were run in Tensorflow 0.7.
59+
60+
### References
61+
62+
[1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015)
63+
64+
[2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py

transformer/cluttered_mnist.py

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import tensorflow as tf
16+
from spatial_transformer import transformer
17+
from scipy import ndimage
18+
import numpy as np
19+
import matplotlib.pyplot as plt
20+
from tf_utils import conv2d, linear, weight_variable, bias_variable, dense_to_one_hot
21+
22+
# %% Load data
23+
mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz')
24+
25+
X_train = mnist_cluttered['X_train']
26+
y_train = mnist_cluttered['y_train']
27+
X_valid = mnist_cluttered['X_valid']
28+
y_valid = mnist_cluttered['y_valid']
29+
X_test = mnist_cluttered['X_test']
30+
y_test = mnist_cluttered['y_test']
31+
32+
# % turn from dense to one hot representation
33+
Y_train = dense_to_one_hot(y_train, n_classes=10)
34+
Y_valid = dense_to_one_hot(y_valid, n_classes=10)
35+
Y_test = dense_to_one_hot(y_test, n_classes=10)
36+
37+
# %% Graph representation of our network
38+
39+
# %% Placeholders for 40x40 resolution
40+
x = tf.placeholder(tf.float32, [None, 1600])
41+
y = tf.placeholder(tf.float32, [None, 10])
42+
43+
# %% Since x is currently [batch, height*width], we need to reshape to a
44+
# 4-D tensor to use it in a convolutional graph. If one component of
45+
# `shape` is the special value -1, the size of that dimension is
46+
# computed so that the total size remains constant. Since we haven't
47+
# defined the batch dimension's shape yet, we use -1 to denote this
48+
# dimension should not change size.
49+
x_tensor = tf.reshape(x, [-1, 40, 40, 1])
50+
51+
# %% We'll setup the two-layer localisation network to figure out the parameters for an affine transformation of the input
52+
# %% Create variables for fully connected layer
53+
W_fc_loc1 = weight_variable([1600, 20])
54+
b_fc_loc1 = bias_variable([20])
55+
56+
W_fc_loc2 = weight_variable([20, 6])
57+
initial = np.array([[1.,0, 0],[0,1.,0]]) # Use identity transformation as starting point
58+
initial = initial.astype('float32')
59+
initial = initial.flatten()
60+
b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')
61+
62+
# %% Define the two layer localisation network
63+
h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1)
64+
# %% We can add dropout for regularizing and to reduce overfitting like so:
65+
keep_prob = tf.placeholder(tf.float32)
66+
h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob)
67+
# %% Second layer
68+
h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)
69+
70+
# %% We'll create a spatial transformer module to identify discriminative patches
71+
h_trans = transformer(x_tensor, h_fc_loc2, downsample_factor=1)
72+
73+
# %% We'll setup the first convolutional layer
74+
# Weight matrix is [height x width x input_channels x output_channels]
75+
filter_size = 3
76+
n_filters_1 = 16
77+
W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1])
78+
79+
# %% Bias is [output_channels]
80+
b_conv1 = bias_variable([n_filters_1])
81+
82+
# %% Now we can build a graph which does the first layer of convolution:
83+
# we define our stride as batch x height x width x channels
84+
# instead of pooling, we use strides of 2 and more layers
85+
# with smaller filters.
86+
87+
h_conv1 = tf.nn.relu(
88+
tf.nn.conv2d(input=h_trans,
89+
filter=W_conv1,
90+
strides=[1, 2, 2, 1],
91+
padding='SAME') +
92+
b_conv1)
93+
94+
# %% And just like the first layer, add additional layers to create
95+
# a deep net
96+
n_filters_2 = 16
97+
W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2])
98+
b_conv2 = bias_variable([n_filters_2])
99+
h_conv2 = tf.nn.relu(
100+
tf.nn.conv2d(input=h_conv1,
101+
filter=W_conv2,
102+
strides=[1, 2, 2, 1],
103+
padding='SAME') +
104+
b_conv2)
105+
106+
# %% We'll now reshape so we can connect to a fully-connected layer:
107+
h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2])
108+
109+
# %% Create a fully-connected layer:
110+
n_fc = 1024
111+
W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc])
112+
b_fc1 = bias_variable([n_fc])
113+
h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1)
114+
115+
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
116+
117+
# %% And finally our softmax layer:
118+
W_fc2 = weight_variable([n_fc, 10])
119+
b_fc2 = bias_variable([10])
120+
y_pred = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
121+
122+
# %% Define loss/eval/training functions
123+
cross_entropy = -tf.reduce_sum(y * tf.log(y_pred))
124+
opt = tf.train.AdamOptimizer()
125+
optimizer = opt.minimize(cross_entropy)
126+
grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])
127+
128+
# %% Monitor accuracy
129+
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
130+
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
131+
132+
# %% We now create a new session to actually perform the initialization the
133+
# variables:
134+
sess = tf.Session()
135+
sess.run(tf.initialize_all_variables())
136+
137+
138+
# %% We'll now train in minibatches and report accuracy, loss:
139+
iter_per_epoch = 100
140+
n_epochs = 500
141+
train_size = 10000
142+
143+
indices = np.linspace(0,10000 - 1,iter_per_epoch)
144+
indices = indices.astype('int')
145+
146+
for epoch_i in range(n_epochs):
147+
for iter_i in range(iter_per_epoch - 1):
148+
batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
149+
batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]]
150+
151+
if iter_i % 10 == 0:
152+
loss = sess.run(cross_entropy,
153+
feed_dict={
154+
x: batch_xs,
155+
y: batch_ys,
156+
keep_prob: 1.0
157+
})
158+
print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss))
159+
160+
sess.run(optimizer, feed_dict={
161+
x: batch_xs, y: batch_ys, keep_prob: 0.8})
162+
163+
164+
print('Accuracy: ' + str(sess.run(accuracy,
165+
feed_dict={
166+
x: X_valid,
167+
y: Y_valid,
168+
keep_prob: 1.0
169+
})))
170+
#theta = sess.run(h_fc_loc2, feed_dict={
171+
# x: batch_xs, keep_prob: 1.0})
172+
#print(theta[0])

transformer/data/README.md

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
### How to get the data
2+
3+
#### Cluttered MNIST
4+
5+
The cluttered MNIST dataset can be found here [1] or can be generated via [2].
6+
7+
Settings used for `cluttered_mnist.py` :
8+
9+
```python
10+
11+
ORG_SHP = [28, 28]
12+
OUT_SHP = [40, 40]
13+
NUM_DISTORTIONS = 8
14+
dist_size = (5, 5)
15+
16+
```
17+
18+
[1] https://github.com/daviddao/spatial-transformer-tensorflow
19+
20+
[2] https://github.com/skaae/recurrent-spatial-transformer-code/blob/master/MNIST_SEQUENCE/create_mnist_sequence.py

transformer/example.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import tensorflow as tf
16+
from spatial_transformer import transformer
17+
from scipy import ndimage
18+
import numpy as np
19+
import matplotlib.pyplot as plt
20+
from tf_utils import conv2d, linear, weight_variable, bias_variable
21+
22+
# %% Create a batch of three images (1600 x 1200)
23+
# %% Image retrieved from https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
24+
im = ndimage.imread('cat.jpg')
25+
im = im / 255.
26+
im = im.reshape(1, 1200, 1600, 3)
27+
im = im.astype('float32')
28+
29+
# %% Simulate batch
30+
batch = np.append(im, im, axis=0)
31+
batch = np.append(batch, im, axis=0)
32+
num_batch = 3
33+
34+
x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
35+
x = tf.cast(batch,'float32')
36+
37+
# %% Create localisation network and convolutional layer
38+
with tf.variable_scope('spatial_transformer_0'):
39+
40+
# %% Create a fully-connected layer with 6 output nodes
41+
n_fc = 6
42+
W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
43+
44+
# %% Zoom into the image
45+
initial = np.array([[0.5,0, 0],[0,0.5,0]])
46+
initial = initial.astype('float32')
47+
initial = initial.flatten()
48+
49+
b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
50+
h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1
51+
h_trans = transformer(x, h_fc1, downsample_factor=2)
52+
53+
# %% Run session
54+
sess = tf.Session()
55+
sess.run(tf.initialize_all_variables())
56+
y = sess.run(h_trans, feed_dict={x: batch})
57+
58+
# plt.imshow(y[0])

0 commit comments

Comments
 (0)