Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Showing proper error when csr array is not 2D in shape. (#15242)
Browse files Browse the repository at this point in the history
* Showing proper error when csr array is not 2D in shape.

* Fixed failing CI

* Nudge to CI
  • Loading branch information
piyushghai authored and eric-haibin-lin committed Jun 20, 2019
1 parent 12c4226 commit 2de0db0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2227,6 +2227,10 @@ def tostype(self, stype):
NDArray, CSRNDArray or RowSparseNDArray
A copy of the array with the chosen storage stype
"""
if stype == 'csr' and len(self.shape) != 2:
raise ValueError("To convert to a CSR, the NDArray should be 2 Dimensional. Current "
"shape is %s" % str(self.shape))

return op.cast_storage(self, stype=stype)

def to_dlpack_for_read(self):
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,11 @@ def test_sparse_nd_check_format():
indptr_list = [0, -2, 2, 3]
a = mx.nd.sparse.csr_matrix((data_list, indices_list, indptr_list), shape=shape)
assertRaises(mx.base.MXNetError, a.check_format)
# CSR format should be 2 Dimensional.
a = mx.nd.array([1, 2, 3])
assertRaises(ValueError, a.tostype, 'csr')
a = mx.nd.array([[[1, 2, 3]]])
assertRaises(ValueError, a.tostype, 'csr')
# Row Sparse format indices should be less than the number of rows
shape = (3, 2)
data_list = [[1, 2], [3, 4]]
Expand Down

0 comments on commit 2de0db0

Please sign in to comment.