Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add Sparse NDArray support for Scala (#15378)
Browse files Browse the repository at this point in the history
* add Sparse Support

* add imperative invoke sparse support

* add retain method and comments

* add getData method

* add Sparse NDIter test

* remove debug line
  • Loading branch information
lanking520 committed Jul 8, 2019
1 parent 1ae73de commit 09c71bf
Show file tree
Hide file tree
Showing 10 changed files with 543 additions and 28 deletions.
17 changes: 4 additions & 13 deletions scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,17 @@ object DType extends Enumeration {
val Float16 = Value(2, "float16")
val UInt8 = Value(3, "uint8")
val Int32 = Value(4, "int32")
val Int8 = Value(5, "int8")
val Int64 = Value(6, "int64")
val Unknown = Value(-1, "unknown")
private[mxnet] def numOfBytes(dtype: DType): Int = {
dtype match {
case DType.UInt8 => 1
case DType.UInt8 | DType.Int8 => 1
case DType.Int32 => 4
case DType.Float16 => 2
case DType.Float32 => 4
case DType.Float64 => 8
case DType.Float64 | DType.Int64 => 8
case DType.Unknown => 0
}
}
private[mxnet] def getType(dtypeStr: String): DType = {
dtypeStr match {
case "UInt8" => DType.UInt8
case "Int32" => DType.Int32
case "Float16" => DType.Float16
case "Float32" => DType.Float32
case "Float64" => DType.Float64
case _ => throw new IllegalArgumentException(
s"DType: $dtypeStr not found! please set it in DType.scala")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,14 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private def getOutputs: Array[NDArray] = {
val ndHandles = ArrayBuffer[NDArrayHandle]()
checkCall(_LIB.mxExecutorOutputs(handle, ndHandles))
ndHandles.toArray.map(new NDArray(_, addToCollector = false))
ndHandles.toArray.map(ele => {
val nd = new NDArray(ele, addToCollector = false)
if (nd.isSparse) {
nd.asInstanceOf[SparseNDArray]
}
nd
}
)
}

/**
Expand Down
27 changes: 25 additions & 2 deletions scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ private[mxnet] class LibInfo {
@native def mxListAllOpNames(names: ListBuffer[String]): Int
@native def nnGetOpHandle(opName: String, opHandle: RefLong): Int
// NDArray
@native def mxImperativeInvoke(creator: FunctionHandle,
@native def mxImperativeInvokeEx(creator: FunctionHandle,
inputs: Array[NDArrayHandle],
outputsGiven: Array[NDArrayHandle],
outputs: ArrayBuffer[NDArrayHandle],
numParams: Int,
paramKeys: Array[String],
paramVals: Array[String]): Int
paramVals: Array[String],
outStype: ArrayBuffer[Int]): Int
@native def mxNDArrayFree(handle: NDArrayHandle): Int
@native def mxNDArrayCreateNone(out: NDArrayHandleRef): Int
@native def mxNDArrayCreateEx(shape: Array[Int],
Expand All @@ -47,6 +48,20 @@ private[mxnet] class LibInfo {
delayAlloc: Int,
dtype: Int,
out: NDArrayHandleRef): Int
// scalastyle:off parameterNum
@native def mxNDArrayCreateSparseEx(storageType: Int,
shape: Array[Int],
ndim: Int,
devType: Int,
devId: Int,
delayAlloc: Int,
dtype: Int,
numAux: Int,
auxTypes: Array[Int],
auxNdims: Array[Int],
auxShapes: Array[Int],
out: NDArrayHandleRef): Int
// scalastyle:on parameterNum
@native def mxNDArrayWaitAll(): Int
@native def mxNDArrayWaitToRead(handle: NDArrayHandle): Int
@native def mxListFunctions(functions: ListBuffer[FunctionHandle]): Int
Expand Down Expand Up @@ -76,6 +91,9 @@ private[mxnet] class LibInfo {
@native def mxNDArrayGetShape(handle: NDArrayHandle,
ndim: MXUintRef,
data: ArrayBuffer[Int]): Int
@native def mxNDArraySyncCopyFromNDArray(handleDst: NDArrayHandle,
handleSrc: NDArrayHandle,
locator: Int): Int
@native def mxNDArraySyncCopyToCPU(handle: NDArrayHandle,
data: Array[Byte],
size: Int): Int
Expand Down Expand Up @@ -105,10 +123,15 @@ private[mxnet] class LibInfo {
@native def mxNDArraySave(fname: String,
handles: Array[NDArrayHandle],
keys: Array[String]): Int
@native def mxNDArrayGetDataNDArray(handle: NDArrayHandle, out: NDArrayHandleRef): Int
@native def mxNDArrayGetAuxNDArray(handle: NDArrayHandle,
location: Int,
out: NDArrayHandleRef): Int
@native def mxNDArrayGetContext(handle: NDArrayHandle, devTypeId: RefInt, devId: RefInt): Int
@native def mxNDArraySaveRawBytes(handle: NDArrayHandle, buf: ArrayBuffer[Byte]): Int
@native def mxNDArrayLoadFromRawBytes(bytes: Array[Byte], handle: NDArrayHandleRef): Int
@native def mxNDArrayGetDType(handle: NDArrayHandle, dtype: RefInt): Int
@native def mxNDArrayGetStorageType(handle: NDArrayHandle, stype: RefInt): Int

// KVStore Server
@native def mxInitPSEnv(keys: Array[String], values: Array[String]): Int
Expand Down
65 changes: 61 additions & 4 deletions scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import java.nio.{ByteBuffer, ByteOrder}

import org.apache.mxnet.Base._
import org.apache.mxnet.DType.DType
import org.apache.mxnet.MX_PRIMITIVES.{MX_PRIMITIVE_TYPE}
import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE
import org.apache.mxnet.SparseFormat.SparseFormat
import org.slf4j.LoggerFactory

import scala.collection.mutable
Expand Down Expand Up @@ -113,10 +114,22 @@ object NDArray extends NDArrayBase {
}

val outputs = ArrayBuffer.empty[NDArrayHandle]
checkCall(_LIB.mxImperativeInvoke(function.handle, ndArgs.map(_.handle).toArray, outputVars,
outputs, updatedKwargs.size, updatedKwargs.keys.toArray, updatedKwargs.values.toArray))
val outStypes = ArrayBuffer.empty[Int]
checkCall(_LIB.mxImperativeInvokeEx(function.handle,
ndArgs.map(_.handle).toArray,
outputVars,
outputs,
updatedKwargs.size,
updatedKwargs.keys.toArray,
updatedKwargs.values.toArray,
outStypes))
new NDArrayFuncReturn(Option(oriOutputs).getOrElse {
val outputArrs = outputs.map(new NDArray(_)).toArray
val outputArrs = (outputs zip outStypes).map(
ele => ele._2 match {
case 0 => new NDArray(ele._1)
case _ => new SparseNDArray(ele._1)
}
).toArray
addDependency(ndArgs.toArray, outputArrs)
outputArrs
})
Expand Down Expand Up @@ -943,6 +956,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
DType(mxDtype.value)
}

val sparseFormat: SparseFormat = {
val mxSF = new RefInt
checkCall(_LIB.mxNDArrayGetStorageType(handle, mxSF))
SparseFormat(mxSF.value)
}

/**
* Return a copied numpy array of current array with specified type.
* @param dtype Desired type of result array.
Expand Down Expand Up @@ -1309,6 +1328,30 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
if (this.context == context) this else this.copyTo(context)
}

/**
* check if NDArray is SparseNDArray
* @return Boolean
*/
def isSparse: Boolean = {
this.sparseFormat.id != 0
}

/**
* Convert a NDArray to SparseNDArray
*
* @param sfOption the target sparse type
* @return SparseNDArray
*/
def toSparse(sfOption : Option[SparseFormat] = None): SparseNDArray = {
val sf = sfOption.getOrElse(SparseFormat.ROW_SPARSE)
if (sf.id == 0) throw new IllegalArgumentException("Require Sparse")
if (isSparse && sfOption.isEmpty) {
this.asInstanceOf[SparseNDArray]
} else {
NDArray.api.cast_storage(this, sf.toString).head.asInstanceOf[SparseNDArray]
}
}

override def equals(o: Any): Boolean = o match {
case that: NDArray =>
that != null && that.shape == this.shape && that.toArray.sameElements(this.toArray)
Expand Down Expand Up @@ -1479,6 +1522,7 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
case DType.Float32 => units.map(wrapBytes(_).getFloat.toDouble)
case DType.Float64 => units.map(wrapBytes(_).getDouble)
case DType.Int32 => units.map(wrapBytes(_).getInt.toDouble)
case DType.Int64 => units.map(wrapBytes(_).getLong.toDouble)
case DType.UInt8 => internal.map(_.toDouble)
}
}
Expand All @@ -1488,6 +1532,7 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
case DType.Float32 => units.map(wrapBytes(_).getFloat)
case DType.Float64 => units.map(wrapBytes(_).getDouble.toFloat)
case DType.Int32 => units.map(wrapBytes(_).getInt.toFloat)
case DType.Int64 => units.map(wrapBytes(_).getLong.toFloat)
case DType.UInt8 => internal.map(_.toFloat)
}
}
Expand All @@ -1497,15 +1542,27 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
case DType.Float32 => units.map(wrapBytes(_).getFloat.toInt)
case DType.Float64 => units.map(wrapBytes(_).getDouble.toInt)
case DType.Int32 => units.map(wrapBytes(_).getInt)
case DType.Int64 => units.map(wrapBytes(_).getLong.toInt)
case DType.UInt8 => internal.map(_.toInt)
}
}
def toLongArray: Array[Long] = {
require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types")
dtype match {
case DType.Float32 => units.map(wrapBytes(_).getFloat.toLong)
case DType.Float64 => units.map(wrapBytes(_).getDouble.toLong)
case DType.Int32 => units.map(wrapBytes(_).getInt.toLong)
case DType.Int64 => units.map(wrapBytes(_).getLong)
case DType.UInt8 => internal.map(_.toLong)
}
}
def toByteArray: Array[Byte] = {
require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types")
dtype match {
case DType.Float16 | DType.Float32 => units.map(wrapBytes(_).getFloat.toByte)
case DType.Float64 => units.map(wrapBytes(_).getDouble.toByte)
case DType.Int32 => units.map(wrapBytes(_).getInt.toByte)
case DType.Int64 => units.map(wrapBytes(_).getLong.toByte)
case DType.UInt8 => internal.clone()
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet

object SparseFormat extends Enumeration {
type SparseFormat = Value
val DEFAULT = Value(0, "default")
val ROW_SPARSE = Value(1, "row_sparse")
val CSR = Value(2, "csr")
}
Loading

0 comments on commit 09c71bf

Please sign in to comment.