Skip to content

Commit

Permalink
Merge pull request apache#23 from javelinjs/scala-package-l
Browse files Browse the repository at this point in the history
Symbol::bind and Executor forward/backward
  • Loading branch information
yanqingmen committed Jan 25, 2016
2 parents 0013f85 + 427b87a commit 10f1d5a
Show file tree
Hide file tree
Showing 14 changed files with 1,568 additions and 62 deletions.
2 changes: 1 addition & 1 deletion scala-package/core/scalastyle-config.xml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ You can also disable only one rule, by specifying its rule id, as specified in:
<parameters><parameter name="regex"><![CDATA[^[a-z][A-Za-z]*$]]></parameter></parameters>
</check>

<check level="error" class="org.scalastyle.scalariform.ParameterNumberChecker" enabled="true">
<check customId="parameterNum" level="error" class="org.scalastyle.scalariform.ParameterNumberChecker" enabled="true">
<parameters><parameter name="maxParameters"><![CDATA[10]]></parameter></parameters>
</check>

Expand Down
2 changes: 2 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ object Base {
type MXUint = Int
type MXFloat = Float
type CPtrAddress = Long
// TODO: make it more friendly to java
type Shape = Array[Int]

type NDArrayHandle = CPtrAddress
type FunctionHandle = CPtrAddress
Expand Down
13 changes: 13 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,19 @@ object Context {
val devtype2str = Map(1 -> "cpu", 2 -> "gpu", 3 -> "cpu_pinned")
val devstr2type = Map("cpu" -> 1, "gpu" -> 2, "cpu_pinned" -> 3)
val defaultCtx = new Context("cpu", 0)

def cpu(deviceId: Int = 0): Context = {
new Context("cpu", deviceId)
}

def gpu(deviceId: Int = 0): Context = {
new Context("gpu", deviceId)
}
}

/**
* Constructing a context.
* @author Yizhi Liu
* @param deviceTypeName {'cpu', 'gpu'} String representing the device type
* @param deviceId (default=0) The device id of the device, needed for GPU
*/
Expand All @@ -23,4 +32,8 @@ class Context(deviceTypeName: String, val deviceId: Int = 0) {
* @return device_type
*/
def deviceType: String = Context.devtype2str(deviceTypeid)

override def toString: String = {
deviceType
}
}
191 changes: 171 additions & 20 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.Base._
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.mutable.ArrayBuffer

object Executor {
// Get the dictionary given name and ndarray pairs.
private def getDict(names: Array[String], ndarrays: Array[NDArray]): Map[String, NDArray] = {
private[mxnet] def getDict(names: Seq[String],
ndarrays: Seq[NDArray]): Map[String, NDArray] = {
require(names.toSet.size == names.length, "Duplicate names detected")
(names zip ndarrays).toMap
}
Expand All @@ -19,12 +21,11 @@ object Executor {
* @throws IllegalArgumentException
* If there are two many splits such that some slice can be empty.
*/
private def splitInputSlice[@specialized(Int, Float, Double) V]
(batchSize: Int, workLoadList: Array[V])
(implicit num: Numeric[V]): Array[(Int, Int)] = {
val totalWorkLoad = workLoadList.sum.asInstanceOf[Float]
private[mxnet] def splitInputSlice(batchSize: Int,
workLoadList: Seq[Float]): Array[(Int, Int)] = {
val totalWorkLoad = workLoadList.sum
val batchNumList = workLoadList.map(workLoad =>
math.round(workLoad.asInstanceOf[Float] * batchSize / totalWorkLoad))
math.round(workLoad * batchSize / totalWorkLoad)).toArray
val batchNumSum = batchNumList.sum
if (batchNumSum < batchSize) {
batchNumList(batchNumList.length-1) += batchSize - batchNumSum
Expand All @@ -47,7 +48,7 @@ object Executor {
* The check is done for feedforward net for now.
* @param symbol The network configuration
*/
private def checkArguments(symbol: Symbol): Unit = {
private[mxnet] def checkArguments(symbol: Symbol): Unit = {
val argNames = symbol.listArguments()
require(argNames.toSet.size == argNames.length,
"Find duplicated argument name," +
Expand All @@ -62,35 +63,50 @@ object Executor {
}

// Load a list of arrays into a list of arrays
private def loadGeneral(data: Array[NDArray], targets: Array[NDArray]): Unit = {
private[mxnet] def loadGeneral(data: Array[NDArray], targets: Array[NDArray]): Unit = {
(data zip targets).foreach { case (dSrc, dTarget) =>
dSrc.copyTo(dTarget)
}
}

// Load a list of arrays into a list of arrays specified by slices
private def loadGeneral(data: Array[NDArray], targets: Array[(Int, Int, NDArray)]): Unit = {
(data zip targets).foreach { case (dSrc, (start, end, dTarget)) =>
dSrc.slice(start, end).copyTo(dTarget)
private[mxnet] def loadGeneral(data: IndexedSeq[NDArray],
targets: Seq[Array[(Int, Int, NDArray)]]): Unit = {
for ((src, dTargets) <- data zip targets) {
for ((start, end, dst) <- dTargets) {
src.slice(start, end).copyTo(dst)
}
}
}

// Load data into sliced arrays
private[mxnet] def loadData(batch: DataBatch,
targets: Seq[Array[(Int, Int, NDArray)]]): Unit = {
loadGeneral(batch.data, targets)
}

// Load label into sliced arrays
private[mxnet] def loadLabel(batch: DataBatch,
targets: Seq[Array[(Int, Int, NDArray)]]): Unit = {
loadGeneral(batch.label, targets)
}
}

/**
* Symbolic Executor component of MXNet
* @author Yizhi Liu
*
* Constructor: used Symbol.bind and Symbol.simple_bind instead.
* Constructor: please use Symbol.bind and Symbol.simpleBind instead.
* @param handle ExecutorHandle generated by calling Bind
* @param symbol
* @see Symbol.bind : to create executor
*/
// scalastyle:off finalize
class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val symbol: Symbol) {
var argArrays: Array[NDArray] = null
protected var gradArrays: Array[NDArray] = null
protected var auxArrays: Array[NDArray] = null
protected var outputs: Array[NDArray] = getOutputs
private[mxnet] var argArrays: Array[NDArray] = null
private[mxnet] var gradArrays: Array[NDArray] = null
private[mxnet] var auxArrays: Array[NDArray] = null
val outputs: Array[NDArray] = getOutputs
protected var _argDict: Map[String, NDArray] = null
protected var _auxDict: Map[String, NDArray] = null
protected var monitorCallback: MXMonitorCallback = null
Expand Down Expand Up @@ -128,10 +144,9 @@ class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val sym

/**
* Do backward pass to get the gradient of arguments.
* @param outGrads
* Gradient on the outputs to be propagated back.
* This parameter is only needed when bind is called
* on outputs that are not a loss function.
* @param outGrads Gradient on the outputs to be propagated back.
* This parameter is only needed when bind is called
* on outputs that are not a loss function.
*/
def backward(outGrads: Array[NDArray]): Unit = {
require(outGrads != null)
Expand Down Expand Up @@ -233,3 +248,139 @@ class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val sym
}
}
// scalastyle:on finalize

/**
* Helper class to manage multiple executors for data parallelism.
* @author Yizhi Liu
* @param symbol output symbol
* @param ctx devices to run on
* @param paramNames Name of all trainable parameters of the network.
* @param argNames Name of all arguments of the network.
* @param auxNames Name of all auxiliary states of the network.
* @param trainData Training data iterator.
* @param workLoadList The list of work load for different devices, in the same order as ctx
* @param logger When not specified, default logger will be used.
*/
class DataParallelExecutorManager(symbol: Symbol,
ctx: Array[Context],
paramNames: Seq[String],
argNames: Seq[String],
private val auxNames: Seq[String],
trainData: DataIter,
private var workLoadList: Seq[Float] = null,
logger: Logger = DataParallelExecutorManager.logger) {
// preparation
private val numDevice = ctx.length
logger.info(s"Start training with ${ctx.mkString(",")}")

// make sure the architecture is valid
Executor.checkArguments(symbol)

if (workLoadList == null) {
workLoadList = Seq.fill(numDevice)(1f)
}
require(workLoadList.size == numDevice, "Invalid settings for work load.")

private val slices = Executor.splitInputSlice(trainData.batchSize, workLoadList)

private val trainExecs =
ctx.zipWithIndex.map { case (context, i) =>
val dataShapes =
trainData.provideData.map { case (name: String, shape: Shape) =>
(name, Array(slices(i)._2 - slices(i)._1) ++ shape.drop(1))
}
symbol.simpleBind(context, "write", shapeDict = dataShapes)
}

// data structure
private val dataNames = trainData.provideData.map(_._1).toArray
private val labelNames = trainData.provideLabel.map(_._1).toArray

private val dataArrays =
dataNames.map { name =>
trainExecs.zipWithIndex.map { case (exec, i) =>
val slice = slices(i)
(slice._1, slice._2, exec.argDict(name))
}
}
private val labelArrays =
labelNames.map { name =>
trainExecs.zipWithIndex.map { case (exec, i) =>
val slice = slices(i)
(slice._1, slice._2, exec.argDict(name))
}
}

private val paramIdx = (0 until argNames.length).filter { i =>
paramNames.contains(argNames(i))
}
private val _paramNames = paramIdx.map(argNames(_))
private val paramArrays = paramIdx.map { i => trainExecs.map(_.argArrays(i)) }.toArray
private val gradArrays = paramIdx.map { i => trainExecs.map(_.gradArrays(i)) }.toArray

private val auxArrays = (0 until auxNames.length).map { i =>
trainExecs.map(_.auxArrays(i))
}.toArray
private val batchSize = trainData.batchSize
private val outputShapes: Array[Shape] = trainExecs(0).outputs.map { x: NDArray =>
Array(batchSize) ++ x.shape.drop(1)
}
private val cpuOutputArrays = outputShapes.map(NDArray.zeros(_))

// Install monitor on all executors
def installMonitor(monitor: Monitor): Unit = {
trainExecs.foreach(monitor.install)
}

/**
* Set parameter and aux values
* @param argParams source parameter arrays
* @param auxParams source aux arrays
*/
def setParams(argParams: Map[String, NDArray], auxParams: Map[String, NDArray]): Unit = {
trainExecs.foreach(_.copyParamsFrom(argParams, auxParams))
}

/**
* Copy data from each executor to `arg_params` and `aux_params`
* @param argParams target parameter arrays
* @param auxParams target aux arrays
* @note This function will inplace update the NDArrays in arg_params and aux_params.
*/
def copyTo(argParams: Map[String, NDArray], auxParams: Map[String, NDArray]): Unit = {
for ((name, block) <- _paramNames zip paramArrays) {
val weight = block.map(_.copyTo(Context.cpu())).reduce(_ + _) / block.length
weight.copyTo(argParams(name))
}
for ((name, block) <- auxNames zip auxArrays) {
val weight = block.map(_.copyTo(Context.cpu())).reduce(_ + _) / block.length
weight.copyTo(auxParams(name))
}
}

// load data and labels into arrays
def loadDataBatch(dataBatch: DataBatch): Unit = {
Executor.loadData(dataBatch, dataArrays)
Executor.loadLabel(dataBatch, labelArrays)
}

// Perform a forward pass on each executor
def forward(isTrain: Boolean = false): Unit = {
for ((texec, islice) <- trainExecs zip slices) {
texec.forward(isTrain)
for ((cpuOut, devOut) <- cpuOutputArrays zip texec.outputs) {
devOut.copyTo(cpuOut.slice(islice))
}
}
}

// Perform a backward pass on each executor
def backward(): Unit = {
trainExecs.foreach(_.backward())
}
}

object DataParallelExecutorManager {
private val logger = LoggerFactory.getLogger(classOf[Model])
}

Loading

0 comments on commit 10f1d5a

Please sign in to comment.