Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix NDArrayIter cant pad when size is large
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Dec 7, 2019
1 parent 9c94fdb commit 7d21b43
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 19 deletions.
29 changes: 13 additions & 16 deletions python/mxnet/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,23 +709,19 @@ def _getdata(self, data_source, start=None, end=None):

def _concat(self, first_data, second_data):
"""Helper function to concat two NDArrays."""
if (not first_data) and (not second_data):
return []
elif (not first_data) or (not second_data):
return first_data if first_data else second_data
assert len(first_data) == len(
second_data), 'data source should contain the same size'
if first_data and second_data:
return [
concat(
first_data[x],
second_data[x],
dim=0
) for x in range(len(first_data))
]
elif (not first_data) and (not second_data):
return []
else:
return [
first_data[0] if first_data else second_data[0]
for x in range(len(first_data))
]
return [
concat(
first_data[x],
second_data[x],
dim=0
) for x in range(len(first_data))
]

def _batchify(self, data_source):
"""Load data from underlying arrays, internal use only."""
Expand All @@ -748,11 +744,12 @@ def _batchify(self, data_source):
self.cursor + self.batch_size > self.num_data:
pad = self.batch_size - self.num_data + self.cursor
first_data = self._getdata(data_source, start=self.cursor)
second_data = None
if pad > self.num_data:
while True:
if pad <= self.num_data:
break
second_data = self._getdata(data_source, end=self.num_data)
second_data = self._concat(second_data, self._getdata(data_source, end=self.num_data))
pad -= self.num_data
second_data = self._concat(second_data, self._getdata(data_source, end=pad))
else:
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,11 @@ def _test_shuffle(data, labels=None):
assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]])
i += 1

# fixes the issue https://github.com/apache/incubator-mxnet/issues/15535

def _test_corner_case():
data = np.arange(10)
data_iter = mx.io.NDArrayIter(data=data, batch_size=25, shuffle=False, last_batch_handle='pad')
expect = np.concatenate((np.tile(data, 2), np.arange(5)))
data_iter = mx.io.NDArrayIter(data=data, batch_size=205, shuffle=False, last_batch_handle='pad')
expect = np.concatenate((np.tile(data, 20), np.arange(5)))
assert np.array_equal(data_iter.next().data[0].asnumpy(), expect)


Expand Down

0 comments on commit 7d21b43

Please sign in to comment.