From 7d21b4388162ebcbddaed3d5627abc80be2c59da Mon Sep 17 00:00:00 2001 From: stu1130 Date: Fri, 6 Dec 2019 16:39:01 -0800 Subject: [PATCH] Fix NDArrayIter cant pad when size is large --- python/mxnet/io/io.py | 29 +++++++++++++---------------- tests/python/unittest/test_io.py | 6 +++--- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/python/mxnet/io/io.py b/python/mxnet/io/io.py index dcf964df976a..3d9fb833d4ab 100644 --- a/python/mxnet/io/io.py +++ b/python/mxnet/io/io.py @@ -709,23 +709,19 @@ def _getdata(self, data_source, start=None, end=None): def _concat(self, first_data, second_data): """Helper function to concat two NDArrays.""" + if (not first_data) and (not second_data): + return [] + elif (not first_data) or (not second_data): + return first_data if first_data else second_data assert len(first_data) == len( second_data), 'data source should contain the same size' - if first_data and second_data: - return [ - concat( - first_data[x], - second_data[x], - dim=0 - ) for x in range(len(first_data)) - ] - elif (not first_data) and (not second_data): - return [] - else: - return [ - first_data[0] if first_data else second_data[0] - for x in range(len(first_data)) - ] + return [ + concat( + first_data[x], + second_data[x], + dim=0 + ) for x in range(len(first_data)) + ] def _batchify(self, data_source): """Load data from underlying arrays, internal use only.""" @@ -748,11 +744,12 @@ def _batchify(self, data_source): self.cursor + self.batch_size > self.num_data: pad = self.batch_size - self.num_data + self.cursor first_data = self._getdata(data_source, start=self.cursor) + second_data = None if pad > self.num_data: while True: if pad <= self.num_data: break - second_data = self._getdata(data_source, end=self.num_data) + second_data = self._concat(second_data, self._getdata(data_source, end=self.num_data)) pad -= self.num_data second_data = self._concat(second_data, self._getdata(data_source, end=pad)) else: diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 2a806efc9034..a13addb0adca 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -198,11 +198,11 @@ def _test_shuffle(data, labels=None): assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]]) i += 1 -# fixes the issue https://github.com/apache/incubator-mxnet/issues/15535 + def _test_corner_case(): data = np.arange(10) - data_iter = mx.io.NDArrayIter(data=data, batch_size=25, shuffle=False, last_batch_handle='pad') - expect = np.concatenate((np.tile(data, 2), np.arange(5))) + data_iter = mx.io.NDArrayIter(data=data, batch_size=205, shuffle=False, last_batch_handle='pad') + expect = np.concatenate((np.tile(data, 20), np.arange(5))) assert np.array_equal(data_iter.next().data[0].asnumpy(), expect)