diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala index f3a8e8e9a4a5..1d5cc2847ac0 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala @@ -24,26 +24,17 @@ object DType extends Enumeration { val Float16 = Value(2, "float16") val UInt8 = Value(3, "uint8") val Int32 = Value(4, "int32") + val Int8 = Value(5, "int8") + val Int64 = Value(6, "int64") val Unknown = Value(-1, "unknown") private[mxnet] def numOfBytes(dtype: DType): Int = { dtype match { - case DType.UInt8 => 1 + case DType.UInt8 | DType.Int8 => 1 case DType.Int32 => 4 case DType.Float16 => 2 case DType.Float32 => 4 - case DType.Float64 => 8 + case DType.Float64 | DType.Int64 => 8 case DType.Unknown => 0 } } - private[mxnet] def getType(dtypeStr: String): DType = { - dtypeStr match { - case "UInt8" => DType.UInt8 - case "Int32" => DType.Int32 - case "Float16" => DType.Float16 - case "Float32" => DType.Float32 - case "Float64" => DType.Float64 - case _ => throw new IllegalArgumentException( - s"DType: $dtypeStr not found! please set it in DType.scala") - } - } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index b0fae0f9d58d..6365f9cb4645 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -159,7 +159,14 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, private def getOutputs: Array[NDArray] = { val ndHandles = ArrayBuffer[NDArrayHandle]() checkCall(_LIB.mxExecutorOutputs(handle, ndHandles)) - ndHandles.toArray.map(new NDArray(_, addToCollector = false)) + ndHandles.toArray.map(ele => { + val nd = new NDArray(ele, addToCollector = false) + if (nd.isSparse) { + nd.asInstanceOf[SparseNDArray] + } + nd + } + ) } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala index 640ecf5d5978..0ee6476be365 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala @@ -31,13 +31,14 @@ private[mxnet] class LibInfo { @native def mxListAllOpNames(names: ListBuffer[String]): Int @native def nnGetOpHandle(opName: String, opHandle: RefLong): Int // NDArray - @native def mxImperativeInvoke(creator: FunctionHandle, + @native def mxImperativeInvokeEx(creator: FunctionHandle, inputs: Array[NDArrayHandle], outputsGiven: Array[NDArrayHandle], outputs: ArrayBuffer[NDArrayHandle], numParams: Int, paramKeys: Array[String], - paramVals: Array[String]): Int + paramVals: Array[String], + outStype: ArrayBuffer[Int]): Int @native def mxNDArrayFree(handle: NDArrayHandle): Int @native def mxNDArrayCreateNone(out: NDArrayHandleRef): Int @native def mxNDArrayCreateEx(shape: Array[Int], @@ -47,6 +48,20 @@ private[mxnet] class LibInfo { delayAlloc: Int, dtype: Int, out: NDArrayHandleRef): Int + // scalastyle:off parameterNum + @native def mxNDArrayCreateSparseEx(storageType: Int, + shape: Array[Int], + ndim: Int, + devType: Int, + devId: Int, + delayAlloc: Int, + dtype: Int, + numAux: Int, + auxTypes: Array[Int], + auxNdims: Array[Int], + auxShapes: Array[Int], + out: NDArrayHandleRef): Int + // scalastyle:on parameterNum @native def mxNDArrayWaitAll(): Int @native def mxNDArrayWaitToRead(handle: NDArrayHandle): Int @native def mxListFunctions(functions: ListBuffer[FunctionHandle]): Int @@ -76,6 +91,9 @@ private[mxnet] class LibInfo { @native def mxNDArrayGetShape(handle: NDArrayHandle, ndim: MXUintRef, data: ArrayBuffer[Int]): Int + @native def mxNDArraySyncCopyFromNDArray(handleDst: NDArrayHandle, + handleSrc: NDArrayHandle, + locator: Int): Int @native def mxNDArraySyncCopyToCPU(handle: NDArrayHandle, data: Array[Byte], size: Int): Int @@ -105,10 +123,15 @@ private[mxnet] class LibInfo { @native def mxNDArraySave(fname: String, handles: Array[NDArrayHandle], keys: Array[String]): Int + @native def mxNDArrayGetDataNDArray(handle: NDArrayHandle, out: NDArrayHandleRef): Int + @native def mxNDArrayGetAuxNDArray(handle: NDArrayHandle, + location: Int, + out: NDArrayHandleRef): Int @native def mxNDArrayGetContext(handle: NDArrayHandle, devTypeId: RefInt, devId: RefInt): Int @native def mxNDArraySaveRawBytes(handle: NDArrayHandle, buf: ArrayBuffer[Byte]): Int @native def mxNDArrayLoadFromRawBytes(bytes: Array[Byte], handle: NDArrayHandleRef): Int @native def mxNDArrayGetDType(handle: NDArrayHandle, dtype: RefInt): Int + @native def mxNDArrayGetStorageType(handle: NDArrayHandle, stype: RefInt): Int // KVStore Server @native def mxInitPSEnv(keys: Array[String], values: Array[String]): Int diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 40888012fc5a..1b7b31b32cce 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -21,7 +21,8 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType -import org.apache.mxnet.MX_PRIMITIVES.{MX_PRIMITIVE_TYPE} +import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE +import org.apache.mxnet.SparseFormat.SparseFormat import org.slf4j.LoggerFactory import scala.collection.mutable @@ -113,10 +114,22 @@ object NDArray extends NDArrayBase { } val outputs = ArrayBuffer.empty[NDArrayHandle] - checkCall(_LIB.mxImperativeInvoke(function.handle, ndArgs.map(_.handle).toArray, outputVars, - outputs, updatedKwargs.size, updatedKwargs.keys.toArray, updatedKwargs.values.toArray)) + val outStypes = ArrayBuffer.empty[Int] + checkCall(_LIB.mxImperativeInvokeEx(function.handle, + ndArgs.map(_.handle).toArray, + outputVars, + outputs, + updatedKwargs.size, + updatedKwargs.keys.toArray, + updatedKwargs.values.toArray, + outStypes)) new NDArrayFuncReturn(Option(oriOutputs).getOrElse { - val outputArrs = outputs.map(new NDArray(_)).toArray + val outputArrs = (outputs zip outStypes).map( + ele => ele._2 match { + case 0 => new NDArray(ele._1) + case _ => new SparseNDArray(ele._1) + } + ).toArray addDependency(ndArgs.toArray, outputArrs) outputArrs }) @@ -943,6 +956,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, DType(mxDtype.value) } + val sparseFormat: SparseFormat = { + val mxSF = new RefInt + checkCall(_LIB.mxNDArrayGetStorageType(handle, mxSF)) + SparseFormat(mxSF.value) + } + /** * Return a copied numpy array of current array with specified type. * @param dtype Desired type of result array. @@ -1309,6 +1328,30 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, if (this.context == context) this else this.copyTo(context) } + /** + * check if NDArray is SparseNDArray + * @return Boolean + */ + def isSparse: Boolean = { + this.sparseFormat.id != 0 + } + + /** + * Convert a NDArray to SparseNDArray + * + * @param sfOption the target sparse type + * @return SparseNDArray + */ + def toSparse(sfOption : Option[SparseFormat] = None): SparseNDArray = { + val sf = sfOption.getOrElse(SparseFormat.ROW_SPARSE) + if (sf.id == 0) throw new IllegalArgumentException("Require Sparse") + if (isSparse && sfOption.isEmpty) { + this.asInstanceOf[SparseNDArray] + } else { + NDArray.api.cast_storage(this, sf.toString).head.asInstanceOf[SparseNDArray] + } + } + override def equals(o: Any): Boolean = o match { case that: NDArray => that != null && that.shape == this.shape && that.toArray.sameElements(this.toArray) @@ -1479,6 +1522,7 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private case DType.Float32 => units.map(wrapBytes(_).getFloat.toDouble) case DType.Float64 => units.map(wrapBytes(_).getDouble) case DType.Int32 => units.map(wrapBytes(_).getInt.toDouble) + case DType.Int64 => units.map(wrapBytes(_).getLong.toDouble) case DType.UInt8 => internal.map(_.toDouble) } } @@ -1488,6 +1532,7 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private case DType.Float32 => units.map(wrapBytes(_).getFloat) case DType.Float64 => units.map(wrapBytes(_).getDouble.toFloat) case DType.Int32 => units.map(wrapBytes(_).getInt.toFloat) + case DType.Int64 => units.map(wrapBytes(_).getLong.toFloat) case DType.UInt8 => internal.map(_.toFloat) } } @@ -1497,15 +1542,27 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private case DType.Float32 => units.map(wrapBytes(_).getFloat.toInt) case DType.Float64 => units.map(wrapBytes(_).getDouble.toInt) case DType.Int32 => units.map(wrapBytes(_).getInt) + case DType.Int64 => units.map(wrapBytes(_).getLong.toInt) case DType.UInt8 => internal.map(_.toInt) } } + def toLongArray: Array[Long] = { + require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types") + dtype match { + case DType.Float32 => units.map(wrapBytes(_).getFloat.toLong) + case DType.Float64 => units.map(wrapBytes(_).getDouble.toLong) + case DType.Int32 => units.map(wrapBytes(_).getInt.toLong) + case DType.Int64 => units.map(wrapBytes(_).getLong) + case DType.UInt8 => internal.map(_.toLong) + } + } def toByteArray: Array[Byte] = { require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types") dtype match { case DType.Float16 | DType.Float32 => units.map(wrapBytes(_).getFloat.toByte) case DType.Float64 => units.map(wrapBytes(_).getDouble.toByte) case DType.Int32 => units.map(wrapBytes(_).getInt.toByte) + case DType.Int64 => units.map(wrapBytes(_).getLong.toByte) case DType.UInt8 => internal.clone() } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SparseFormat.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SparseFormat.scala new file mode 100644 index 000000000000..acb0c0f24070 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SparseFormat.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +object SparseFormat extends Enumeration { + type SparseFormat = Value + val DEFAULT = Value(0, "default") + val ROW_SPARSE = Value(1, "row_sparse") + val CSR = Value(2, "csr") +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala new file mode 100644 index 000000000000..f3fe638e41ed --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.apache.mxnet.Base.{NDArrayHandle, NDArrayHandleRef, checkCall, _LIB} +import org.apache.mxnet.DType.DType +import org.apache.mxnet.SparseFormat.SparseFormat + +object SparseNDArray { + /** + * Create a Compressed Sparse Row Storage (CSR) Format Matrix + * @param data the data to feed + * @param indices The indices array stores the column index for each non-zero element in data + * @param indptr The indptr array is what will help identify the rows where the data appears + * @param shape the shape of CSR NDArray to be created + * @param ctx the context of this NDArray + * @return SparseNDArray + */ + def csrMatrix(data: Array[Float], indices: Array[Float], + indptr: Array[Float], shape: Shape, ctx: Context): SparseNDArray = { + val fmt = SparseFormat.CSR + val dataND = NDArray.array(data, Shape(data.length), ctx) + val indicesND = NDArray.array(indices, Shape(indices.length), ctx).asType(DType.Int64) + val indptrND = NDArray.array(indptr, Shape(indptr.length), ctx).asType(DType.Int64) + val dTypes = Array(indptrND.dtype, indicesND.dtype) + val shapes = Array(indptrND.shape, indicesND.shape) + val handle = + newAllocHandle(fmt, shape, ctx, false, DType.Float32, dTypes, shapes) + checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, dataND.handle, -1)) + checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, indptrND.handle, 0)) + checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, indicesND.handle, 1)) + new SparseNDArray(handle) + } + + /** + * RowSparseNDArray stores the matrix in row sparse format, + * which is designed for arrays of which most row slices are all zeros + * @param data Any Array(Array(... Array(Float))) + * @param indices the indices to store the data + * @param shape shape of the NDArray + * @param ctx Context + * @return SparseNDArray + */ + def rowSparseArray(data: Array[_], indices: Array[Float], + shape: Shape, ctx: Context): SparseNDArray = { + val dataND = NDArray.toNDArray(data) + val indicesND = NDArray.array(indices, Shape(indices.length), ctx).asType(DType.Int64) + rowSparseArray(dataND, indicesND, shape, ctx) + } + + /** + * RowSparseNDArray stores the matrix in row sparse format, + * which is designed for arrays of which most row slices are all zeros + * @param data NDArray input + * @param indices in NDArray. Only DType.Int64 supported + * @param shape shape of the NDArray + * @param ctx Context + * @return + */ + def rowSparseArray(data: NDArray, indices: NDArray, + shape: Shape, ctx: Context): SparseNDArray = { + val fmt = SparseFormat.ROW_SPARSE + val handle = newAllocHandle(fmt, shape, ctx, false, + DType.Float32, Array(indices.dtype), Array(indices.shape)) + checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, data.handle, -1)) + checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, indices.handle, 0)) + new SparseNDArray(handle) + } + + def retain(sparseNDArray: SparseNDArray, indices: Array[Float]): SparseNDArray = { + if (sparseNDArray.sparseFormat == SparseFormat.CSR) { + throw new IllegalArgumentException("CSR not supported") + } + NDArray.genericNDArrayFunctionInvoke("_sparse_retain", + Seq(sparseNDArray, NDArray.toNDArray(indices))).head.toSparse() + } + + private def newAllocHandle(stype : SparseFormat, + shape: Shape, + ctx: Context, + delayAlloc: Boolean, + dtype: DType = DType.Float32, + auxDTypes: Array[DType], + auxShapes: Array[Shape]) : NDArrayHandle = { + val hdl = new NDArrayHandleRef + checkCall(_LIB.mxNDArrayCreateSparseEx( + stype.id, + shape.toArray, + shape.length, + ctx.deviceTypeid, + ctx.deviceId, + if (delayAlloc) 1 else 0, + dtype.id, + auxDTypes.length, + auxDTypes.map(_.id), + auxShapes.map(_.length), + auxShapes.map(_.get(0)), + hdl) + ) + hdl.value + } +} + +/** + * Sparse NDArray is the child class of NDArray designed to hold the Sparse format + * + *

Currently, Rowsparse and CSR typed NDArray is supported. Most of the Operators + * will convert Sparse NDArray to dense. Basic operators like add will + * have optimization for sparse operattions

+ * @param handle The pointer that SparseNDArray holds + * @param writable whether the NDArray is writable + */ +class SparseNDArray private[mxnet] (override private[mxnet] val handle: NDArrayHandle, + override val writable: Boolean = true) + extends NDArray(handle, writable) { + + private lazy val dense: NDArray = toDense + + override def toString: String = { + dense.toString + } + + /** + * Convert a SparseNDArray to dense NDArray + * @return NDArray + */ + def toDense: NDArray = { + NDArray.api.cast_storage(this, SparseFormat.DEFAULT.toString).head + } + + override def toArray: Array[Float] = { + dense.toArray + } + + override def at(idx: Int): NDArray = { + dense.at(idx) + } + + override def slice(start: Int, end: Int): NDArray = { + NDArray.api.slice(this, Shape(start), Shape(end)) + } + + /** + * Get the Data portion from a Row Sparse NDArray + * @return NDArray + */ + def getData: NDArray = { + require(this.sparseFormat == SparseFormat.ROW_SPARSE, "Not Supported for CSR") + val handle = new NDArrayHandleRef + _LIB.mxNDArrayGetDataNDArray(this.handle, handle) + new NDArray(handle.value, false) + } + + /** + * Get the indptr Array from a CSR NDArray + * @return NDArray + */ + def getIndptr: NDArray = { + require(this.sparseFormat == SparseFormat.CSR, "Not Supported for row sparse") + getAuxNDArray(0) + } + + /** + * Get the indice Array + * @return NDArray + */ + def getIndices: NDArray = { + if (this.sparseFormat == SparseFormat.ROW_SPARSE) { + getAuxNDArray(0) + } else { + getAuxNDArray(1) + } + } + + private def getAuxNDArray(idx: Int): NDArray = { + val handle = new NDArrayHandleRef + checkCall(_LIB.mxNDArrayGetAuxNDArray(this.handle, idx, handle)) + new NDArray(handle.value, false) + } + +} diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index c2ef641f9c9a..82b9edc8f4bb 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -45,6 +45,22 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(ndones.toScalar === 1f) } + test("to sparse") { + val arr = Array( + Array(1f, 0f, 0f), + Array(0f, 3f, 0f), + Array(0f, 0f, 1f) + ) + val nd = NDArray.toNDArray(arr) + assert(!nd.isSparse) + // row sparse + var ndSparse = nd.toSparse() + assert(ndSparse.getIndices.toArray sameElements Array(0f, 1f, 2f)) + // csr + ndSparse = nd.toSparse(Some(SparseFormat.CSR)) + assert(ndSparse.getIndptr.toArray sameElements Array(0f, 1f, 2f, 3f)) + } + test("to float 64 scalar") { val ndzeros = NDArray.zeros(Shape(1), dtype = DType.Float64) assert(ndzeros.toFloat64Scalar === 0d) diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/SparseNDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/SparseNDArraySuite.scala new file mode 100644 index 000000000000..f9968efd80c5 --- /dev/null +++ b/scala-package/core/src/test/scala/org/apache/mxnet/SparseNDArraySuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.apache.mxnet.io.NDArrayIter +import org.scalatest.FunSuite +import org.slf4j.LoggerFactory + +class SparseNDArraySuite extends FunSuite { + + private val logger = LoggerFactory.getLogger(classOf[SparseNDArraySuite]) + + test("create CSR NDArray") { + val data = Array(7f, 8f, 9f) + val indices = Array(0f, 2f, 1f) + val indptr = Array(0f, 2f, 2f, 3f) + val shape = Shape(3, 4) + val sparseND = SparseNDArray.csrMatrix(data, indices, indptr, shape, Context.cpu()) + assert(sparseND.shape == Shape(3, 4)) + assert(sparseND.toArray + sameElements Array(7.0f, 0.0f, 8.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 9.0f, 0.0f, 0.0f)) + assert(sparseND.sparseFormat == SparseFormat.CSR) + assert(sparseND.getIndptr.toArray sameElements indptr) + assert(sparseND.getIndices.toArray sameElements indices) + } + + test("create Row Sparse NDArray") { + val data = Array( + Array(1f, 2f), + Array(3f, 4f) + ) + val indices = Array(1f, 4f) + val shape = Shape(6, 2) + val sparseND = SparseNDArray.rowSparseArray(data, indices, shape, Context.cpu()) + assert(sparseND.sparseFormat == SparseFormat.ROW_SPARSE) + assert(sparseND.shape == Shape(6, 2)) + assert(sparseND.at(1).toArray sameElements Array(1f, 2f)) + assert(sparseND.getIndices.toArray sameElements indices) + } + + test("Test retain") { + val arr = Array( + Array(1f, 2f), + Array(3f, 4f), + Array(5f, 6f) + ) + val indices = Array(0f, 1f, 3f) + val rspIn = SparseNDArray.rowSparseArray(arr, indices, Shape(4, 2), Context.cpu()) + val toRetain = Array(0f, 3f) + val rspOut = SparseNDArray.retain(rspIn, toRetain) + assert(rspOut.getData.toArray sameElements Array(1f, 2f, 5f, 6f)) + assert(rspOut.getIndices.toArray sameElements Array(0f, 3f)) + } + + test("Test add") { + val nd = NDArray.array(Array(1f, 2f, 3f), Shape(3)).toSparse(Some(SparseFormat.ROW_SPARSE)) + val nd2 = nd + nd + assert(nd2.isInstanceOf[SparseNDArray]) + assert(nd2.toArray sameElements Array(2f, 4f, 6f)) + } + + test("Test DataIter") { + val nd = NDArray.array(Array(1f, 2f, 3f), Shape(1, 3)).toSparse(Some(SparseFormat.CSR)) + val arr = IndexedSeq(nd, nd, nd, nd) + val iter = new NDArrayIter(arr) + while (iter.hasNext) { + val tempArr = iter.next().data + tempArr.foreach(ele => { + assert(ele.sparseFormat == SparseFormat.CSR) + assert(ele.shape == Shape(1, 3)) + }) + } + } + + +} diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index 9b19fd360fc4..387a0b17e252 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -93,6 +93,31 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateEx return ret; } +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateSparseEx + (JNIEnv *env, jobject obj, jint storageType, jintArray shape, jint ndim, jint devType, + jint devId, jint delayAlloc, jint dtype, jint numAux, jintArray auxTypes, + jintArray auxNdims, jintArray auxShapes, jobject ndArrayHandle) { + jint *shapeArr = env->GetIntArrayElements(shape, NULL); + jint *auxTypesArr = env->GetIntArrayElements(auxTypes, NULL); + jint *auxNdimsArr = env->GetIntArrayElements(auxNdims, NULL); + jint *auxShapesArr = env->GetIntArrayElements(auxShapes, NULL); + NDArrayHandle out; + int ret = MXNDArrayCreateSparseEx(storageType, + reinterpret_cast(shapeArr), + static_cast(ndim), + devType, devId, delayAlloc, dtype, + static_cast(numAux), + reinterpret_cast(auxTypesArr), + reinterpret_cast(auxNdimsArr), + reinterpret_cast(auxShapesArr), &out); + env->ReleaseIntArrayElements(shape, shapeArr, 0); + env->ReleaseIntArrayElements(auxTypes, auxTypesArr, 0); + env->ReleaseIntArrayElements(auxNdims, auxNdimsArr, 0); + env->ReleaseIntArrayElements(auxShapes, auxShapesArr, 0); + SetLongField(env, ndArrayHandle, reinterpret_cast(out)); + return ret; +} + JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayWaitAll(JNIEnv *env, jobject obj) { return MXNDArrayWaitAll(); } @@ -179,10 +204,10 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncGetInfo return ret; } -JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvoke +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvokeEx (JNIEnv *env, jobject obj, jlong funcPtr, jlongArray inputs, jlongArray outputsGiven, jobject outputs, jint numParams, - jobjectArray paramKeys, jobjectArray paramVals) { + jobjectArray paramKeys, jobjectArray paramVals, jobject outStypes) { const char **cParamKeys = NULL; const char **cParamVals = NULL; @@ -204,6 +229,7 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvoke int numOutputs = 0; jlong *cOutputsGiven = NULL; NDArrayHandle *cOutputs = NULL; + const int *cOutStypes; if (outputsGiven) { cOutputsGiven = env->GetLongArrayElements(outputsGiven, NULL); cOutputs = reinterpret_cast(cOutputsGiven); @@ -211,14 +237,15 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvoke } jlong *cInputs = env->GetLongArrayElements(inputs, NULL); jsize numInputs = env->GetArrayLength(inputs); - int ret = MXImperativeInvoke(reinterpret_cast(funcPtr), + int ret = MXImperativeInvokeEx(reinterpret_cast(funcPtr), static_cast(numInputs), reinterpret_cast(cInputs), &numOutputs, &cOutputs, static_cast(numParams), cParamKeys, - cParamVals); + cParamVals, + &cOutStypes); env->ReleaseLongArrayElements(inputs, cInputs, 0); if (cOutputsGiven) { env->ReleaseLongArrayElements(outputsGiven, cOutputsGiven, 0); @@ -240,7 +267,9 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvoke if (cOutputs) { jclass longCls = env->FindClass("java/lang/Long"); + jclass intCls = env->FindClass("java/lang/Integer"); jmethodID longConst = env->GetMethodID(longCls, "", "(J)V"); + jmethodID intConst = env->GetMethodID(intCls, "", "(I)V"); // scala.collection.mutable.ListBuffer append method jclass listClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq", @@ -249,6 +278,9 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvoke env->CallObjectMethod(outputs, listAppend, env->NewObject(longCls, longConst, reinterpret_cast(cOutputs[i]))); + env->CallObjectMethod(outStypes, listAppend, + env->NewObject(intCls, intConst, + cOutStypes[i])); } } @@ -379,6 +411,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape return ret; } +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromNDArray + (JNIEnv *env, jobject obj, jlong dstPtr, jlong srcPtr, jint locator) { + int ret = MXNDArraySyncCopyFromNDArray(reinterpret_cast(dstPtr), + reinterpret_cast(srcPtr), + locator); + return ret; +} + JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyToCPU (JNIEnv *env, jobject obj, jlong ndArrayPtr, jbyteArray data, jint size) { jbyte *pdata = env->GetByteArrayElements(data, NULL); @@ -434,6 +474,25 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFro return ret; } +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDataNDArray + (JNIEnv *env, jobject obj, jlong arrayPtr, jobject ndArrayHandle) { + NDArrayHandle out; + int ret = MXNDArrayGetDataNDArray(reinterpret_cast(arrayPtr), + &out); + SetLongField(env, ndArrayHandle, reinterpret_cast(out)); + return ret; +} + +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetAuxNDArray + (JNIEnv *env, jobject obj, jlong arrayPtr, jint location, jobject ndArrayHandle) { + NDArrayHandle out; + int ret = MXNDArrayGetAuxNDArray(reinterpret_cast(arrayPtr), + static_cast(location), + &out); + SetLongField(env, ndArrayHandle, reinterpret_cast(out)); + return ret; +} + JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetContext (JNIEnv *env, jobject obj, jlong arrayPtr, jobject devTypeId, jobject devId) { int outDevType; @@ -540,6 +599,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDType return ret; } +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetStorageType + (JNIEnv * env, jobject obj, jlong jhandle, jobject jstype) { + int stype; + int ret = MXNDArrayGetStorageType(reinterpret_cast(jhandle), &stype); + SetIntField(env, jstype, stype); + return ret; +} + JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxInitPSEnv (JNIEnv *env, jobject obj, jobjectArray jkeys, jobjectArray jvals) { // keys and values diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h index fac32bb0a410..c8ee0ce8d22b 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h @@ -41,11 +41,11 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_nnGetOpHandle /* * Class: org_apache_mxnet_LibInfo - * Method: mxImperativeInvoke - * Signature: (J[J[JLscala/collection/mutable/ArrayBuffer;I[Ljava/lang/String;[Ljava/lang/String;)I + * Method: mxImperativeInvokeEx + * Signature: (J[J[JLscala/collection/mutable/ArrayBuffer;I[Ljava/lang/String;[Ljava/lang/String;Lscala/collection/mutable/ArrayBuffer;)I */ -JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvoke - (JNIEnv *, jobject, jlong, jlongArray, jlongArray, jobject, jint, jobjectArray, jobjectArray); +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvokeEx + (JNIEnv *, jobject, jlong, jlongArray, jlongArray, jobject, jint, jobjectArray, jobjectArray, jobject); /* * Class: org_apache_mxnet_LibInfo @@ -71,6 +71,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateNone JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateEx (JNIEnv *, jobject, jintArray, jint, jint, jint, jint, jint, jobject); +/* + * Class: org_apache_mxnet_LibInfo + * Method: mxNDArrayCreateSparseEx + * Signature: (I[IIIIIII[I[I[ILorg/apache/mxnet/Base/RefLong;)I + */ +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateSparseEx + (JNIEnv *, jobject, jint, jintArray, jint, jint, jint, jint, jint, jint, jintArray, jintArray, jintArray, jobject); + /* * Class: org_apache_mxnet_LibInfo * Method: mxNDArrayWaitAll @@ -135,6 +143,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncInvokeEx JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape (JNIEnv *, jobject, jlong, jobject, jobject); +/* + * Class: org_apache_mxnet_LibInfo + * Method: mxNDArraySyncCopyFromNDArray + * Signature: (JJI)I + */ +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromNDArray + (JNIEnv *, jobject, jlong, jlong, jint); + /* * Class: org_apache_mxnet_LibInfo * Method: mxNDArraySyncCopyToCPU @@ -199,6 +215,22 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayLoad JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySave (JNIEnv *, jobject, jstring, jlongArray, jobjectArray); +/* + * Class: org_apache_mxnet_LibInfo + * Method: mxNDArrayGetDataNDArray + * Signature: (JLorg/apache/mxnet/Base/RefLong;)I + */ +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDataNDArray + (JNIEnv *, jobject, jlong, jobject); + +/* + * Class: org_apache_mxnet_LibInfo + * Method: mxNDArrayGetAuxNDArray + * Signature: (JILorg/apache/mxnet/Base/RefLong;)I + */ +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetAuxNDArray + (JNIEnv *, jobject, jlong, jint, jobject); + /* * Class: org_apache_mxnet_LibInfo * Method: mxNDArrayGetContext @@ -231,6 +263,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayLoadFromRawBytes JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDType (JNIEnv *, jobject, jlong, jobject); +/* + * Class: org_apache_mxnet_LibInfo + * Method: mxNDArrayGetStorageType + * Signature: (JLorg/apache/mxnet/Base/RefInt;)I + */ +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetStorageType + (JNIEnv *, jobject, jlong, jobject); + /* * Class: org_apache_mxnet_LibInfo * Method: mxInitPSEnv