Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix NDArrayIter iteration bug when last_batch_handle='pad' (#16166)
Browse files Browse the repository at this point in the history
* fix NDarrayIter pad bug

* retrigger CI
  • Loading branch information
stu1130 authored and roywei committed Dec 6, 2019
1 parent 8f53e77 commit 9c94fdb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
10 changes: 9 additions & 1 deletion python/mxnet/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,15 @@ def _batchify(self, data_source):
self.cursor + self.batch_size > self.num_data:
pad = self.batch_size - self.num_data + self.cursor
first_data = self._getdata(data_source, start=self.cursor)
second_data = self._getdata(data_source, end=pad)
if pad > self.num_data:
while True:
if pad <= self.num_data:
break
second_data = self._getdata(data_source, end=self.num_data)
pad -= self.num_data
second_data = self._concat(second_data, self._getdata(data_source, end=pad))
else:
second_data = self._getdata(data_source, end=pad)
return self._concat(first_data, second_data)
# normal case
else:
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ def _test_shuffle(data, labels=None):
assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]])
i += 1

# fixes the issue https://github.com/apache/incubator-mxnet/issues/15535
def _test_corner_case():
data = np.arange(10)
data_iter = mx.io.NDArrayIter(data=data, batch_size=25, shuffle=False, last_batch_handle='pad')
expect = np.concatenate((np.tile(data, 2), np.arange(5)))
assert np.array_equal(data_iter.next().data[0].asnumpy(), expect)


def test_NDArrayIter():
dtype_list = ['NDArray', 'ndarray']
Expand All @@ -220,6 +227,7 @@ def test_NDArrayIter():
_test_shuffle({'data1': data, 'data2': data})
_test_shuffle(data, [])
_test_shuffle(data)
_test_corner_case()


def test_NDArrayIter_h5py():
Expand Down

0 comments on commit 9c94fdb

Please sign in to comment.