diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 7e21daedcde1..3fb1af6a7336 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2227,6 +2227,10 @@ def tostype(self, stype): NDArray, CSRNDArray or RowSparseNDArray A copy of the array with the chosen storage stype """ + if stype == 'csr' and len(self.shape) != 2: + raise ValueError("To convert to a CSR, the NDArray should be 2 Dimensional. Current " + "shape is %s" % str(self.shape)) + return op.cast_storage(self, stype=stype) def to_dlpack_for_read(self): diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 3b4c684e8696..9a1fce4ff197 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -963,6 +963,11 @@ def test_sparse_nd_check_format(): indptr_list = [0, -2, 2, 3] a = mx.nd.sparse.csr_matrix((data_list, indices_list, indptr_list), shape=shape) assertRaises(mx.base.MXNetError, a.check_format) + # CSR format should be 2 Dimensional. + a = mx.nd.array([1, 2, 3]) + assertRaises(ValueError, a.tostype, 'csr') + a = mx.nd.array([[[1, 2, 3]]]) + assertRaises(ValueError, a.tostype, 'csr') # Row Sparse format indices should be less than the number of rows shape = (3, 2) data_list = [[1, 2], [3, 4]]