Skip to content

Commit a51d063

Browse files
committed
Introduce resize-convolution-based autoencoders
1 parent da348ae commit a51d063

File tree

9 files changed

+185
-154
lines changed

9 files changed

+185
-154
lines changed

Diff for: README.md

+44-4
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,52 @@ SegNet is a TensorFlow implementation of the [segmentation network proposed by K
44

55
## Configuration
66

7-
Before running, download the [VGG16 weights file](https://www.cs.toronto.edu/~frossard/vgg16/vgg16_weights.npz)
8-
and save it as `input/vgg16_weights.npz` if you want to initialize the encoder weights with the VGG16 ones trained on ImageNet classification dataset.
7+
Create a `config.py` file, containing color maps, working dataset and other options.
98

10-
In `config.py`, choose your working dataset. The dataset name needs to match the data directories you create in your `input` folder.
9+
```
10+
colors = {
11+
'segnet-32': [
12+
[64, 128, 64], # Animal
13+
[192, 0, 128], # Archway
14+
[0, 128, 192], # Bicyclist
15+
[0, 128, 64], # Bridge
16+
[128, 0, 0], # Building
17+
[64, 0, 128], # Car
18+
[64, 0, 192], # CartLuggagePram
19+
[192, 128, 64], # Child
20+
[192, 192, 128], # Column_Pole
21+
[64, 64, 128], # Fence
22+
[128, 0, 192], # LaneMkgsDriv
23+
[192, 0, 64], # LaneMkgsNonDriv
24+
[128, 128, 64], # Misc_Text
25+
[192, 0, 192], # MotorcycleScooter
26+
[128, 64, 64], # OtherMoving
27+
[64, 192, 128], # ParkingBlock
28+
[64, 64, 0], # Pedestrian
29+
[128, 64, 128], # Road
30+
[128, 128, 192], # RoadShoulder
31+
[0, 0, 192], # Sidewalk
32+
[192, 128, 128], # SignSymbol
33+
[128, 128, 128], # Sky
34+
[64, 128, 192], # SUVPickupTruck
35+
[0, 0, 64], # TrafficCone
36+
[0, 64, 64], # TrafficLight
37+
[192, 64, 128], # Train
38+
[128, 128, 0], # Tree
39+
[192, 128, 192], # Truck_Bus
40+
[64, 0, 64], # Tunnel
41+
[192, 192, 0], # VegetationMisc
42+
[0, 0, 0], # Void
43+
[64, 192, 0] # Wall
44+
]
45+
}
46+
gpu_memory_fraction = 0.7
47+
working_dataset = 'segnet-32'
48+
```
49+
50+
The `dataset_name` needs to match the data directories you create in your `input` folder.
1151
You can use `segnet-32` and `segnet-13` to replicate the aforementioned Kendall et al. experiments.
1252

1353
## Train and test
1454

15-
Train SegNet with `python -m src/train.py`. Analogously, test it with `python -m src/test.py`.
55+
Train SegNet with `python src/train.py`. Analogously, test it with `python src/test.py`.

Diff for: config.py

-1
This file was deleted.

Diff for: src/convnet.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import tensorflow as tf
22

3-
def conv(x, receptive_field_shape, channels_shape, stride, name):
3+
def conv(x, receptive_field_shape, channels_shape, stride, name, repad=False):
44
kernel_shape = receptive_field_shape + channels_shape
55
bias_shape = [channels_shape[-1]]
66

77
weights = tf.get_variable('%s_W' % name, kernel_shape, initializer=tf.truncated_normal_initializer(stddev=.1))
88
biases = tf.get_variable('%s_b' % name, bias_shape, initializer=tf.constant_initializer(.1))
9-
conv = tf.nn.conv2d(x, weights, strides=[1, stride, stride, 1], padding='SAME')
9+
10+
if repad:
11+
padded = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='SYMMETRIC')
12+
conv = tf.nn.conv2d(padded, weights, strides=[1, stride, stride, 1], padding='VALID')
13+
else:
14+
conv = tf.nn.conv2d(x, weights, strides=[1, stride, stride, 1], padding='SAME')
15+
1016
conv_bias = tf.nn.bias_add(conv, biases)
1117
return tf.nn.relu(tf.contrib.layers.batch_norm(conv_bias))
1218

@@ -21,9 +27,23 @@ def deconv(x, receptive_field_shape, channels_shape, stride, name):
2127

2228
weights = tf.get_variable('%s_W' % name, kernel_shape, initializer=tf.truncated_normal_initializer(stddev=.1))
2329
biases = tf.get_variable('%s_b' % name, bias_shape, initializer=tf.constant_initializer(.1))
24-
conv = tf.nn.conv2d_transpose(x, weights, [batch_size, height, width, channels_shape[0]], [1, stride, stride, 1], padding='SAME')
30+
conv = tf.nn.conv2d_transpose(x, weights, [batch_size, height * stride, width * stride, channels_shape[0]], [1, stride, stride, 1], padding='SAME')
2531
conv_bias = tf.nn.bias_add(conv, biases)
2632
return tf.nn.relu(tf.contrib.layers.batch_norm(conv_bias))
2733

2834
def max_pool(x, size, stride, padding='SAME'):
2935
return tf.nn.max_pool(x, ksize=[1, size, size, 1], strides=[1, stride, stride, 1], padding=padding, name='maxpool')
36+
37+
def unpool(x, size):
38+
out = tf.concat_v2([x, tf.zeros_like(x)], 3)
39+
out = tf.concat_v2([out, tf.zeros_like(out)], 2)
40+
41+
sh = x.get_shape().as_list()
42+
if None not in sh[1:]:
43+
out_size = [-1, sh[1] * size, sh[2] * size, sh[3]]
44+
return tf.reshape(out, out_size)
45+
46+
shv = tf.shape(x)
47+
ret = tf.reshape(out, tf.stack([-1, shv[1] * size, shv[2] * size, sh[3]]))
48+
ret.set_shape([None, None, None, sh[3]])
49+
return ret

Diff for: src/initializer.py

-9
This file was deleted.

Diff for: src/models.py

+73-31
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,35 @@
33
import tensorflow as tf
44

55
class SegNetAutoencoder:
6-
def __init__(self, n, max_images=3):
7-
self.params = []
8-
self.n = n
6+
def __init__(self, n, strided=False, max_images=3):
97
self.max_images = max_images
8+
self.n = n
9+
self.strided = strided
1010

1111
def conv(self, x, channels_shape, name):
1212
return cnn.conv(x, [3, 3], channels_shape, 1, name)
1313

14+
def conv2(self, x, channels_shape, name):
15+
return cnn.conv(x, [3, 3], channels_shape, 2, name)
16+
1417
def deconv(self, x, channels_shape, name):
1518
return cnn.deconv(x, [3, 3], channels_shape, 1, name)
1619

1720
def pool(self, x):
1821
return cnn.max_pool(x, 2, 2)
1922

20-
def unpool(self, bottom):
21-
sh = bottom.get_shape().as_list()
22-
dim = len(sh[1:-1])
23-
out = tf.reshape(bottom, [-1] + sh[-dim:])
24-
for i in range(dim, 0, -1):
25-
out = tf.concat(i, [out, tf.zeros_like(out)])
26-
out_size = [-1] + [s * 2 for s in sh[1:-1]] + [sh[-1]]
27-
return tf.reshape(out, out_size)
28-
29-
def encode(self, images):
30-
tf.image_summary('input', images, max_images=self.max_images)
23+
def unpool(self, x):
24+
return cnn.unpool(x, 2)
25+
26+
def resize_conv(self, x, channels_shape, name):
27+
shape = x.get_shape().as_list()
28+
height = shape[1] * 2
29+
width = shape[2] * 2
30+
resized = tf.image.resize_nearest_neighbor(x, [height, width])
31+
return cnn.conv(resized, [3, 3], channels_shape, 1, name, repad=True)
32+
33+
def inference_with_pooling(self, images):
34+
tf.summary.image('input', images, max_outputs=self.max_images)
3135

3236
with tf.variable_scope('pool1'):
3337
conv1 = self.conv(images, [3, 64], 'conv1_1')
@@ -88,26 +92,64 @@ def decode(self, code):
8892
deconv12 = self.deconv(unpool5, [64, 64], 'deconv1_2')
8993
deconv13 = self.deconv(deconv12, [self.n, 64], 'deconv1_1')
9094

91-
rgb_output = classifier.rgb(deconv13)
92-
tf.image_summary('output', rgb_output, max_images=self.max_images)
93-
95+
rgb_image = classifier.rgb(deconv13)
96+
tf.summary.image('output', rgb_image, max_outputs=self.max_images)
9497
return deconv13
9598

96-
def prepare_encoder_parameters(self):
97-
param_format = 'conv%d_%d_%s'
98-
conv_layers = [2, 2, 3, 3, 3]
99+
def strided_inference(self, images):
100+
tf.summary.image('input', images, max_outputs=self.max_images)
101+
102+
with tf.variable_scope('pool1'):
103+
conv1 = self.conv(images, [3, 64], 'conv1_1')
104+
conv2 = self.conv2(conv1, [64, 64], 'conv1_2')
105+
106+
with tf.variable_scope('pool2'):
107+
conv3 = self.conv(conv2, [64, 128], 'conv2_1')
108+
conv4 = self.conv2(conv3, [128, 128], 'conv2_2')
109+
110+
with tf.variable_scope('pool3'):
111+
conv5 = self.conv(conv4, [128, 256], 'conv3_1')
112+
conv6 = self.conv(conv5, [256, 256], 'conv3_2')
113+
conv7 = self.conv2(conv6, [256, 256], 'conv3_3')
114+
115+
with tf.variable_scope('pool4'):
116+
conv8 = self.conv(conv7, [256, 512], 'conv4_1')
117+
conv9 = self.conv(conv8, [512, 512], 'conv4_2')
118+
conv10 = self.conv2(conv9, [512, 512], 'conv4_3')
119+
120+
with tf.variable_scope('pool5'):
121+
conv11 = self.conv(conv10, [512, 512], 'conv5_1')
122+
conv12 = self.conv(conv11, [512, 512], 'conv5_2')
123+
conv13 = self.conv2(conv12, [512, 512], 'conv5_3')
124+
125+
with tf.variable_scope('unpool1'):
126+
deconv1 = self.resize_conv(conv13, [512, 512], 'deconv5_3')
127+
deconv2 = self.deconv(deconv1, [512, 512], 'deconv5_2')
128+
deconv3 = self.deconv(deconv2, [512, 512], 'deconv5_1')
99129

100-
for pool in range(1, 6):
101-
with tf.variable_scope('pool%d' % pool, reuse=True):
102-
for conv in range(1, conv_layers[pool - 1] + 1):
103-
weights = tf.get_variable(param_format % (pool, conv, 'W'))
104-
biases = tf.get_variable(param_format % (pool, conv, 'b'))
105-
self.params += [weights, biases]
130+
with tf.variable_scope('unpool2'):
131+
deconv4 = self.resize_conv(deconv3, [512, 512], 'deconv4_3')
132+
deconv5 = self.deconv(deconv4, [512, 512], 'deconv4_2')
133+
deconv6 = self.deconv(deconv5, [256, 512], 'deconv4_1')
106134

107-
def get_encoder_parameters(self):
108-
return self.params
135+
with tf.variable_scope('unpool3'):
136+
deconv7 = self.resize_conv(deconv6, [256, 256], 'deconv3_3')
137+
deconv8 = self.deconv(deconv7, [256, 256], 'deconv3_2')
138+
deconv9 = self.deconv(deconv8, [128, 256], 'deconv3_1')
139+
140+
with tf.variable_scope('unpool4'):
141+
deconv10 = self.resize_conv(deconv9, [128, 128], 'deconv2_2')
142+
deconv11 = self.deconv(deconv10, [64, 128], 'deconv2_1')
143+
144+
with tf.variable_scope('unpool5'):
145+
deconv12 = self.resize_conv(deconv11, [64, 64], 'deconv1_2')
146+
deconv13 = self.deconv(deconv12, [self.n, 64], 'deconv1_1')
147+
148+
rgb_image = classifier.rgb(deconv13)
149+
tf.summary.image('output', rgb_image, max_outputs=self.max_images)
150+
return deconv13
109151

110152
def inference(self, images):
111-
code = self.encode(images)
112-
self.prepare_encoder_parameters()
113-
return self.decode(code)
153+
if self.strided:
154+
return self.strided_inference(images)
155+
return self.inference_with_pooling(images)

Diff for: src/scalar_ops.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import tensorflow as tf
2+
3+
def accuracy(logits, labels, batch_size):
4+
equal_pixels = tf.reduce_sum(tf.to_float(tf.equal(logits, labels)))
5+
total_pixels = batch_size * 224 * 224 * 3
6+
return equal_pixels / total_pixels
7+
8+
def loss(logits, labels):
9+
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels)
10+
return tf.reduce_mean(cross_entropy, name='loss')

Diff for: src/test.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,43 @@
11
from inputs import inputs
22
from models import SegNetAutoencoder
3+
from scalar_ops import accuracy, loss
34

45
import classifier
56
import config
67
import tensorflow as tf
78
import utils
89

9-
test_file = utils.get_test_set(config.working_dataset)
10+
test_file, test_labels_file = utils.get_test_set(config.working_dataset, include_labels=True)
1011

11-
tf.app.flags.DEFINE_string('test', test_file, 'Test data')
1212
tf.app.flags.DEFINE_string('ckpt_dir', './ckpts', 'Train checkpoint directory')
13-
# tf.app.flags.DEFINE_string('test_labels', './input/test_labels.tfrecords', 'Test labels data')
13+
tf.app.flags.DEFINE_string('test', test_file, 'Test data')
14+
tf.app.flags.DEFINE_string('test_labels', test_labels_file, 'Test labels data')
1415
tf.app.flags.DEFINE_string('test_logs', './logs/test', 'Log directory')
1516

16-
tf.app.flags.DEFINE_integer('batch', 35, 'Batch size')
17+
tf.app.flags.DEFINE_boolean('strided', True, 'Use strided convolutions and deconvolutions')
1718

18-
FLAGS = tf.app.flags.FLAGS
19+
tf.app.flags.DEFINE_integer('batch', 200, 'Batch size')
1920

20-
def accuracy(logits, labels):
21-
equal_pixels = tf.reduce_sum(tf.to_float(tf.equal(logits, labels)))
22-
total_pixels = tf.to_float(tf.reduce_prod(tf.shape(logits)))
23-
return equal_pixels / total_pixels
21+
FLAGS = tf.app.flags.FLAGS
2422

2523
def test():
26-
#images, labels = inputs(FLAGS.batch, FLAGS.test, FLAGS.test_labels)
27-
images = inputs(FLAGS.batch, FLAGS.test)
28-
#one_hot_labels = classifier.one_hot(labels)
24+
images, labels = inputs(FLAGS.batch, FLAGS.test, FLAGS.test_labels)
25+
tf.summary.image('labels', labels)
26+
one_hot_labels = classifier.one_hot(labels)
2927

30-
autoencoder = SegNetAutoencoder(2, max_images=20)
28+
autoencoder = SegNetAutoencoder(4, strided=FLAGS.strided)
3129
logits = autoencoder.inference(images)
3230

33-
#accuracy_op = accuracy(logits, one_hot_labels)
34-
#tf.scalar_summary('accuracy', accuracy_op)
31+
accuracy_op = accuracy(logits, one_hot_labels, FLAGS.batch)
32+
tf.scalar_summary('accuracy', accuracy_op)
3533

3634
saver = tf.train.Saver(tf.global_variables())
3735
summary = tf.merge_all_summaries()
3836
summary_writer = tf.train.SummaryWriter(FLAGS.test_logs)
3937

40-
with tf.Session() as sess:
38+
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=config.gpu_memory_fraction)
39+
session_config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
40+
with tf.Session(config=session_config) as sess:
4141
coord = tf.train.Coordinator()
4242
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
4343

@@ -50,7 +50,6 @@ def test():
5050
ckpt_path = ckpt.model_checkpoint_path
5151
saver.restore(sess, ckpt_path)
5252

53-
#accuracy_value, summary_str = sess.run([accuracy_op, summary])
5453
summary_str = sess.run(summary)
5554
summary_writer.add_summary(summary_str)
5655
summary_writer.flush()

0 commit comments

Comments
 (0)