-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-737]Add last batch handle for imageiter #12131
[MXNET-737]Add last batch handle for imageiter #12131
Conversation
python/mxnet/image/image.py
Outdated
@@ -1059,16 +1059,21 @@ class ImageIter(io.DataIter): | |||
Label name for provided symbols. | |||
dtype : str | |||
Label data type. Default: float32. Other options: int32, int64, float64 | |||
last_batch_hanle : str, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use last_batch
to be consistent with https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=batchify_fn#mxnet.gluon.data.BatchSampler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original NDArrayIter call it last_batch_handle
but I think the name last_batch
already give much information of what it does. So I would go with your suggestion.
python/mxnet/image/image.py
Outdated
@@ -1059,16 +1059,21 @@ class ImageIter(io.DataIter): | |||
Label name for provided symbols. | |||
dtype : str | |||
Label data type. Default: float32. Other options: int32, int64, float64 | |||
last_batch_hanle : str, optional | |||
How to handle the last batch. This parameter can be ‘pad’, ‘discard’ or ‘roll_over’. | |||
'discard' is not support when reading from record file(.rec) withouting shuffle(=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is discard not supported explicitly? We can add a warning to this behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and please add explanation what each option stands for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason why discard is not supported is that when we read the rec file sequentially we don't know how many images in the file. Therefore, there is no way we can precalculate the number of images we need to discard. The only two solutions that I came up with is
- iterate the file during the initialization of data iterator
- allow users to input the number of the images
The first solution would take lots of time if the file is large during the initialization. The second one is not user-friendly. So I decided to give up this option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can check if valid batch.shape[0] is smaller than batch_size without checking how many images are there.
python/mxnet/image/image.py
Outdated
@@ -1149,22 +1157,53 @@ def __init__(self, batch_size, data_shape, label_width=1, | |||
else: | |||
self.auglist = aug_list | |||
self.cur = 0 | |||
self.is_iterated_over = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please make internal state variables private with leading underscore
python/mxnet/image/image.py
Outdated
if pad != 0: | ||
_ = self.iterate(batch_data, batch_label, i, self._imgrec) | ||
|
||
if self.last_batch_handle != 'roll_over': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to cache last batch in memory if roll_over? I think open/closing recordio is problematic
…_batch' 2. remove the extra RecordIO and refactor the code accordingly 3. support 'discard' for sequentially reading
…r both sequential read and random access
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
Welcome to MXNet community. Thanks for your first contributions!
Few changes requested.
python/mxnet/image/image.py
Outdated
if self.imgrec is not None: | ||
self.imgrec.reset() | ||
self.cur = 0 | ||
self._is_allowed_reading = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to set this again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to reset RecordIO, current cursor as long as the last_batch is not roll_over. As for self._is_allowed_reading, it needs to be set to True after reset. Consider the situation where the data iter with 'pad' or 'discard' have iterated till the last_batch, the self._is_allowed_reading is set to False(After did the padding, the value would set to False in case users call the next()), we need to set it back to True to prepare for next() function call.
python/mxnet/image/image.py
Outdated
|
||
def next_sample(self): | ||
"""Helper function for reading in next sample.""" | ||
if self._is_allowed_reading is False: | ||
raise StopIteration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add user comprehensible message when raising errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StopIteration is normal behavior, not for users
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup my bad. It is built in exception for iterators. Please ignore my comment.
python/mxnet/image/image.py
Outdated
if self.cur < self.num_image: | ||
idx = self.seq[self.cur] | ||
else: | ||
if self.last_batch != 'discard': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be != discard and == rollover?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No matter it is pad
or roll_over
, we need to "reset the cursor"(assign 0 to the self.cur) to do the padding
def next(self): | ||
"""Returns the next batch of data.""" | ||
batch_size = self.batch_size | ||
c, h, w = self.data_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we assuming always channels_first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we are
tests/python/unittest/test_image.py
Outdated
path_imglist=path_imglist, path_root='', dtype=dtype) | ||
for _ in range(3): | ||
for batch in test_iter: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we assert the returned batch size and shape here?
python/mxnet/image/image.py
Outdated
# handle the last batch | ||
if self.seq and last_batch == 'discard': | ||
new_seq_n = len(self.seq) - len(self.seq) % batch_size | ||
self.seq = self.seq[:new_seq_n] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not correct, you now change the behavior from initial shuffle to per epoch shuffle, so self.seq must be shuffled in reset, and the slice can not be fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to NDArrayIter, it shuffle only during the initialization. But you are right. We should shuffle for each epoch.
python/mxnet/image/image.py
Outdated
raise StopIteration | ||
header, img = recordio.unpack(s) | ||
return header.label, img | ||
|
||
def next(self): | ||
"""Returns the next batch of data.""" | ||
def iterate(self, batch_data, batch_label, start=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use _collect_batch
as name instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually _batchify
could be better IMO
python/mxnet/image/image.py
Outdated
pad = batch_size - i | ||
# handle padding for sequential read | ||
if pad != 0: | ||
if self.seq is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the logic of last_batch here is kind weird
…on. Now we check it in next() function 2. move shuffle back to reset function
…ta are not correct
python/mxnet/image/image.py
Outdated
# calculate the padding | ||
pad = batch_size - i | ||
# handle padding for 'pad' and 'roll_over' for the last batch | ||
if pad != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactor the padding logic here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Essentially LGTM now, please address remaining comments and we are good to merge.
python/mxnet/image/image.py
Outdated
random.shuffle(self.seq) | ||
if self.imgrec is not None: | ||
self.imgrec.reset() | ||
self.cur = 0 | ||
self._is_allowed_reading = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._allow_read
for shorter but same meaning
first_batch_roll_over_twice = test_iter.next() | ||
assert np.array_equal( | ||
first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy()) | ||
assert first_batch_roll_over_twice.pad == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert second epoch with size 6?
Also can you add test for shuffle=True
, just for sanity test, no value assert is required for shuffle mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the code change is completed.
Thanks, @zhreshold and @sandeep-krishnamurthy
# test the third epoch with size 6 | ||
assert i == 6 | ||
# test shuffle option for sanity test | ||
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add shuffle test case
for _ in test_iter: | ||
i += 1 | ||
# test the third epoch with size 6 | ||
assert i == 6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test epoch size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Thank you @stu1130 @zhreshold
if self.imgrec is not None: | ||
self.imgrec.reset() | ||
self.cur = 0 | ||
if self._allow_read is False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: just self._allow_read = True would work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it works as well. Only iter with 'pad' would change this flag though.
[MXNET-737]Add last batch handle for imageiter
Description
Add
last_batch_handle
parameter to ImageIter based on #11883pad
(default),discard
,roll_over
Note that reading record files(.rec) without shuffle(i.e. sequential read) didn't support 'discard'
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments
N/A
@zhreshold