-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathimage_reader.py
232 lines (179 loc) · 8.48 KB
/
image_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os
import random
from collections import namedtuple
import click
import numpy as np
from IPython import embed
from keras.preprocessing.image import (
load_img, img_to_array,
flip_axis)
# The set of parameters that describes an instance of
# (random) augmentation
TransformParams = namedtuple(
'TransformParameters',
('do_hor_flip', 'do_vert_flip'))
pascal_mean = np.array([102.93, 111.36, 116.52])
label_margin = 186
def load_img_array(fname, grayscale=False, target_size=None, dim_ordering='default'):
"""Loads and image file and returns an array."""
img = load_img(fname,
grayscale=grayscale,
target_size=target_size)
x = img_to_array(img, dim_ordering=dim_ordering)
return x
class RandomTransformer:
"""To consistently add data augmentation to image pairs, we split the process in
two steps. First, we generate a stream of random augmentation parameters, that
can be zipped together with the images. Second, we do the actual transformation,
that has no randomness since the parameters are passed in."""
def __init__(self,
horizontal_flip=False,
vertical_flip=False):
self.horizontal_flip = horizontal_flip
self.vertical_flip = vertical_flip
def random_params_gen(self) -> TransformParams:
"""Returns a generator of random transformation parameters."""
while True:
do_hor_flip = self.horizontal_flip and (np.random.random() < 0.5)
do_vert_flip = self.vertical_flip and (np.random.random() < 0.5)
yield TransformParams(do_hor_flip=do_hor_flip,
do_vert_flip=do_vert_flip)
@staticmethod
def transform(x: np.array, params: TransformParams) -> np.array:
"""Transforms a single image according to the parameters given."""
if params.do_hor_flip:
x = flip_axis(x, 1)
if params.do_vert_flip:
x = flip_axis(x, 0)
return x
class SegmentationDataGenerator:
"""A data generator for segmentation tasks, similar to ImageDataGenerator
in Keras, but with distinct pipelines for images and masks.
The idea is that this object holds no data, and only knows how to run
the pipeline to load, augment, and batch samples. The actual data (csv,
numpy, etc..) must be passed in to the fit/flow functions directly."""
skipped_count = 0
def __init__(self,
random_transformer: RandomTransformer):
self.random_transformer = random_transformer
def get_processed_pairs(self,
img_fnames,
mask_fnames):
# Generators for image data
img_arrs = (load_img_array(f) for f in img_fnames)
mask_arrs = (load_img_array(f, grayscale=True) for f in mask_fnames)
def add_context_margin(image, margin_size, **pad_kwargs):
""" Adds a margin-size border around the image, used for
providing context. """
return np.pad(image,
((margin_size, margin_size),
(margin_size, margin_size),
(0, 0)), **pad_kwargs)
def pad_to_square(image, min_size, **pad_kwargs):
""" Add padding to make sure that the image is larger than (min_size * min_size).
This time, the image is aligned to the top left corner. """
h, w = image.shape[:2]
if h >= min_size and w >= min_size:
return image
top = bottom = left = right = 0
if h < min_size:
top = (min_size - h) // 2
bottom = min_size - h - top
if w < min_size:
left = (min_size - w) // 2
right = min_size - w - left
return np.pad(image,
((top, bottom),
(left, right),
(0, 0)), **pad_kwargs)
def pad_image(image):
image_pad_kwargs = dict(mode='reflect')
image = add_context_margin(image, label_margin, **image_pad_kwargs)
return pad_to_square(image, 500, **image_pad_kwargs)
def pad_label(image):
# Same steps as the image, but the borders are constant white
label_pad_kwargs = dict(mode='constant', constant_values=255)
image = add_context_margin(image, label_margin, **label_pad_kwargs)
return pad_to_square(image, 500, **label_pad_kwargs)
pairs = ((pad_image(image), pad_label(label)) for
image, label in zip(img_arrs, mask_arrs))
# random/center crop
def crop_to(image, target_h=500, target_w=500):
# TODO: random cropping
h_off = (image.shape[0] - target_h) // 2
w_off = (image.shape[1] - target_w) // 2
return image[h_off:h_off + target_h,
w_off:w_off + target_w, :]
pairs = ((crop_to(image), crop_to(label)) for
image, label in pairs)
# random augmentation
augmentation_params = self.random_transformer.random_params_gen()
transf_fn = self.random_transformer.transform
pairs = ((transf_fn(image, params), transf_fn(label, params)) for
((image, label), params) in zip(pairs, augmentation_params))
def rgb_to_bgr(image):
# Swap color channels to use pretrained VGG weights
return image[:, :, ::-1]
pairs = ((rgb_to_bgr(image), rgb_to_bgr(label)) for
image, label in pairs)
def remove_mean(image):
# Note that there's no 0..1 normalization in VGG
return image - pascal_mean
pairs = ((remove_mean(image), label) for
image, label in pairs)
def slice_label(image, offset, label_size, stride):
# Builds label_size * label_size pixels labels, starting from
# offset from the original image, and stride stride
return image[offset:offset + label_size * stride:stride,
offset:offset + label_size * stride:stride]
pairs = ((image, slice_label(label, label_margin, 16, 8)) for
image, label in pairs)
return pairs
def flow_from_list(self,
img_fnames,
mask_fnames,
batch_size,
img_target_size,
mask_target_size,
shuffle=False):
assert batch_size > 0
paired_fnames = list(zip(img_fnames, mask_fnames))
while True:
# Starting a new epoch..
if shuffle:
random.shuffle(paired_fnames) # Shuffles in place
img_fnames, mask_fnames = zip(*paired_fnames)
pairs = self.get_processed_pairs(img_fnames, mask_fnames)
i = 0
img_batch = np.zeros((batch_size, img_target_size[0], img_target_size[1], 3))
mask_batch = np.zeros((batch_size, mask_target_size[0] * mask_target_size[1], 1))
for img, mask in pairs:
# Fill up the batch one pair at a time
img_batch[i] = img
# Pass the label image as 1D array to avoid the problematic Reshape
# layer after Softmax (see model.py)
mask_batch[i] = np.reshape(mask, (-1, 1))
# TODO: remove this ugly workaround to skip pairs whose mask
# has non-labeled pixels.
if 255. in mask:
self.skipped_count += 1
continue
i += 1
if i == batch_size:
i = 0
yield img_batch, mask_batch
@click.command()
@click.option('--list-fname', type=click.Path(exists=True),
default='/mnt/pascal_voc/benchmark_RELEASE/dataset/train.txt')
@click.option('--img-root', type=click.Path(exists=True),
default='/mnt/pascal_voc/benchmark_RELEASE/dataset/img')
@click.option('--mask-root', type=click.Path(exists=True),
default='/mnt/pascal_voc/benchmark_RELEASE/dataset/pngs')
def test_datagen(list_fname, img_root, mask_root):
datagen = SegmentationDataGenerator()
basenames = [l.strip() for l in open(list_fname).readlines()]
img_fnames = [os.path.join(img_root, f) + '.jpg' for f in basenames]
mask_fnames = [os.path.join(mask_root, f) + '.png' for f in basenames]
datagen.flow_from_list(img_fnames, mask_fnames)
if __name__ == '__main__':
test_datagen()