Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Assure NDArray.reshape does not change the array size (#19078)
Browse files Browse the repository at this point in the history
Co-authored-by: r3stl355 <[email protected]>
  • Loading branch information
r3stl355 and ulmasov authored Sep 6, 2020
1 parent 23b3665 commit 62b7f03
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
7 changes: 6 additions & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,12 @@ def reshape(self, *shape, **kwargs):
c_array(ctypes.c_int64, shape),
reverse,
ctypes.byref(handle)))
return self.__class__(handle=handle, writable=self.writable)
res = self.__class__(handle=handle, writable=self.writable)

# Array size should not change
if np.prod(res.shape) != np.prod(self.shape):
raise ValueError('Cannot reshape array of size {} into shape {}'.format(np.prod(self.shape), shape))
return res

def reshape_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`reshape_like`.
Expand Down
3 changes: 2 additions & 1 deletion tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def test_ndarray_reshape():
assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10, 3).asnumpy())
assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy())

# https://github.com/apache/incubator-mxnet/issues/18886
assertRaises(ValueError, tensor.reshape, (2, 3))

@with_seed()
def test_ndarray_flatten():
Expand Down

0 comments on commit 62b7f03

Please sign in to comment.