Skip to content

Commit

Permalink
Merge pull request apache#12 from javelinjs/scala-package-cc
Browse files Browse the repository at this point in the history
Executor, Model and completed Monitor
  • Loading branch information
terrytangyuan committed Dec 27, 2015
2 parents d4e397f + 213cd10 commit ee51492
Show file tree
Hide file tree
Showing 13 changed files with 564 additions and 62 deletions.
1 change: 1 addition & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ object Base {
type NDArrayHandle = RefLong
type FunctionHandle = RefLong
type KVStoreHandle = RefLong
type ExecutorHandle = RefLong

System.loadLibrary("mxnet-scala")
val _LIB = new LibInfo
Expand Down
230 changes: 225 additions & 5 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,231 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.Base._

import scala.collection.mutable.ArrayBuffer

object Executor {
// Get the dictionary given name and ndarray pairs.
private def getDict(names: Array[String], ndarrays: Array[NDArray]): Map[String, NDArray] = {
require(names.toSet.size == names.length, "Duplicate names detected")
(names zip ndarrays).toMap
}

/**
* Get input slice from the input shape.
* @param batchSize The number of samples in a mini-batch.
* @param workLoadList The list of work load for different devices, in the same order as ctx
* @return The split slices to get a specific slice.
* @throws IllegalArgumentException If there are two many splits such that some slice can be empty.
*/
private def splitInputSlice[@specialized(Int, Float, Double) V]
(batchSize: Int, workLoadList: Array[V])
(implicit num: Numeric[V]): Array[(Int, Int)] = {
val totalWorkLoad = workLoadList.sum.asInstanceOf[Float]
val batchNumList = workLoadList.map(workLoad =>
math.round(workLoad.asInstanceOf[Float] * batchSize / totalWorkLoad))
val batchNumSum = batchNumList.sum
if (batchNumSum < batchSize) {
batchNumList(batchNumList.length-1) += batchSize - batchNumSum
}

val slices = ArrayBuffer.empty[(Int, Int)]
var end = 0
batchNumList.foreach(batchNum => {
val begin = math.min(end, batchSize)
end = math.min(begin + batchNum, batchSize)
require(begin < end, "Too many slices such that some splits are empty")
slices.append((begin, end))
})
slices.toArray
}

/**
* Check the argument names of symbol.
* This function checks the duplication of arguments in Symbol.
* The check is done for feedforward net for now.
* @param symbol The network configuration
*/
private def checkArguments(symbol: Symbol): Unit = {
val argNames = symbol.listArguments()
require(argNames.toSet.size == argNames.length,
"Find duplicated argument name," +
"please make the weight name non-duplicated(using name arguments)," +
s"arguments are $argNames")

val auxNames = symbol.listAuxiliaryStates()
require(auxNames.toSet.size == auxNames.length,
"Find duplicated auxiliary param name," +
"please make the weight name non-duplicated(using name arguments)," +
s"arguments are $auxNames")
}

// Load a list of arrays into a list of arrays
private def loadGeneral(data: Array[NDArray], targets: Array[NDArray]): Unit = {
(data zip targets).foreach { case (dSrc, dTarget) =>
dSrc.copyTo(dTarget)
}
}

// Load a list of arrays into a list of arrays specified by slices
private def loadGeneral(data: Array[NDArray], targets: Array[(Int, Int, NDArray)]): Unit = {
(data zip targets).foreach { case (dSrc, (start, end, dTarget)) =>
dSrc.slice(start, end).copyTo(dTarget)
}
}
}

/**
* Created by yuantang on 12/23/15.
* Symbolic Executor component of MXNet
* @author Yizhi Liu
*
* Constructor: used Symbol.bind and Symbol.simple_bind instead.
* @param handle ExecutorHandle generated by calling Bind
* @param symbol
* @see Symbol.bind : to create executor
*/
abstract class Executor(var argArrays: Array[NDArray]) {
def forward
def backward
def setMonitorCallback(callback: (String, NDArray) => Any)
class Executor(val handle: ExecutorHandle, val symbol: Symbol) {
var argArrays: Array[NDArray] = null
protected var gradArrays: Array[NDArray] = null
protected var auxArrays: Array[NDArray] = null
protected var outputs: Array[NDArray] = getOutputs
protected var _argDict: Map[String, NDArray] = null
protected var _auxDict: Map[String, NDArray] = null
protected var monitorCallback: MXMonitorCallback = null

override def finalize(): Unit = {
checkCall(_LIB.mxExecutorFree(handle))
}

/**
* list all the output ndarray
* @return A list of ndarray binded to the heads of executor.
*/
private def getOutputs: Array[NDArray] = {
val ndHandles = ArrayBuffer[NDArrayHandle]()
checkCall(_LIB.mxExecutorOutputs(handle, ndHandles))
ndHandles.toArray.map(new NDArray(_))
}

/**
* Calculate the outputs specified by the binded symbol.
* @param isTrain whether this forward is for evaluation purpose.
* @param kwargs Additional specification of input arguments.
*/
def forward(isTrain: Boolean, kwargs: (String, NDArray)*): Unit = {
kwargs.foreach { case (name, array) =>
require(argDict.contains(name), s"Unknown argument $name")
array.copyTo(argDict(name))
}
checkCall(_LIB.mxExecutorForward(handle, if (isTrain) 1 else 0))
}

def forward(): Unit = {
forward(isTrain = false)
}

/**
* Do backward pass to get the gradient of arguments.
* @param outGrads
* Gradient on the outputs to be propagated back.
* This parameter is only needed when bind is called
* on outputs that are not a loss function.
*/
def backward(outGrads: Array[NDArray]):Unit = {
require(outGrads != null)
val ndArrayPtrs = outGrads.map(_.handle.value)
checkCall(_LIB.mxExecutorBackward(handle, outGrads.length, ndArrayPtrs))
}

def backward(outGrad: NDArray): Unit = {
require(outGrad != null)
backward(Array(outGrad))
}

def backward(): Unit = {
backward(Array.empty[NDArray])
}

/**
* Install callback.
* @param callback Takes a string and an NDArrayHandle.
*/
def setMonitorCallback(callback: MXMonitorCallback): Unit = {
monitorCallback = callback
checkCall(_LIB.mxExecutorSetMonitorCallback(handle, monitorCallback))
}

/**
* Get dictionary representation of argument arrrays.
* @return The dictionary that maps name of arguments to NDArrays.
* @throws IllegalArgumentException if there are duplicated names in the arguments.
*/
def argDict: Map[String, NDArray] = {
if (_argDict == null) {
_argDict = Executor.getDict(symbol.listArguments(), argArrays)
}
_argDict
}

/**
* Get dictionary representation of auxiliary states arrays.
* @return The dictionary that maps name of auxiliary states to NDArrays.
* @throws IllegalArgumentException if there are duplicated names in the auxiliary states.
*/
def auxDict: Map[String, NDArray] = {
if (_auxDict == null) {
_auxDict = Executor.getDict(
symbol.listAuxiliaryStates(), auxArrays)
}
_auxDict
}

/**
* Copy parameters from arg_params, aux_params into executor's internal array.
* @param argParams : dict of name to NDArray of arguments
* @param auxParams : dict of name to NDArray of auxiliary states.
* @param allowExtraParams
* Whether allow extra parameters that are not needed by symbol
* If this is True, no error will be thrown when arg_params or aux_params
* contain extra parameters that is not needed by the executor.
* @throws IllegalArgumentException If there is additional parameters in the dict but allow_extra_params=False
*/
def copyParamsFrom(argParams: Map[String, NDArray],
auxParams: Map[String, NDArray],
allowExtraParams: Boolean = false): Unit = {
argParams.foreach { case (name, array) =>
if (argDict.contains(name)) {
array.copyTo(argDict(name))
} else {
require(allowExtraParams, s"Find name $name that is not in the arguments")
}
}
if (auxParams != null) {
auxParams.foreach { case (name, array) =>
if (auxDict.contains(name)) {
array.copyTo(auxDict(name))
} else {
require(allowExtraParams, s"Find name $name that is not in the auxiliary states")
}
}
}
}

def copyParamsFrom(argParams: Map[String, NDArray], allowExtraParams: Boolean): Unit = {
copyParamsFrom(argParams, null, allowExtraParams)
}

def copyParamsFrom(argParams: Map[String, NDArray]): Unit = {
copyParamsFrom(argParams, allowExtraParams = false)
}

/**
* Get a debug string about internal execution plan.
* @return Debug string of the executor.
*/
def debugStr: String = {
val str = new RefString
checkCall(_LIB.mxExecutorPrint(handle, str))
str.value
}
}
18 changes: 18 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ class KVStore(private val handle: KVStoreHandle) {
push(Array(key), Array(value), priority)
}

def push(key: Int, values: Array[NDArray], priority: Int): Unit = {
val keys = Array.fill(values.length)(key)
push(keys, values, priority)
}

def push(key: Int, values: Array[NDArray]): Unit = {
push(key, values, 0)
}

/**
* Pull a single value or a sequence of values from the store.
*
Expand Down Expand Up @@ -98,6 +107,15 @@ class KVStore(private val handle: KVStoreHandle) {
pull(Array(key), Array(out), priority)
}

def pull(key: Int, outs: Array[NDArray], priority: Int): Unit = {
val keys = Array.fill(outs.length)(key)
pull(keys, outs, priority)
}

def pull(key: Int, outs: Array[NDArray]): Unit = {
pull(key, outs, 0)
}

// Get the type of this kvstore
def `type`: String = {
val kvType = new RefString
Expand Down
15 changes: 14 additions & 1 deletion scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import ml.dmlc.mxnet.Base._

import scala.collection.mutable.{ArrayBuffer, ListBuffer}

// JNI functions
/**
* JNI functions
* @author Yizhi Liu
*/
class LibInfo {
@native def mxNDArrayFree(handle: NDArrayHandle): Int
@native def mxGetLastError(): String
Expand Down Expand Up @@ -81,4 +84,14 @@ class LibInfo {
@native def mxKVStoreBarrier(handle: KVStoreHandle): Int
@native def mxKVStoreGetGroupSize(handle: KVStoreHandle, size: RefInt): Int
@native def mxKVStoreGetRank(handle: KVStoreHandle, size: RefInt): Int
@native def mxExecutorOutputs(handle: ExecutorHandle, outputs: ArrayBuffer[NDArrayHandle]): Int
@native def mxExecutorFree(handle: ExecutorHandle): Int
@native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int
@native def mxExecutorBackward(handle: ExecutorHandle,
gradsSize: Int,
// grads ought to be Array[NDArrayHandle],
// we pass ptr address directly for performance consideration
grads: Array[CPtrAddress]): Int
@native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int
@native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: MXMonitorCallback): Int
}
Loading

0 comments on commit ee51492

Please sign in to comment.