Skip to content

Commit

Permalink
[1.x] Backport of apache#19078 (apache#19095)
Browse files Browse the repository at this point in the history
* Assure NDArray.reshape does not change the array size

* Truncate wikitext-2 to match target array size on reshape

Co-authored-by: r3stl355 <[email protected]>
  • Loading branch information
r3stl355 and ulmasov authored Sep 29, 2020
1 parent 93b7ff7 commit 16280ad
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
7 changes: 5 additions & 2 deletions python/mxnet/gluon/contrib/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,11 @@ def _get_data(self):

data, label = self._read_batch(path)

self._data = nd.array(data, dtype=data.dtype).reshape((-1, self._seq_len))
self._label = nd.array(label, dtype=label.dtype).reshape((-1, self._seq_len))
# https://github.com/apache/incubator-mxnet/issues/18886 breaks this unless array size is
# multiple of self._seq_len. Truncating the source is consistent with pre #18886 outcome
seq_len_mult = len(data) // self._seq_len * self._seq_len
self._data = nd.array(data, dtype=data.dtype)[:seq_len_mult].reshape((-1, self._seq_len))
self._label = nd.array(label, dtype=label.dtype)[:seq_len_mult].reshape((-1, self._seq_len))

def __getitem__(self, idx):
return self._data[idx], self._label[idx]
Expand Down
7 changes: 6 additions & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,12 @@ def reshape(self, *shape, **kwargs):
c_array(ctypes.c_int64, shape),
reverse,
ctypes.byref(handle)))
return self.__class__(handle=handle, writable=self.writable)
res = self.__class__(handle=handle, writable=self.writable)

# Array size should not change
if np.prod(res.shape) != np.prod(self.shape):
raise ValueError('Cannot reshape array of size {} into shape {}'.format(np.prod(self.shape), shape))
return res

def reshape_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`reshape_like`.
Expand Down
3 changes: 2 additions & 1 deletion tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def test_ndarray_reshape():
assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10, 3).asnumpy())
assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy())

# https://github.com/apache/incubator-mxnet/issues/18886
assertRaises(ValueError, tensor.reshape, (2, 3))

@with_seed()
def test_ndarray_flatten():
Expand Down

0 comments on commit 16280ad

Please sign in to comment.