diff --git a/python/mxnet/io/io.py b/python/mxnet/io/io.py index 2a42840bcf22..dcf964df976a 100644 --- a/python/mxnet/io/io.py +++ b/python/mxnet/io/io.py @@ -748,7 +748,15 @@ def _batchify(self, data_source): self.cursor + self.batch_size > self.num_data: pad = self.batch_size - self.num_data + self.cursor first_data = self._getdata(data_source, start=self.cursor) - second_data = self._getdata(data_source, end=pad) + if pad > self.num_data: + while True: + if pad <= self.num_data: + break + second_data = self._getdata(data_source, end=self.num_data) + pad -= self.num_data + second_data = self._concat(second_data, self._getdata(data_source, end=pad)) + else: + second_data = self._getdata(data_source, end=pad) return self._concat(first_data, second_data) # normal case else: diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 144f042ea719..2a806efc9034 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -198,6 +198,13 @@ def _test_shuffle(data, labels=None): assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]]) i += 1 +# fixes the issue https://github.com/apache/incubator-mxnet/issues/15535 +def _test_corner_case(): + data = np.arange(10) + data_iter = mx.io.NDArrayIter(data=data, batch_size=25, shuffle=False, last_batch_handle='pad') + expect = np.concatenate((np.tile(data, 2), np.arange(5))) + assert np.array_equal(data_iter.next().data[0].asnumpy(), expect) + def test_NDArrayIter(): dtype_list = ['NDArray', 'ndarray'] @@ -220,6 +227,7 @@ def test_NDArrayIter(): _test_shuffle({'data1': data, 'data2': data}) _test_shuffle(data, []) _test_shuffle(data) + _test_corner_case() def test_NDArrayIter_h5py():