From e1c54be698453eae4d6f5be949d60be1717f17d9 Mon Sep 17 00:00:00 2001 From: Piyush Ghai Date: Fri, 12 Apr 2019 15:48:00 -0700 Subject: [PATCH 1/3] Showing proper error when csr array is not 2D in shape. --- python/mxnet/ndarray/ndarray.py | 4 ++++ tests/python/unittest/test_sparse_ndarray.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 7e21daedcde1..3fb1af6a7336 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2227,6 +2227,10 @@ def tostype(self, stype): NDArray, CSRNDArray or RowSparseNDArray A copy of the array with the chosen storage stype """ + if stype == 'csr' and len(self.shape) != 2: + raise ValueError("To convert to a CSR, the NDArray should be 2 Dimensional. Current " + "shape is %s" % str(self.shape)) + return op.cast_storage(self, stype=stype) def to_dlpack_for_read(self): diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 3b4c684e8696..41a9b5d49c60 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -963,6 +963,11 @@ def test_sparse_nd_check_format(): indptr_list = [0, -2, 2, 3] a = mx.nd.sparse.csr_matrix((data_list, indices_list, indptr_list), shape=shape) assertRaises(mx.base.MXNetError, a.check_format) + # CSR format should be 2 Dimensional. + a = mx.nd.array([1, 2, 3]).tostype('csr') + assertRaises(ValueError, a.tostype, 'csr') + a = mx.nd.array([[[1, 2, 3]]]).tostype('csr') + assertRaises(ValueError, a.tostype, 'csr') # Row Sparse format indices should be less than the number of rows shape = (3, 2) data_list = [[1, 2], [3, 4]] From 2c3d846a7092038ba4a4958c67f80d47a30680c9 Mon Sep 17 00:00:00 2001 From: Piyush Ghai Date: Thu, 13 Jun 2019 17:14:03 -0700 Subject: [PATCH 2/3] Fixed failing CI --- tests/python/unittest/test_sparse_ndarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 41a9b5d49c60..9a1fce4ff197 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -964,9 +964,9 @@ def test_sparse_nd_check_format(): a = mx.nd.sparse.csr_matrix((data_list, indices_list, indptr_list), shape=shape) assertRaises(mx.base.MXNetError, a.check_format) # CSR format should be 2 Dimensional. - a = mx.nd.array([1, 2, 3]).tostype('csr') + a = mx.nd.array([1, 2, 3]) assertRaises(ValueError, a.tostype, 'csr') - a = mx.nd.array([[[1, 2, 3]]]).tostype('csr') + a = mx.nd.array([[[1, 2, 3]]]) assertRaises(ValueError, a.tostype, 'csr') # Row Sparse format indices should be less than the number of rows shape = (3, 2) From c06bbef3e4b629b31533d24438d553a030cd391c Mon Sep 17 00:00:00 2001 From: Piyush Ghai Date: Thu, 13 Jun 2019 21:00:32 -0700 Subject: [PATCH 3/3] Nudge to CI