Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add Sparse NDArray support for Scala #15378

Merged
merged 6 commits into from
Jul 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,17 @@ object DType extends Enumeration {
val Float16 = Value(2, "float16")
val UInt8 = Value(3, "uint8")
val Int32 = Value(4, "int32")
val Int8 = Value(5, "int8")
val Int64 = Value(6, "int64")
val Unknown = Value(-1, "unknown")
private[mxnet] def numOfBytes(dtype: DType): Int = {
dtype match {
case DType.UInt8 => 1
case DType.UInt8 | DType.Int8 => 1
case DType.Int32 => 4
case DType.Float16 => 2
case DType.Float32 => 4
case DType.Float64 => 8
case DType.Float64 | DType.Int64 => 8
case DType.Unknown => 0
}
}
private[mxnet] def getType(dtypeStr: String): DType = {
dtypeStr match {
case "UInt8" => DType.UInt8
case "Int32" => DType.Int32
case "Float16" => DType.Float16
case "Float32" => DType.Float32
case "Float64" => DType.Float64
case _ => throw new IllegalArgumentException(
s"DType: $dtypeStr not found! please set it in DType.scala")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,14 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private def getOutputs: Array[NDArray] = {
val ndHandles = ArrayBuffer[NDArrayHandle]()
checkCall(_LIB.mxExecutorOutputs(handle, ndHandles))
ndHandles.toArray.map(new NDArray(_, addToCollector = false))
ndHandles.toArray.map(ele => {
val nd = new NDArray(ele, addToCollector = false)
if (nd.isSparse) {
nd.asInstanceOf[SparseNDArray]
}
nd
}
)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ private[mxnet] class LibInfo {
@native def mxListAllOpNames(names: ListBuffer[String]): Int
@native def nnGetOpHandle(opName: String, opHandle: RefLong): Int
// NDArray
@native def mxImperativeInvoke(creator: FunctionHandle,
@native def mxImperativeInvokeEx(creator: FunctionHandle,
inputs: Array[NDArrayHandle],
outputsGiven: Array[NDArrayHandle],
outputs: ArrayBuffer[NDArrayHandle],
numParams: Int,
paramKeys: Array[String],
paramVals: Array[String]): Int
paramVals: Array[String],
outStype: ArrayBuffer[Int]): Int
@native def mxNDArrayFree(handle: NDArrayHandle): Int
@native def mxNDArrayCreateNone(out: NDArrayHandleRef): Int
@native def mxNDArrayCreateEx(shape: Array[Int],
Expand All @@ -47,6 +48,20 @@ private[mxnet] class LibInfo {
delayAlloc: Int,
dtype: Int,
out: NDArrayHandleRef): Int
// scalastyle:off parameterNum
@native def mxNDArrayCreateSparseEx(storageType: Int,
shape: Array[Int],
ndim: Int,
devType: Int,
devId: Int,
delayAlloc: Int,
dtype: Int,
numAux: Int,
auxTypes: Array[Int],
auxNdims: Array[Int],
auxShapes: Array[Int],
out: NDArrayHandleRef): Int
// scalastyle:on parameterNum
@native def mxNDArrayWaitAll(): Int
@native def mxNDArrayWaitToRead(handle: NDArrayHandle): Int
@native def mxListFunctions(functions: ListBuffer[FunctionHandle]): Int
Expand Down Expand Up @@ -76,6 +91,9 @@ private[mxnet] class LibInfo {
@native def mxNDArrayGetShape(handle: NDArrayHandle,
ndim: MXUintRef,
data: ArrayBuffer[Int]): Int
@native def mxNDArraySyncCopyFromNDArray(handleDst: NDArrayHandle,
handleSrc: NDArrayHandle,
locator: Int): Int
@native def mxNDArraySyncCopyToCPU(handle: NDArrayHandle,
data: Array[Byte],
size: Int): Int
Expand Down Expand Up @@ -105,10 +123,15 @@ private[mxnet] class LibInfo {
@native def mxNDArraySave(fname: String,
handles: Array[NDArrayHandle],
keys: Array[String]): Int
@native def mxNDArrayGetDataNDArray(handle: NDArrayHandle, out: NDArrayHandleRef): Int
@native def mxNDArrayGetAuxNDArray(handle: NDArrayHandle,
location: Int,
out: NDArrayHandleRef): Int
@native def mxNDArrayGetContext(handle: NDArrayHandle, devTypeId: RefInt, devId: RefInt): Int
@native def mxNDArraySaveRawBytes(handle: NDArrayHandle, buf: ArrayBuffer[Byte]): Int
@native def mxNDArrayLoadFromRawBytes(bytes: Array[Byte], handle: NDArrayHandleRef): Int
@native def mxNDArrayGetDType(handle: NDArrayHandle, dtype: RefInt): Int
@native def mxNDArrayGetStorageType(handle: NDArrayHandle, stype: RefInt): Int

// KVStore Server
@native def mxInitPSEnv(keys: Array[String], values: Array[String]): Int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import java.nio.{ByteBuffer, ByteOrder}

import org.apache.mxnet.Base._
import org.apache.mxnet.DType.DType
import org.apache.mxnet.MX_PRIMITIVES.{MX_PRIMITIVE_TYPE}
import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE
import org.apache.mxnet.SparseFormat.SparseFormat
import org.slf4j.LoggerFactory

import scala.collection.mutable
Expand Down Expand Up @@ -113,10 +114,22 @@ object NDArray extends NDArrayBase {
}

val outputs = ArrayBuffer.empty[NDArrayHandle]
checkCall(_LIB.mxImperativeInvoke(function.handle, ndArgs.map(_.handle).toArray, outputVars,
outputs, updatedKwargs.size, updatedKwargs.keys.toArray, updatedKwargs.values.toArray))
val outStypes = ArrayBuffer.empty[Int]
checkCall(_LIB.mxImperativeInvokeEx(function.handle,
ndArgs.map(_.handle).toArray,
outputVars,
outputs,
updatedKwargs.size,
updatedKwargs.keys.toArray,
updatedKwargs.values.toArray,
outStypes))
new NDArrayFuncReturn(Option(oriOutputs).getOrElse {
val outputArrs = outputs.map(new NDArray(_)).toArray
val outputArrs = (outputs zip outStypes).map(
ele => ele._2 match {
case 0 => new NDArray(ele._1)
case _ => new SparseNDArray(ele._1)
}
).toArray
addDependency(ndArgs.toArray, outputArrs)
outputArrs
})
Expand Down Expand Up @@ -943,6 +956,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
DType(mxDtype.value)
}

val sparseFormat: SparseFormat = {
val mxSF = new RefInt
checkCall(_LIB.mxNDArrayGetStorageType(handle, mxSF))
SparseFormat(mxSF.value)
}

/**
* Return a copied numpy array of current array with specified type.
* @param dtype Desired type of result array.
Expand Down Expand Up @@ -1309,6 +1328,30 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
if (this.context == context) this else this.copyTo(context)
}

/**
* check if NDArray is SparseNDArray
* @return Boolean
*/
def isSparse: Boolean = {
this.sparseFormat.id != 0
}

/**
* Convert a NDArray to SparseNDArray
*
* @param sfOption the target sparse type
* @return SparseNDArray
*/
def toSparse(sfOption : Option[SparseFormat] = None): SparseNDArray = {
val sf = sfOption.getOrElse(SparseFormat.ROW_SPARSE)
if (sf.id == 0) throw new IllegalArgumentException("Require Sparse")
if (isSparse && sfOption.isEmpty) {
this.asInstanceOf[SparseNDArray]
} else {
NDArray.api.cast_storage(this, sf.toString).head.asInstanceOf[SparseNDArray]
}
}

override def equals(o: Any): Boolean = o match {
case that: NDArray =>
that != null && that.shape == this.shape && that.toArray.sameElements(this.toArray)
Expand Down Expand Up @@ -1479,6 +1522,7 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
case DType.Float32 => units.map(wrapBytes(_).getFloat.toDouble)
case DType.Float64 => units.map(wrapBytes(_).getDouble)
case DType.Int32 => units.map(wrapBytes(_).getInt.toDouble)
case DType.Int64 => units.map(wrapBytes(_).getLong.toDouble)
case DType.UInt8 => internal.map(_.toDouble)
}
}
Expand All @@ -1488,6 +1532,7 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
case DType.Float32 => units.map(wrapBytes(_).getFloat)
case DType.Float64 => units.map(wrapBytes(_).getDouble.toFloat)
case DType.Int32 => units.map(wrapBytes(_).getInt.toFloat)
case DType.Int64 => units.map(wrapBytes(_).getLong.toFloat)
case DType.UInt8 => internal.map(_.toFloat)
}
}
Expand All @@ -1497,15 +1542,27 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
case DType.Float32 => units.map(wrapBytes(_).getFloat.toInt)
case DType.Float64 => units.map(wrapBytes(_).getDouble.toInt)
case DType.Int32 => units.map(wrapBytes(_).getInt)
case DType.Int64 => units.map(wrapBytes(_).getLong.toInt)
case DType.UInt8 => internal.map(_.toInt)
}
}
def toLongArray: Array[Long] = {
require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types")
dtype match {
case DType.Float32 => units.map(wrapBytes(_).getFloat.toLong)
case DType.Float64 => units.map(wrapBytes(_).getDouble.toLong)
case DType.Int32 => units.map(wrapBytes(_).getInt.toLong)
case DType.Int64 => units.map(wrapBytes(_).getLong)
case DType.UInt8 => internal.map(_.toLong)
}
}
def toByteArray: Array[Byte] = {
require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types")
dtype match {
case DType.Float16 | DType.Float32 => units.map(wrapBytes(_).getFloat.toByte)
case DType.Float64 => units.map(wrapBytes(_).getDouble.toByte)
case DType.Int32 => units.map(wrapBytes(_).getInt.toByte)
case DType.Int64 => units.map(wrapBytes(_).getLong.toByte)
case DType.UInt8 => internal.clone()
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet

object SparseFormat extends Enumeration {
type SparseFormat = Value
val DEFAULT = Value(0, "default")
val ROW_SPARSE = Value(1, "row_sparse")
val CSR = Value(2, "csr")
}
Loading