|
3 | 3 | import tensorflow as tf
|
4 | 4 |
|
5 | 5 | class SegNetAutoencoder:
|
6 |
| - def __init__(self, n, max_images=3): |
7 |
| - self.params = [] |
8 |
| - self.n = n |
| 6 | + def __init__(self, n, strided=False, max_images=3): |
9 | 7 | self.max_images = max_images
|
| 8 | + self.n = n |
| 9 | + self.strided = strided |
10 | 10 |
|
11 | 11 | def conv(self, x, channels_shape, name):
|
12 | 12 | return cnn.conv(x, [3, 3], channels_shape, 1, name)
|
13 | 13 |
|
| 14 | + def conv2(self, x, channels_shape, name): |
| 15 | + return cnn.conv(x, [3, 3], channels_shape, 2, name) |
| 16 | + |
14 | 17 | def deconv(self, x, channels_shape, name):
|
15 | 18 | return cnn.deconv(x, [3, 3], channels_shape, 1, name)
|
16 | 19 |
|
17 | 20 | def pool(self, x):
|
18 | 21 | return cnn.max_pool(x, 2, 2)
|
19 | 22 |
|
20 |
| - def unpool(self, bottom): |
21 |
| - sh = bottom.get_shape().as_list() |
22 |
| - dim = len(sh[1:-1]) |
23 |
| - out = tf.reshape(bottom, [-1] + sh[-dim:]) |
24 |
| - for i in range(dim, 0, -1): |
25 |
| - out = tf.concat(i, [out, tf.zeros_like(out)]) |
26 |
| - out_size = [-1] + [s * 2 for s in sh[1:-1]] + [sh[-1]] |
27 |
| - return tf.reshape(out, out_size) |
28 |
| - |
29 |
| - def encode(self, images): |
30 |
| - tf.image_summary('input', images, max_images=self.max_images) |
| 23 | + def unpool(self, x): |
| 24 | + return cnn.unpool(x, 2) |
| 25 | + |
| 26 | + def resize_conv(self, x, channels_shape, name): |
| 27 | + shape = x.get_shape().as_list() |
| 28 | + height = shape[1] * 2 |
| 29 | + width = shape[2] * 2 |
| 30 | + resized = tf.image.resize_nearest_neighbor(x, [height, width]) |
| 31 | + return cnn.conv(resized, [3, 3], channels_shape, 1, name, repad=True) |
| 32 | + |
| 33 | + def inference_with_pooling(self, images): |
| 34 | + tf.summary.image('input', images, max_outputs=self.max_images) |
31 | 35 |
|
32 | 36 | with tf.variable_scope('pool1'):
|
33 | 37 | conv1 = self.conv(images, [3, 64], 'conv1_1')
|
@@ -88,26 +92,64 @@ def decode(self, code):
|
88 | 92 | deconv12 = self.deconv(unpool5, [64, 64], 'deconv1_2')
|
89 | 93 | deconv13 = self.deconv(deconv12, [self.n, 64], 'deconv1_1')
|
90 | 94 |
|
91 |
| - rgb_output = classifier.rgb(deconv13) |
92 |
| - tf.image_summary('output', rgb_output, max_images=self.max_images) |
93 |
| - |
| 95 | + rgb_image = classifier.rgb(deconv13) |
| 96 | + tf.summary.image('output', rgb_image, max_outputs=self.max_images) |
94 | 97 | return deconv13
|
95 | 98 |
|
96 |
| - def prepare_encoder_parameters(self): |
97 |
| - param_format = 'conv%d_%d_%s' |
98 |
| - conv_layers = [2, 2, 3, 3, 3] |
| 99 | + def strided_inference(self, images): |
| 100 | + tf.summary.image('input', images, max_outputs=self.max_images) |
| 101 | + |
| 102 | + with tf.variable_scope('pool1'): |
| 103 | + conv1 = self.conv(images, [3, 64], 'conv1_1') |
| 104 | + conv2 = self.conv2(conv1, [64, 64], 'conv1_2') |
| 105 | + |
| 106 | + with tf.variable_scope('pool2'): |
| 107 | + conv3 = self.conv(conv2, [64, 128], 'conv2_1') |
| 108 | + conv4 = self.conv2(conv3, [128, 128], 'conv2_2') |
| 109 | + |
| 110 | + with tf.variable_scope('pool3'): |
| 111 | + conv5 = self.conv(conv4, [128, 256], 'conv3_1') |
| 112 | + conv6 = self.conv(conv5, [256, 256], 'conv3_2') |
| 113 | + conv7 = self.conv2(conv6, [256, 256], 'conv3_3') |
| 114 | + |
| 115 | + with tf.variable_scope('pool4'): |
| 116 | + conv8 = self.conv(conv7, [256, 512], 'conv4_1') |
| 117 | + conv9 = self.conv(conv8, [512, 512], 'conv4_2') |
| 118 | + conv10 = self.conv2(conv9, [512, 512], 'conv4_3') |
| 119 | + |
| 120 | + with tf.variable_scope('pool5'): |
| 121 | + conv11 = self.conv(conv10, [512, 512], 'conv5_1') |
| 122 | + conv12 = self.conv(conv11, [512, 512], 'conv5_2') |
| 123 | + conv13 = self.conv2(conv12, [512, 512], 'conv5_3') |
| 124 | + |
| 125 | + with tf.variable_scope('unpool1'): |
| 126 | + deconv1 = self.resize_conv(conv13, [512, 512], 'deconv5_3') |
| 127 | + deconv2 = self.deconv(deconv1, [512, 512], 'deconv5_2') |
| 128 | + deconv3 = self.deconv(deconv2, [512, 512], 'deconv5_1') |
99 | 129 |
|
100 |
| - for pool in range(1, 6): |
101 |
| - with tf.variable_scope('pool%d' % pool, reuse=True): |
102 |
| - for conv in range(1, conv_layers[pool - 1] + 1): |
103 |
| - weights = tf.get_variable(param_format % (pool, conv, 'W')) |
104 |
| - biases = tf.get_variable(param_format % (pool, conv, 'b')) |
105 |
| - self.params += [weights, biases] |
| 130 | + with tf.variable_scope('unpool2'): |
| 131 | + deconv4 = self.resize_conv(deconv3, [512, 512], 'deconv4_3') |
| 132 | + deconv5 = self.deconv(deconv4, [512, 512], 'deconv4_2') |
| 133 | + deconv6 = self.deconv(deconv5, [256, 512], 'deconv4_1') |
106 | 134 |
|
107 |
| - def get_encoder_parameters(self): |
108 |
| - return self.params |
| 135 | + with tf.variable_scope('unpool3'): |
| 136 | + deconv7 = self.resize_conv(deconv6, [256, 256], 'deconv3_3') |
| 137 | + deconv8 = self.deconv(deconv7, [256, 256], 'deconv3_2') |
| 138 | + deconv9 = self.deconv(deconv8, [128, 256], 'deconv3_1') |
| 139 | + |
| 140 | + with tf.variable_scope('unpool4'): |
| 141 | + deconv10 = self.resize_conv(deconv9, [128, 128], 'deconv2_2') |
| 142 | + deconv11 = self.deconv(deconv10, [64, 128], 'deconv2_1') |
| 143 | + |
| 144 | + with tf.variable_scope('unpool5'): |
| 145 | + deconv12 = self.resize_conv(deconv11, [64, 64], 'deconv1_2') |
| 146 | + deconv13 = self.deconv(deconv12, [self.n, 64], 'deconv1_1') |
| 147 | + |
| 148 | + rgb_image = classifier.rgb(deconv13) |
| 149 | + tf.summary.image('output', rgb_image, max_outputs=self.max_images) |
| 150 | + return deconv13 |
109 | 151 |
|
110 | 152 | def inference(self, images):
|
111 |
| - code = self.encode(images) |
112 |
| - self.prepare_encoder_parameters() |
113 |
| - return self.decode(code) |
| 153 | + if self.strided: |
| 154 | + return self.strided_inference(images) |
| 155 | + return self.inference_with_pooling(images) |
0 commit comments