Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fixes infinite loop using imagedetiter #13550

Merged
merged 1 commit into from
Dec 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions python/mxnet/image/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,19 +658,26 @@ class ImageDetIter(ImageIter):
Data name for provided symbols.
label_name : str
Name for detection labels
last_batch_handle : str, optional
How to handle the last batch.
This parameter can be 'pad'(default), 'discard' or 'roll_over'.
If 'pad', the last batch will be padded with data starting from the begining
If 'discard', the last batch will be discarded
If 'roll_over', the remaining elements will be rolled over to the next iteration
kwargs : ...
More arguments for creating augmenter. See mx.image.CreateDetAugmenter.
"""
def __init__(self, batch_size, data_shape,
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
data_name='data', label_name='label', **kwargs):
data_name='data', label_name='label', last_batch_handle='pad', **kwargs):
super(ImageDetIter, self).__init__(batch_size=batch_size, data_shape=data_shape,
path_imgrec=path_imgrec, path_imglist=path_imglist,
path_root=path_root, path_imgidx=path_imgidx,
shuffle=shuffle, part_index=part_index,
num_parts=num_parts, aug_list=[], imglist=imglist,
data_name=data_name, label_name=label_name)
data_name=data_name, label_name=label_name,
last_batch_handle=last_batch_handle)

if aug_list is None:
self.auglist = CreateDetAugmenter(data_shape, **kwargs)
Expand Down Expand Up @@ -751,14 +758,10 @@ def reshape(self, data_shape=None, label_shape=None):
self.provide_label = [(self.provide_label[0][0], (self.batch_size,) + label_shape)]
self.label_shape = label_shape

def next(self):
"""Override the function for returning next batch."""
def _batchify(self, batch_data, batch_label, start=0):
"""Override the helper function for batchifying data"""
i = start
batch_size = self.batch_size
c, h, w = self.data_shape
batch_data = nd.zeros((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
batch_label[:] = -1
i = 0
try:
while i < batch_size:
label, s = self.next_sample()
Expand All @@ -783,7 +786,48 @@ def next(self):
if not i:
raise StopIteration

return io.DataBatch([batch_data], [batch_label], batch_size - i)
return i

def next(self):
"""Override the function for returning next batch."""
batch_size = self.batch_size
c, h, w = self.data_shape
# if last batch data is rolled over
if self._cache_data is not None:
# check both the data and label have values
assert self._cache_label is not None, "_cache_label didn't have values"
assert self._cache_idx is not None, "_cache_idx didn't have values"
batch_data = self._cache_data
batch_label = self._cache_label
i = self._cache_idx
else:
batch_data = nd.zeros((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
batch_label[:] = -1
i = self._batchify(batch_data, batch_label)
# calculate the padding
pad = batch_size - i
# handle padding for the last batch
if pad != 0:
if self.last_batch_handle == 'discard':
raise StopIteration
# if the option is 'roll_over', throw StopIteration and cache the data
elif self.last_batch_handle == 'roll_over' and \
self._cache_data is None:
self._cache_data = batch_data
self._cache_label = batch_label
self._cache_idx = i
raise StopIteration
else:
_ = self._batchify(batch_data, batch_label, i)
if self.last_batch_handle == 'pad':
self._allow_read = False
else:
self._cache_data = None
self._cache_label = None
self._cache_idx = None

return io.DataBatch([batch_data], [batch_label], pad=pad)

def augmentation_transform(self, data, label): # pylint: disable=arguments-differ
"""Override Transforms input data with specified augmentations."""
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ def __init__(self, batch_size, data_shape, label_width=1,
self.shuffle = shuffle
if self.imgrec is None:
self.seq = imgkeys
elif shuffle or num_parts > 1:
elif shuffle or num_parts > 1 or path_imgidx:
assert self.imgidx is not None
self.seq = self.imgidx
else:
Expand Down Expand Up @@ -1261,7 +1261,7 @@ def next(self):
i = self._cache_idx
# clear the cache data
else:
batch_data = nd.empty((batch_size, c, h, w))
batch_data = nd.zeros((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
i = self._batchify(batch_data, batch_label)
# calculate the padding
Expand All @@ -1285,6 +1285,7 @@ def next(self):
self._cache_data = None
self._cache_label = None
self._cache_idx = None

return io.DataBatch([batch_data], [batch_label], pad=pad)

def check_data_shape(self, data_shape):
Expand Down
184 changes: 100 additions & 84 deletions tests/python/unittest/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from nose.tools import raises


def _get_data(url, dirname):
import os, tarfile
download(url, dirname=dirname, overwrite=False)
Expand All @@ -50,6 +51,62 @@ def _generate_objects():
label = np.hstack((cid[:, np.newaxis], boxes)).ravel().tolist()
return [2, 5] + label

def _test_imageiter_last_batch(imageiter_list, assert_data_shape):
test_iter = imageiter_list[0]
# test batch data shape
for _ in range(3):
for batch in test_iter:
assert batch.data[0].shape == assert_data_shape
test_iter.reset()
# test last batch handle(discard)
test_iter = imageiter_list[1]
i = 0
for batch in test_iter:
i += 1
assert i == 5
# test last_batch_handle(pad)
test_iter = imageiter_list[2]
i = 0
for batch in test_iter:
if i == 0:
first_three_data = batch.data[0][:2]
if i == 5:
last_three_data = batch.data[0][1:]
i += 1
assert i == 6
assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
# test last_batch_handle(roll_over)
test_iter = imageiter_list[3]
i = 0
for batch in test_iter:
if i == 0:
first_image = batch.data[0][0]
i += 1
assert i == 5
test_iter.reset()
first_batch_roll_over = test_iter.next()
assert np.array_equal(
first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
assert first_batch_roll_over.pad == 2
# test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
for _ in test_iter:
pass
test_iter.reset()
first_batch_roll_over_twice = test_iter.next()
assert np.array_equal(
first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
assert first_batch_roll_over_twice.pad == 1
# we've called next once
i = 1
for _ in test_iter:
i += 1
# test the third epoch with size 6
assert i == 6
# test shuffle option for sanity test
test_iter = imageiter_list[4]
for _ in test_iter:
pass


class TestImage(unittest.TestCase):
IMAGES_URL = "http://data.mxnet.io/data/test_images.tar.gz"
Expand Down Expand Up @@ -151,86 +208,32 @@ def test_color_normalize(self):
assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3)

def test_imageiter(self):
def check_imageiter(dtype='float32'):
im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
fname = './data/test_imageiter.lst'
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
for k, x in enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
for line in file_list:
f.write(line + '\n')

test_list = ['imglist', 'path_imglist']
im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
fname = './data/test_imageiter.lst'
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
for k, x in enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
for line in file_list:
f.write(line + '\n')

test_list = ['imglist', 'path_imglist']
for dtype in ['int32', 'float32', 'int64', 'float64']:
for test in test_list:
imglist = im_list if test == 'imglist' else None
path_imglist = fname if test == 'path_imglist' else None

test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype)
# test batch data shape
for _ in range(3):
for batch in test_iter:
assert batch.data[0].shape == (2, 3, 224, 224)
test_iter.reset()
# test last batch handle(discard)
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard')
i = 0
for batch in test_iter:
i += 1
assert i == 5
# test last_batch_handle(pad)
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
i = 0
for batch in test_iter:
if i == 0:
first_three_data = batch.data[0][:2]
if i == 5:
last_three_data = batch.data[0][1:]
i += 1
assert i == 6
assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
# test last_batch_handle(roll_over)
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over')
i = 0
for batch in test_iter:
if i == 0:
first_image = batch.data[0][0]
i += 1
assert i == 5
test_iter.reset()
first_batch_roll_over = test_iter.next()
assert np.array_equal(
first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
assert first_batch_roll_over.pad == 2
# test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
for _ in test_iter:
pass
test_iter.reset()
first_batch_roll_over_twice = test_iter.next()
assert np.array_equal(
first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
assert first_batch_roll_over_twice.pad == 1
# we've called next once
i = 1
for _ in test_iter:
i += 1
# test the third epoch with size 6
assert i == 6
# test shuffle option for sanity test
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
for _ in test_iter:
pass

for dtype in ['int32', 'float32', 'int64', 'float64']:
check_imageiter(dtype)

# test with default dtype
check_imageiter()
imageiter_list = [
mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype),
mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard'),
mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad'),
mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over'),
mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
]
_test_imageiter_last_batch(imageiter_list, (2, 3, 224, 224))

@with_seed()
def test_augmenters(self):
Expand Down Expand Up @@ -259,27 +262,40 @@ def test_image_detiter(self):
im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
for _ in range(3):
for batch in det_iter:
for _ in det_iter:
pass
det_iter.reset()

det_iter.reset()
val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
det_iter = val_iter.sync_label_shape(det_iter)
assert det_iter.data_shape == val_iter.data_shape
assert det_iter.label_shape == val_iter.label_shape

# test file list
# test batch_size is not divisible by number of images
det_iter = mx.image.ImageDetIter(4, (3, 300, 300), imglist=im_list, path_root='')
for _ in det_iter:
pass

# test file list with last batch handle
fname = './data/test_imagedetiter.lst'
im_list = [[k] + _generate_objects() + [x] for k, x in enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
for line in im_list:
line = '\t'.join([str(k) for k in line])
f.write(line + '\n')

det_iter = mx.image.ImageDetIter(2, (3, 400, 400), path_imglist=fname,
path_root='')
for batch in det_iter:
pass
imageiter_list = [
mx.image.ImageDetIter(2, (3, 400, 400),
path_imglist=fname, path_root=''),
mx.image.ImageDetIter(3, (3, 400, 400),
path_imglist=fname, path_root='', last_batch_handle='discard'),
mx.image.ImageDetIter(3, (3, 400, 400),
path_imglist=fname, path_root='', last_batch_handle='pad'),
mx.image.ImageDetIter(3, (3, 400, 400),
path_imglist=fname, path_root='', last_batch_handle='roll_over'),
mx.image.ImageDetIter(3, (3, 400, 400), shuffle=True,
path_imglist=fname, path_root='', last_batch_handle='pad')
]
_test_imageiter_last_batch(imageiter_list, (2, 3, 400, 400))

def test_det_augmenters(self):
# only test if all augmenters will work
Expand Down