forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request apache#12 from javelinjs/scala-package-cc
Executor, Model and completed Monitor
- Loading branch information
Showing
13 changed files
with
564 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
230 changes: 225 additions & 5 deletions
230
scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,231 @@ | ||
package ml.dmlc.mxnet | ||
|
||
import ml.dmlc.mxnet.Base._ | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
||
object Executor { | ||
// Get the dictionary given name and ndarray pairs. | ||
private def getDict(names: Array[String], ndarrays: Array[NDArray]): Map[String, NDArray] = { | ||
require(names.toSet.size == names.length, "Duplicate names detected") | ||
(names zip ndarrays).toMap | ||
} | ||
|
||
/** | ||
* Get input slice from the input shape. | ||
* @param batchSize The number of samples in a mini-batch. | ||
* @param workLoadList The list of work load for different devices, in the same order as ctx | ||
* @return The split slices to get a specific slice. | ||
* @throws IllegalArgumentException If there are two many splits such that some slice can be empty. | ||
*/ | ||
private def splitInputSlice[@specialized(Int, Float, Double) V] | ||
(batchSize: Int, workLoadList: Array[V]) | ||
(implicit num: Numeric[V]): Array[(Int, Int)] = { | ||
val totalWorkLoad = workLoadList.sum.asInstanceOf[Float] | ||
val batchNumList = workLoadList.map(workLoad => | ||
math.round(workLoad.asInstanceOf[Float] * batchSize / totalWorkLoad)) | ||
val batchNumSum = batchNumList.sum | ||
if (batchNumSum < batchSize) { | ||
batchNumList(batchNumList.length-1) += batchSize - batchNumSum | ||
} | ||
|
||
val slices = ArrayBuffer.empty[(Int, Int)] | ||
var end = 0 | ||
batchNumList.foreach(batchNum => { | ||
val begin = math.min(end, batchSize) | ||
end = math.min(begin + batchNum, batchSize) | ||
require(begin < end, "Too many slices such that some splits are empty") | ||
slices.append((begin, end)) | ||
}) | ||
slices.toArray | ||
} | ||
|
||
/** | ||
* Check the argument names of symbol. | ||
* This function checks the duplication of arguments in Symbol. | ||
* The check is done for feedforward net for now. | ||
* @param symbol The network configuration | ||
*/ | ||
private def checkArguments(symbol: Symbol): Unit = { | ||
val argNames = symbol.listArguments() | ||
require(argNames.toSet.size == argNames.length, | ||
"Find duplicated argument name," + | ||
"please make the weight name non-duplicated(using name arguments)," + | ||
s"arguments are $argNames") | ||
|
||
val auxNames = symbol.listAuxiliaryStates() | ||
require(auxNames.toSet.size == auxNames.length, | ||
"Find duplicated auxiliary param name," + | ||
"please make the weight name non-duplicated(using name arguments)," + | ||
s"arguments are $auxNames") | ||
} | ||
|
||
// Load a list of arrays into a list of arrays | ||
private def loadGeneral(data: Array[NDArray], targets: Array[NDArray]): Unit = { | ||
(data zip targets).foreach { case (dSrc, dTarget) => | ||
dSrc.copyTo(dTarget) | ||
} | ||
} | ||
|
||
// Load a list of arrays into a list of arrays specified by slices | ||
private def loadGeneral(data: Array[NDArray], targets: Array[(Int, Int, NDArray)]): Unit = { | ||
(data zip targets).foreach { case (dSrc, (start, end, dTarget)) => | ||
dSrc.slice(start, end).copyTo(dTarget) | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Created by yuantang on 12/23/15. | ||
* Symbolic Executor component of MXNet | ||
* @author Yizhi Liu | ||
* | ||
* Constructor: used Symbol.bind and Symbol.simple_bind instead. | ||
* @param handle ExecutorHandle generated by calling Bind | ||
* @param symbol | ||
* @see Symbol.bind : to create executor | ||
*/ | ||
abstract class Executor(var argArrays: Array[NDArray]) { | ||
def forward | ||
def backward | ||
def setMonitorCallback(callback: (String, NDArray) => Any) | ||
class Executor(val handle: ExecutorHandle, val symbol: Symbol) { | ||
var argArrays: Array[NDArray] = null | ||
protected var gradArrays: Array[NDArray] = null | ||
protected var auxArrays: Array[NDArray] = null | ||
protected var outputs: Array[NDArray] = getOutputs | ||
protected var _argDict: Map[String, NDArray] = null | ||
protected var _auxDict: Map[String, NDArray] = null | ||
protected var monitorCallback: MXMonitorCallback = null | ||
|
||
override def finalize(): Unit = { | ||
checkCall(_LIB.mxExecutorFree(handle)) | ||
} | ||
|
||
/** | ||
* list all the output ndarray | ||
* @return A list of ndarray binded to the heads of executor. | ||
*/ | ||
private def getOutputs: Array[NDArray] = { | ||
val ndHandles = ArrayBuffer[NDArrayHandle]() | ||
checkCall(_LIB.mxExecutorOutputs(handle, ndHandles)) | ||
ndHandles.toArray.map(new NDArray(_)) | ||
} | ||
|
||
/** | ||
* Calculate the outputs specified by the binded symbol. | ||
* @param isTrain whether this forward is for evaluation purpose. | ||
* @param kwargs Additional specification of input arguments. | ||
*/ | ||
def forward(isTrain: Boolean, kwargs: (String, NDArray)*): Unit = { | ||
kwargs.foreach { case (name, array) => | ||
require(argDict.contains(name), s"Unknown argument $name") | ||
array.copyTo(argDict(name)) | ||
} | ||
checkCall(_LIB.mxExecutorForward(handle, if (isTrain) 1 else 0)) | ||
} | ||
|
||
def forward(): Unit = { | ||
forward(isTrain = false) | ||
} | ||
|
||
/** | ||
* Do backward pass to get the gradient of arguments. | ||
* @param outGrads | ||
* Gradient on the outputs to be propagated back. | ||
* This parameter is only needed when bind is called | ||
* on outputs that are not a loss function. | ||
*/ | ||
def backward(outGrads: Array[NDArray]):Unit = { | ||
require(outGrads != null) | ||
val ndArrayPtrs = outGrads.map(_.handle.value) | ||
checkCall(_LIB.mxExecutorBackward(handle, outGrads.length, ndArrayPtrs)) | ||
} | ||
|
||
def backward(outGrad: NDArray): Unit = { | ||
require(outGrad != null) | ||
backward(Array(outGrad)) | ||
} | ||
|
||
def backward(): Unit = { | ||
backward(Array.empty[NDArray]) | ||
} | ||
|
||
/** | ||
* Install callback. | ||
* @param callback Takes a string and an NDArrayHandle. | ||
*/ | ||
def setMonitorCallback(callback: MXMonitorCallback): Unit = { | ||
monitorCallback = callback | ||
checkCall(_LIB.mxExecutorSetMonitorCallback(handle, monitorCallback)) | ||
} | ||
|
||
/** | ||
* Get dictionary representation of argument arrrays. | ||
* @return The dictionary that maps name of arguments to NDArrays. | ||
* @throws IllegalArgumentException if there are duplicated names in the arguments. | ||
*/ | ||
def argDict: Map[String, NDArray] = { | ||
if (_argDict == null) { | ||
_argDict = Executor.getDict(symbol.listArguments(), argArrays) | ||
} | ||
_argDict | ||
} | ||
|
||
/** | ||
* Get dictionary representation of auxiliary states arrays. | ||
* @return The dictionary that maps name of auxiliary states to NDArrays. | ||
* @throws IllegalArgumentException if there are duplicated names in the auxiliary states. | ||
*/ | ||
def auxDict: Map[String, NDArray] = { | ||
if (_auxDict == null) { | ||
_auxDict = Executor.getDict( | ||
symbol.listAuxiliaryStates(), auxArrays) | ||
} | ||
_auxDict | ||
} | ||
|
||
/** | ||
* Copy parameters from arg_params, aux_params into executor's internal array. | ||
* @param argParams : dict of name to NDArray of arguments | ||
* @param auxParams : dict of name to NDArray of auxiliary states. | ||
* @param allowExtraParams | ||
* Whether allow extra parameters that are not needed by symbol | ||
* If this is True, no error will be thrown when arg_params or aux_params | ||
* contain extra parameters that is not needed by the executor. | ||
* @throws IllegalArgumentException If there is additional parameters in the dict but allow_extra_params=False | ||
*/ | ||
def copyParamsFrom(argParams: Map[String, NDArray], | ||
auxParams: Map[String, NDArray], | ||
allowExtraParams: Boolean = false): Unit = { | ||
argParams.foreach { case (name, array) => | ||
if (argDict.contains(name)) { | ||
array.copyTo(argDict(name)) | ||
} else { | ||
require(allowExtraParams, s"Find name $name that is not in the arguments") | ||
} | ||
} | ||
if (auxParams != null) { | ||
auxParams.foreach { case (name, array) => | ||
if (auxDict.contains(name)) { | ||
array.copyTo(auxDict(name)) | ||
} else { | ||
require(allowExtraParams, s"Find name $name that is not in the auxiliary states") | ||
} | ||
} | ||
} | ||
} | ||
|
||
def copyParamsFrom(argParams: Map[String, NDArray], allowExtraParams: Boolean): Unit = { | ||
copyParamsFrom(argParams, null, allowExtraParams) | ||
} | ||
|
||
def copyParamsFrom(argParams: Map[String, NDArray]): Unit = { | ||
copyParamsFrom(argParams, allowExtraParams = false) | ||
} | ||
|
||
/** | ||
* Get a debug string about internal execution plan. | ||
* @return Debug string of the executor. | ||
*/ | ||
def debugStr: String = { | ||
val str = new RefString | ||
checkCall(_LIB.mxExecutorPrint(handle, str)) | ||
str.value | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.