diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala index ec63b750946c..f3fe638e41ed 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala @@ -152,31 +152,26 @@ class SparseNDArray private[mxnet] (override private[mxnet] val handle: NDArrayH } override def slice(start: Int, end: Int): NDArray = { - printf(s"\n\nSlice being called!!\n\nstart:$start end:$end") NDArray.api.slice(this, Shape(start), Shape(end)) } /** - * Get the Data portion from the row Sparse NDArray + * Get the Data portion from a Row Sparse NDArray * @return NDArray */ def getData: NDArray = { + require(this.sparseFormat == SparseFormat.ROW_SPARSE, "Not Supported for CSR") val handle = new NDArrayHandleRef - if (this.sparseFormat == SparseFormat.CSR) { - throw new UnsupportedOperationException("Not Supported for CSR") - } _LIB.mxNDArrayGetDataNDArray(this.handle, handle) new NDArray(handle.value, false) } /** - * Get the indptr Array + * Get the indptr Array from a CSR NDArray * @return NDArray */ def getIndptr: NDArray = { - if (this.sparseFormat == SparseFormat.ROW_SPARSE) { - throw new UnsupportedOperationException("Not Supported for row sparse") - } + require(this.sparseFormat == SparseFormat.CSR, "Not Supported for row sparse") getAuxNDArray(0) }