From 3948930565d5fb36405151bad0288f4f11f1d35b Mon Sep 17 00:00:00 2001 From: r3stl355 Date: Wed, 2 Sep 2020 22:25:43 +0100 Subject: [PATCH] Assure NDArray.reshape does not change the array size --- python/mxnet/ndarray/ndarray.py | 7 ++++++- tests/python/unittest/test_ndarray.py | 3 ++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 0f638a1ed562..a6eae0511cc6 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -1546,7 +1546,12 @@ def reshape(self, *shape, **kwargs): c_array(ctypes.c_int64, shape), reverse, ctypes.byref(handle))) - return self.__class__(handle=handle, writable=self.writable) + res = self.__class__(handle=handle, writable=self.writable) + + # Array size should not change + if np.prod(res.shape) != np.prod(self.shape): + raise ValueError('Cannot reshape array of size {} into shape {}'.format(np.prod(self.shape), shape)) + return res def reshape_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`reshape_like`. diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index a01746ee5e6f..9e80d48c2460 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -237,7 +237,8 @@ def test_ndarray_reshape(): assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(), true_res.reshape(2, 3, 5).asnumpy()) assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10, 3).asnumpy()) assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy()) - + # https://github.com/apache/incubator-mxnet/issues/18886 + assertRaises(ValueError, tensor.reshape, (2, 3)) @with_seed() def test_ndarray_flatten():