diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala index f3a8e8e9a4a5..1d5cc2847ac0 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala @@ -24,26 +24,17 @@ object DType extends Enumeration { val Float16 = Value(2, "float16") val UInt8 = Value(3, "uint8") val Int32 = Value(4, "int32") + val Int8 = Value(5, "int8") + val Int64 = Value(6, "int64") val Unknown = Value(-1, "unknown") private[mxnet] def numOfBytes(dtype: DType): Int = { dtype match { - case DType.UInt8 => 1 + case DType.UInt8 | DType.Int8 => 1 case DType.Int32 => 4 case DType.Float16 => 2 case DType.Float32 => 4 - case DType.Float64 => 8 + case DType.Float64 | DType.Int64 => 8 case DType.Unknown => 0 } } - private[mxnet] def getType(dtypeStr: String): DType = { - dtypeStr match { - case "UInt8" => DType.UInt8 - case "Int32" => DType.Int32 - case "Float16" => DType.Float16 - case "Float32" => DType.Float32 - case "Float64" => DType.Float64 - case _ => throw new IllegalArgumentException( - s"DType: $dtypeStr not found! please set it in DType.scala") - } - } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index b0fae0f9d58d..6365f9cb4645 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -159,7 +159,14 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, private def getOutputs: Array[NDArray] = { val ndHandles = ArrayBuffer[NDArrayHandle]() checkCall(_LIB.mxExecutorOutputs(handle, ndHandles)) - ndHandles.toArray.map(new NDArray(_, addToCollector = false)) + ndHandles.toArray.map(ele => { + val nd = new NDArray(ele, addToCollector = false) + if (nd.isSparse) { + nd.asInstanceOf[SparseNDArray] + } + nd + } + ) } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala index 640ecf5d5978..0ee6476be365 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala @@ -31,13 +31,14 @@ private[mxnet] class LibInfo { @native def mxListAllOpNames(names: ListBuffer[String]): Int @native def nnGetOpHandle(opName: String, opHandle: RefLong): Int // NDArray - @native def mxImperativeInvoke(creator: FunctionHandle, + @native def mxImperativeInvokeEx(creator: FunctionHandle, inputs: Array[NDArrayHandle], outputsGiven: Array[NDArrayHandle], outputs: ArrayBuffer[NDArrayHandle], numParams: Int, paramKeys: Array[String], - paramVals: Array[String]): Int + paramVals: Array[String], + outStype: ArrayBuffer[Int]): Int @native def mxNDArrayFree(handle: NDArrayHandle): Int @native def mxNDArrayCreateNone(out: NDArrayHandleRef): Int @native def mxNDArrayCreateEx(shape: Array[Int], @@ -47,6 +48,20 @@ private[mxnet] class LibInfo { delayAlloc: Int, dtype: Int, out: NDArrayHandleRef): Int + // scalastyle:off parameterNum + @native def mxNDArrayCreateSparseEx(storageType: Int, + shape: Array[Int], + ndim: Int, + devType: Int, + devId: Int, + delayAlloc: Int, + dtype: Int, + numAux: Int, + auxTypes: Array[Int], + auxNdims: Array[Int], + auxShapes: Array[Int], + out: NDArrayHandleRef): Int + // scalastyle:on parameterNum @native def mxNDArrayWaitAll(): Int @native def mxNDArrayWaitToRead(handle: NDArrayHandle): Int @native def mxListFunctions(functions: ListBuffer[FunctionHandle]): Int @@ -76,6 +91,9 @@ private[mxnet] class LibInfo { @native def mxNDArrayGetShape(handle: NDArrayHandle, ndim: MXUintRef, data: ArrayBuffer[Int]): Int + @native def mxNDArraySyncCopyFromNDArray(handleDst: NDArrayHandle, + handleSrc: NDArrayHandle, + locator: Int): Int @native def mxNDArraySyncCopyToCPU(handle: NDArrayHandle, data: Array[Byte], size: Int): Int @@ -105,10 +123,15 @@ private[mxnet] class LibInfo { @native def mxNDArraySave(fname: String, handles: Array[NDArrayHandle], keys: Array[String]): Int + @native def mxNDArrayGetDataNDArray(handle: NDArrayHandle, out: NDArrayHandleRef): Int + @native def mxNDArrayGetAuxNDArray(handle: NDArrayHandle, + location: Int, + out: NDArrayHandleRef): Int @native def mxNDArrayGetContext(handle: NDArrayHandle, devTypeId: RefInt, devId: RefInt): Int @native def mxNDArraySaveRawBytes(handle: NDArrayHandle, buf: ArrayBuffer[Byte]): Int @native def mxNDArrayLoadFromRawBytes(bytes: Array[Byte], handle: NDArrayHandleRef): Int @native def mxNDArrayGetDType(handle: NDArrayHandle, dtype: RefInt): Int + @native def mxNDArrayGetStorageType(handle: NDArrayHandle, stype: RefInt): Int // KVStore Server @native def mxInitPSEnv(keys: Array[String], values: Array[String]): Int diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 40888012fc5a..1b7b31b32cce 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -21,7 +21,8 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType -import org.apache.mxnet.MX_PRIMITIVES.{MX_PRIMITIVE_TYPE} +import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE +import org.apache.mxnet.SparseFormat.SparseFormat import org.slf4j.LoggerFactory import scala.collection.mutable @@ -113,10 +114,22 @@ object NDArray extends NDArrayBase { } val outputs = ArrayBuffer.empty[NDArrayHandle] - checkCall(_LIB.mxImperativeInvoke(function.handle, ndArgs.map(_.handle).toArray, outputVars, - outputs, updatedKwargs.size, updatedKwargs.keys.toArray, updatedKwargs.values.toArray)) + val outStypes = ArrayBuffer.empty[Int] + checkCall(_LIB.mxImperativeInvokeEx(function.handle, + ndArgs.map(_.handle).toArray, + outputVars, + outputs, + updatedKwargs.size, + updatedKwargs.keys.toArray, + updatedKwargs.values.toArray, + outStypes)) new NDArrayFuncReturn(Option(oriOutputs).getOrElse { - val outputArrs = outputs.map(new NDArray(_)).toArray + val outputArrs = (outputs zip outStypes).map( + ele => ele._2 match { + case 0 => new NDArray(ele._1) + case _ => new SparseNDArray(ele._1) + } + ).toArray addDependency(ndArgs.toArray, outputArrs) outputArrs }) @@ -943,6 +956,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, DType(mxDtype.value) } + val sparseFormat: SparseFormat = { + val mxSF = new RefInt + checkCall(_LIB.mxNDArrayGetStorageType(handle, mxSF)) + SparseFormat(mxSF.value) + } + /** * Return a copied numpy array of current array with specified type. * @param dtype Desired type of result array. @@ -1309,6 +1328,30 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, if (this.context == context) this else this.copyTo(context) } + /** + * check if NDArray is SparseNDArray + * @return Boolean + */ + def isSparse: Boolean = { + this.sparseFormat.id != 0 + } + + /** + * Convert a NDArray to SparseNDArray + * + * @param sfOption the target sparse type + * @return SparseNDArray + */ + def toSparse(sfOption : Option[SparseFormat] = None): SparseNDArray = { + val sf = sfOption.getOrElse(SparseFormat.ROW_SPARSE) + if (sf.id == 0) throw new IllegalArgumentException("Require Sparse") + if (isSparse && sfOption.isEmpty) { + this.asInstanceOf[SparseNDArray] + } else { + NDArray.api.cast_storage(this, sf.toString).head.asInstanceOf[SparseNDArray] + } + } + override def equals(o: Any): Boolean = o match { case that: NDArray => that != null && that.shape == this.shape && that.toArray.sameElements(this.toArray) @@ -1479,6 +1522,7 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private case DType.Float32 => units.map(wrapBytes(_).getFloat.toDouble) case DType.Float64 => units.map(wrapBytes(_).getDouble) case DType.Int32 => units.map(wrapBytes(_).getInt.toDouble) + case DType.Int64 => units.map(wrapBytes(_).getLong.toDouble) case DType.UInt8 => internal.map(_.toDouble) } } @@ -1488,6 +1532,7 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private case DType.Float32 => units.map(wrapBytes(_).getFloat) case DType.Float64 => units.map(wrapBytes(_).getDouble.toFloat) case DType.Int32 => units.map(wrapBytes(_).getInt.toFloat) + case DType.Int64 => units.map(wrapBytes(_).getLong.toFloat) case DType.UInt8 => internal.map(_.toFloat) } } @@ -1497,15 +1542,27 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private case DType.Float32 => units.map(wrapBytes(_).getFloat.toInt) case DType.Float64 => units.map(wrapBytes(_).getDouble.toInt) case DType.Int32 => units.map(wrapBytes(_).getInt) + case DType.Int64 => units.map(wrapBytes(_).getLong.toInt) case DType.UInt8 => internal.map(_.toInt) } } + def toLongArray: Array[Long] = { + require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types") + dtype match { + case DType.Float32 => units.map(wrapBytes(_).getFloat.toLong) + case DType.Float64 => units.map(wrapBytes(_).getDouble.toLong) + case DType.Int32 => units.map(wrapBytes(_).getInt.toLong) + case DType.Int64 => units.map(wrapBytes(_).getLong) + case DType.UInt8 => internal.map(_.toLong) + } + } def toByteArray: Array[Byte] = { require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types") dtype match { case DType.Float16 | DType.Float32 => units.map(wrapBytes(_).getFloat.toByte) case DType.Float64 => units.map(wrapBytes(_).getDouble.toByte) case DType.Int32 => units.map(wrapBytes(_).getInt.toByte) + case DType.Int64 => units.map(wrapBytes(_).getLong.toByte) case DType.UInt8 => internal.clone() } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SparseFormat.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SparseFormat.scala new file mode 100644 index 000000000000..acb0c0f24070 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SparseFormat.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +object SparseFormat extends Enumeration { + type SparseFormat = Value + val DEFAULT = Value(0, "default") + val ROW_SPARSE = Value(1, "row_sparse") + val CSR = Value(2, "csr") +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala new file mode 100644 index 000000000000..f3fe638e41ed --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.apache.mxnet.Base.{NDArrayHandle, NDArrayHandleRef, checkCall, _LIB} +import org.apache.mxnet.DType.DType +import org.apache.mxnet.SparseFormat.SparseFormat + +object SparseNDArray { + /** + * Create a Compressed Sparse Row Storage (CSR) Format Matrix + * @param data the data to feed + * @param indices The indices array stores the column index for each non-zero element in data + * @param indptr The indptr array is what will help identify the rows where the data appears + * @param shape the shape of CSR NDArray to be created + * @param ctx the context of this NDArray + * @return SparseNDArray + */ + def csrMatrix(data: Array[Float], indices: Array[Float], + indptr: Array[Float], shape: Shape, ctx: Context): SparseNDArray = { + val fmt = SparseFormat.CSR + val dataND = NDArray.array(data, Shape(data.length), ctx) + val indicesND = NDArray.array(indices, Shape(indices.length), ctx).asType(DType.Int64) + val indptrND = NDArray.array(indptr, Shape(indptr.length), ctx).asType(DType.Int64) + val dTypes = Array(indptrND.dtype, indicesND.dtype) + val shapes = Array(indptrND.shape, indicesND.shape) + val handle = + newAllocHandle(fmt, shape, ctx, false, DType.Float32, dTypes, shapes) + checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, dataND.handle, -1)) + checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, indptrND.handle, 0)) + checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, indicesND.handle, 1)) + new SparseNDArray(handle) + } + + /** + * RowSparseNDArray stores the matrix in row sparse format, + * which is designed for arrays of which most row slices are all zeros + * @param data Any Array(Array(... Array(Float))) + * @param indices the indices to store the data + * @param shape shape of the NDArray + * @param ctx Context + * @return SparseNDArray + */ + def rowSparseArray(data: Array[_], indices: Array[Float], + shape: Shape, ctx: Context): SparseNDArray = { + val dataND = NDArray.toNDArray(data) + val indicesND = NDArray.array(indices, Shape(indices.length), ctx).asType(DType.Int64) + rowSparseArray(dataND, indicesND, shape, ctx) + } + + /** + * RowSparseNDArray stores the matrix in row sparse format, + * which is designed for arrays of which most row slices are all zeros + * @param data NDArray input + * @param indices in NDArray. Only DType.Int64 supported + * @param shape shape of the NDArray + * @param ctx Context + * @return + */ + def rowSparseArray(data: NDArray, indices: NDArray, + shape: Shape, ctx: Context): SparseNDArray = { + val fmt = SparseFormat.ROW_SPARSE + val handle = newAllocHandle(fmt, shape, ctx, false, + DType.Float32, Array(indices.dtype), Array(indices.shape)) + checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, data.handle, -1)) + checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, indices.handle, 0)) + new SparseNDArray(handle) + } + + def retain(sparseNDArray: SparseNDArray, indices: Array[Float]): SparseNDArray = { + if (sparseNDArray.sparseFormat == SparseFormat.CSR) { + throw new IllegalArgumentException("CSR not supported") + } + NDArray.genericNDArrayFunctionInvoke("_sparse_retain", + Seq(sparseNDArray, NDArray.toNDArray(indices))).head.toSparse() + } + + private def newAllocHandle(stype : SparseFormat, + shape: Shape, + ctx: Context, + delayAlloc: Boolean, + dtype: DType = DType.Float32, + auxDTypes: Array[DType], + auxShapes: Array[Shape]) : NDArrayHandle = { + val hdl = new NDArrayHandleRef + checkCall(_LIB.mxNDArrayCreateSparseEx( + stype.id, + shape.toArray, + shape.length, + ctx.deviceTypeid, + ctx.deviceId, + if (delayAlloc) 1 else 0, + dtype.id, + auxDTypes.length, + auxDTypes.map(_.id), + auxShapes.map(_.length), + auxShapes.map(_.get(0)), + hdl) + ) + hdl.value + } +} + +/** + * Sparse NDArray is the child class of NDArray designed to hold the Sparse format + * + *
Currently, Rowsparse and CSR typed NDArray is supported. Most of the Operators
+ * will convert Sparse NDArray to dense. Basic operators like add
will
+ * have optimization for sparse operattions