diff --git a/scala-package/core/scalastyle-config.xml b/scala-package/core/scalastyle-config.xml
index 847bbc2babe9..583a815a6fbe 100644
--- a/scala-package/core/scalastyle-config.xml
+++ b/scala-package/core/scalastyle-config.xml
@@ -60,7 +60,7 @@ You can also disable only one rule, by specifying its rule id, as specified in:
-
+
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
index 9f02c6e4e2e0..206b8d13f108 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
@@ -10,6 +10,8 @@ object Base {
type MXUint = Int
type MXFloat = Float
type CPtrAddress = Long
+ // TODO: make it more friendly to java
+ type Shape = Array[Int]
type NDArrayHandle = CPtrAddress
type FunctionHandle = CPtrAddress
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala
index 1ca12df2a975..c2a2325047f3 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala
@@ -4,10 +4,19 @@ object Context {
val devtype2str = Map(1 -> "cpu", 2 -> "gpu", 3 -> "cpu_pinned")
val devstr2type = Map("cpu" -> 1, "gpu" -> 2, "cpu_pinned" -> 3)
val defaultCtx = new Context("cpu", 0)
+
+ def cpu(deviceId: Int = 0): Context = {
+ new Context("cpu", deviceId)
+ }
+
+ def gpu(deviceId: Int = 0): Context = {
+ new Context("gpu", deviceId)
+ }
}
/**
* Constructing a context.
+ * @author Yizhi Liu
* @param deviceTypeName {'cpu', 'gpu'} String representing the device type
* @param deviceId (default=0) The device id of the device, needed for GPU
*/
@@ -23,4 +32,8 @@ class Context(deviceTypeName: String, val deviceId: Int = 0) {
* @return device_type
*/
def deviceType: String = Context.devtype2str(deviceTypeid)
+
+ override def toString: String = {
+ deviceType
+ }
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
index 60513db13635..5d7ffd8c3795 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
@@ -1,12 +1,14 @@
package ml.dmlc.mxnet
import ml.dmlc.mxnet.Base._
+import org.slf4j.{Logger, LoggerFactory}
import scala.collection.mutable.ArrayBuffer
object Executor {
// Get the dictionary given name and ndarray pairs.
- private def getDict(names: Array[String], ndarrays: Array[NDArray]): Map[String, NDArray] = {
+ private[mxnet] def getDict(names: Seq[String],
+ ndarrays: Seq[NDArray]): Map[String, NDArray] = {
require(names.toSet.size == names.length, "Duplicate names detected")
(names zip ndarrays).toMap
}
@@ -19,12 +21,11 @@ object Executor {
* @throws IllegalArgumentException
* If there are two many splits such that some slice can be empty.
*/
- private def splitInputSlice[@specialized(Int, Float, Double) V]
- (batchSize: Int, workLoadList: Array[V])
- (implicit num: Numeric[V]): Array[(Int, Int)] = {
- val totalWorkLoad = workLoadList.sum.asInstanceOf[Float]
+ private[mxnet] def splitInputSlice(batchSize: Int,
+ workLoadList: Seq[Float]): Array[(Int, Int)] = {
+ val totalWorkLoad = workLoadList.sum
val batchNumList = workLoadList.map(workLoad =>
- math.round(workLoad.asInstanceOf[Float] * batchSize / totalWorkLoad))
+ math.round(workLoad * batchSize / totalWorkLoad)).toArray
val batchNumSum = batchNumList.sum
if (batchNumSum < batchSize) {
batchNumList(batchNumList.length-1) += batchSize - batchNumSum
@@ -47,7 +48,7 @@ object Executor {
* The check is done for feedforward net for now.
* @param symbol The network configuration
*/
- private def checkArguments(symbol: Symbol): Unit = {
+ private[mxnet] def checkArguments(symbol: Symbol): Unit = {
val argNames = symbol.listArguments()
require(argNames.toSet.size == argNames.length,
"Find duplicated argument name," +
@@ -62,35 +63,50 @@ object Executor {
}
// Load a list of arrays into a list of arrays
- private def loadGeneral(data: Array[NDArray], targets: Array[NDArray]): Unit = {
+ private[mxnet] def loadGeneral(data: Array[NDArray], targets: Array[NDArray]): Unit = {
(data zip targets).foreach { case (dSrc, dTarget) =>
dSrc.copyTo(dTarget)
}
}
// Load a list of arrays into a list of arrays specified by slices
- private def loadGeneral(data: Array[NDArray], targets: Array[(Int, Int, NDArray)]): Unit = {
- (data zip targets).foreach { case (dSrc, (start, end, dTarget)) =>
- dSrc.slice(start, end).copyTo(dTarget)
+ private[mxnet] def loadGeneral(data: IndexedSeq[NDArray],
+ targets: Seq[Array[(Int, Int, NDArray)]]): Unit = {
+ for ((src, dTargets) <- data zip targets) {
+ for ((start, end, dst) <- dTargets) {
+ src.slice(start, end).copyTo(dst)
+ }
}
}
+
+ // Load data into sliced arrays
+ private[mxnet] def loadData(batch: DataBatch,
+ targets: Seq[Array[(Int, Int, NDArray)]]): Unit = {
+ loadGeneral(batch.data, targets)
+ }
+
+ // Load label into sliced arrays
+ private[mxnet] def loadLabel(batch: DataBatch,
+ targets: Seq[Array[(Int, Int, NDArray)]]): Unit = {
+ loadGeneral(batch.label, targets)
+ }
}
/**
* Symbolic Executor component of MXNet
* @author Yizhi Liu
*
- * Constructor: used Symbol.bind and Symbol.simple_bind instead.
+ * Constructor: please use Symbol.bind and Symbol.simpleBind instead.
* @param handle ExecutorHandle generated by calling Bind
* @param symbol
* @see Symbol.bind : to create executor
*/
// scalastyle:off finalize
class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val symbol: Symbol) {
- var argArrays: Array[NDArray] = null
- protected var gradArrays: Array[NDArray] = null
- protected var auxArrays: Array[NDArray] = null
- protected var outputs: Array[NDArray] = getOutputs
+ private[mxnet] var argArrays: Array[NDArray] = null
+ private[mxnet] var gradArrays: Array[NDArray] = null
+ private[mxnet] var auxArrays: Array[NDArray] = null
+ val outputs: Array[NDArray] = getOutputs
protected var _argDict: Map[String, NDArray] = null
protected var _auxDict: Map[String, NDArray] = null
protected var monitorCallback: MXMonitorCallback = null
@@ -128,10 +144,9 @@ class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val sym
/**
* Do backward pass to get the gradient of arguments.
- * @param outGrads
- * Gradient on the outputs to be propagated back.
- * This parameter is only needed when bind is called
- * on outputs that are not a loss function.
+ * @param outGrads Gradient on the outputs to be propagated back.
+ * This parameter is only needed when bind is called
+ * on outputs that are not a loss function.
*/
def backward(outGrads: Array[NDArray]): Unit = {
require(outGrads != null)
@@ -233,3 +248,139 @@ class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val sym
}
}
// scalastyle:on finalize
+
+/**
+ * Helper class to manage multiple executors for data parallelism.
+ * @author Yizhi Liu
+ * @param symbol output symbol
+ * @param ctx devices to run on
+ * @param paramNames Name of all trainable parameters of the network.
+ * @param argNames Name of all arguments of the network.
+ * @param auxNames Name of all auxiliary states of the network.
+ * @param trainData Training data iterator.
+ * @param workLoadList The list of work load for different devices, in the same order as ctx
+ * @param logger When not specified, default logger will be used.
+ */
+class DataParallelExecutorManager(symbol: Symbol,
+ ctx: Array[Context],
+ paramNames: Seq[String],
+ argNames: Seq[String],
+ private val auxNames: Seq[String],
+ trainData: DataIter,
+ private var workLoadList: Seq[Float] = null,
+ logger: Logger = DataParallelExecutorManager.logger) {
+ // preparation
+ private val numDevice = ctx.length
+ logger.info(s"Start training with ${ctx.mkString(",")}")
+
+ // make sure the architecture is valid
+ Executor.checkArguments(symbol)
+
+ if (workLoadList == null) {
+ workLoadList = Seq.fill(numDevice)(1f)
+ }
+ require(workLoadList.size == numDevice, "Invalid settings for work load.")
+
+ private val slices = Executor.splitInputSlice(trainData.batchSize, workLoadList)
+
+ private val trainExecs =
+ ctx.zipWithIndex.map { case (context, i) =>
+ val dataShapes =
+ trainData.provideData.map { case (name: String, shape: Shape) =>
+ (name, Array(slices(i)._2 - slices(i)._1) ++ shape.drop(1))
+ }
+ symbol.simpleBind(context, "write", shapeDict = dataShapes)
+ }
+
+ // data structure
+ private val dataNames = trainData.provideData.map(_._1).toArray
+ private val labelNames = trainData.provideLabel.map(_._1).toArray
+
+ private val dataArrays =
+ dataNames.map { name =>
+ trainExecs.zipWithIndex.map { case (exec, i) =>
+ val slice = slices(i)
+ (slice._1, slice._2, exec.argDict(name))
+ }
+ }
+ private val labelArrays =
+ labelNames.map { name =>
+ trainExecs.zipWithIndex.map { case (exec, i) =>
+ val slice = slices(i)
+ (slice._1, slice._2, exec.argDict(name))
+ }
+ }
+
+ private val paramIdx = (0 until argNames.length).filter { i =>
+ paramNames.contains(argNames(i))
+ }
+ private val _paramNames = paramIdx.map(argNames(_))
+ private val paramArrays = paramIdx.map { i => trainExecs.map(_.argArrays(i)) }.toArray
+ private val gradArrays = paramIdx.map { i => trainExecs.map(_.gradArrays(i)) }.toArray
+
+ private val auxArrays = (0 until auxNames.length).map { i =>
+ trainExecs.map(_.auxArrays(i))
+ }.toArray
+ private val batchSize = trainData.batchSize
+ private val outputShapes: Array[Shape] = trainExecs(0).outputs.map { x: NDArray =>
+ Array(batchSize) ++ x.shape.drop(1)
+ }
+ private val cpuOutputArrays = outputShapes.map(NDArray.zeros(_))
+
+ // Install monitor on all executors
+ def installMonitor(monitor: Monitor): Unit = {
+ trainExecs.foreach(monitor.install)
+ }
+
+ /**
+ * Set parameter and aux values
+ * @param argParams source parameter arrays
+ * @param auxParams source aux arrays
+ */
+ def setParams(argParams: Map[String, NDArray], auxParams: Map[String, NDArray]): Unit = {
+ trainExecs.foreach(_.copyParamsFrom(argParams, auxParams))
+ }
+
+ /**
+ * Copy data from each executor to `arg_params` and `aux_params`
+ * @param argParams target parameter arrays
+ * @param auxParams target aux arrays
+ * @note This function will inplace update the NDArrays in arg_params and aux_params.
+ */
+ def copyTo(argParams: Map[String, NDArray], auxParams: Map[String, NDArray]): Unit = {
+ for ((name, block) <- _paramNames zip paramArrays) {
+ val weight = block.map(_.copyTo(Context.cpu())).reduce(_ + _) / block.length
+ weight.copyTo(argParams(name))
+ }
+ for ((name, block) <- auxNames zip auxArrays) {
+ val weight = block.map(_.copyTo(Context.cpu())).reduce(_ + _) / block.length
+ weight.copyTo(auxParams(name))
+ }
+ }
+
+ // load data and labels into arrays
+ def loadDataBatch(dataBatch: DataBatch): Unit = {
+ Executor.loadData(dataBatch, dataArrays)
+ Executor.loadLabel(dataBatch, labelArrays)
+ }
+
+ // Perform a forward pass on each executor
+ def forward(isTrain: Boolean = false): Unit = {
+ for ((texec, islice) <- trainExecs zip slices) {
+ texec.forward(isTrain)
+ for ((cpuOut, devOut) <- cpuOutputArrays zip texec.outputs) {
+ devOut.copyTo(cpuOut.slice(islice))
+ }
+ }
+ }
+
+ // Perform a backward pass on each executor
+ def backward(): Unit = {
+ trainExecs.foreach(_.backward())
+ }
+}
+
+object DataParallelExecutorManager {
+ private val logger = LoggerFactory.getLogger(classOf[Model])
+}
+
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
index 3b6cac583f75..6cd073461250 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
@@ -5,12 +5,18 @@ import org.slf4j.LoggerFactory
import scala.collection.mutable.ListBuffer
+/**
+ * IO iterators for loading training & validation data
+ * @author Zixuan Huang, Yizhi Liu
+ */
object IO {
type IterCreateFunc = (Map[String, String]) => DataIter
private val logger = LoggerFactory.getLogger(classOf[DataIter])
private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule()
+ def MNISTIter: IterCreateFunc = iterCreateFuncs("MNISTIter")
+
/**
* create iterator via iterName and params
* @param iterName name of iterator; "MNISTIter" or "ImageRecordIter"
@@ -58,6 +64,12 @@ object IO {
checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out))
new MXDataIter(out.value)
}
+
+ // Convert data into canonical form.
+ private def initData(data: NDArray, allowEmpty: Boolean, defaultName: String) = {
+ require(data != null || allowEmpty)
+ // TODO
+ }
}
@@ -68,15 +80,15 @@ object IO {
* @param index
* @param pad
*/
-case class DataBatch(data: NDArray,
- label: NDArray,
- index: List[Long],
+case class DataBatch(data: IndexedSeq[NDArray],
+ label: IndexedSeq[NDArray],
+ index: IndexedSeq[Long],
pad: Int)
/**
* DataIter object in mxnet.
*/
-abstract class DataIter (val batchSize: Int = 0) {
+abstract class DataIter(val batchSize: Int = 0) {
/**
* reset the iterator
*/
@@ -100,13 +112,13 @@ abstract class DataIter (val batchSize: Int = 0) {
* get data of current batch
* @return the data of current batch
*/
- def getData(): NDArray
+ def getData(): IndexedSeq[NDArray]
/**
* Get label of current batch
* @return the label of current batch
*/
- def getLabel(): NDArray
+ def getLabel(): IndexedSeq[NDArray]
/**
* get the number of padding examples
@@ -119,8 +131,13 @@ abstract class DataIter (val batchSize: Int = 0) {
* the index of current batch
* @return
*/
- def getIndex(): List[Long]
+ def getIndex(): IndexedSeq[Long]
+
+ // The name and shape of data provided by this iterator
+ def provideData: Map[String, Shape]
+ // The name and shape of label provided by this iterator
+ def provideLabel: Map[String, Shape]
}
/**
@@ -156,31 +173,31 @@ class MXDataIter(val handle: DataIterHandle) extends DataIter {
* get data of current batch
* @return the data of current batch
*/
- override def getData(): NDArray = {
+ override def getData(): IndexedSeq[NDArray] = {
val out = new NDArrayHandleRef
checkCall(_LIB.mxDataIterGetData(handle, out))
- new NDArray(out.value, writable = false)
+ IndexedSeq(new NDArray(out.value, writable = false))
}
/**
* Get label of current batch
* @return the label of current batch
*/
- override def getLabel(): NDArray = {
+ override def getLabel(): IndexedSeq[NDArray] = {
val out = new NDArrayHandleRef
checkCall(_LIB.mxDataIterGetLabel(handle, out))
- new NDArray(out.value, writable = false)
+ IndexedSeq(new NDArray(out.value, writable = false))
}
/**
* the index of current batch
* @return
*/
- override def getIndex(): List[Long] = {
+ override def getIndex(): IndexedSeq[Long] = {
val outIndex = new ListBuffer[Long]
val outSize = new RefLong
checkCall(_LIB.mxDataIterGetIndex(handle, outIndex, outSize))
- outIndex.toList
+ outIndex.toIndexedSeq
}
/**
@@ -193,6 +210,12 @@ class MXDataIter(val handle: DataIterHandle) extends DataIter {
checkCall(_LIB.mxDataIterGetPadNum(handle, out))
out.value
}
+
+ // The name and shape of data provided by this iterator
+ override def provideData: Map[String, Shape] = ???
+
+ // The name and shape of label provided by this iterator
+ override def provideLabel: Map[String, Shape] = ???
}
// scalastyle:on finalize
@@ -209,19 +232,19 @@ class ArrayDataIter() extends DataIter {
* get data of current batch
* @return the data of current batch
*/
- override def getData(): NDArray = ???
+ override def getData(): IndexedSeq[NDArray] = ???
/**
* Get label of current batch
* @return the label of current batch
*/
- override def getLabel(): NDArray = ???
+ override def getLabel(): IndexedSeq[NDArray] = ???
/**
* the index of current batch
* @return
*/
- override def getIndex(): List[Long] = ???
+ override def getIndex(): IndexedSeq[Long] = ???
/**
* Iterate to next batch
@@ -235,6 +258,69 @@ class ArrayDataIter() extends DataIter {
* @return number of padding examples in current batch
*/
override def getPad(): Int = ???
+
+ // The name and shape of data provided by this iterator
+ override def provideData: Map[String, Shape] = ???
+
+ // The name and shape of label provided by this iterator
+ override def provideLabel: Map[String, Shape] = ???
}
+/**
+ * TODO
+ * NDArrayIter object in mxnet. Taking NDArray or numpy array to get dataiter.
+ * @param data NDArrayIter supports single or multiple data and label.
+ * @param label Same as data, but is not fed to the model during testing.
+ * @param batchSize Batch Size
+ * @param shuffle Whether to shuffle the data
+ * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
+ * @note
+ * This iterator will pad, discard or roll over the last batch if
+ * the size of data does not match batch_size. Roll over is intended
+ * for training and can cause problems if used for prediction.
+ */
+class NDArrayIter(data: NDArray, label: NDArray = null,
+ batchSize: Int = 1, shuffle: Boolean = false,
+ lastBatchHandle: String = "pad") extends DataIter(batchSize) {
+ /**
+ * reset the iterator
+ */
+ override def reset(): Unit = ???
+
+ /**
+ * get data of current batch
+ * @return the data of current batch
+ */
+ override def getData(): IndexedSeq[NDArray] = ???
+
+ /**
+ * Get label of current batch
+ * @return the label of current batch
+ */
+ override def getLabel(): IndexedSeq[NDArray] = ???
+
+ /**
+ * the index of current batch
+ * @return
+ */
+ override def getIndex(): IndexedSeq[Long] = ???
+ /**
+ * Iterate to next batch
+ * @return whether the move is successful
+ */
+ override def iterNext(): Boolean = ???
+
+ /**
+ * get the number of padding examples
+ * in current batch
+ * @return number of padding examples in current batch
+ */
+ override def getPad(): MXUint = ???
+
+ // The name and shape of data provided by this iterator
+ override def provideData: Map[String, Shape] = ???
+
+ // The name and shape of label provided by this iterator
+ override def provideLabel: Map[String, Shape] = ???
+}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
index 8bfdf2f9cc2e..6e4ccbb273e3 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
@@ -145,8 +145,43 @@ class LibInfo {
@native def mxSymbolListArguments(handle: SymbolHandle,
arguments: ArrayBuffer[String]): Int
@native def mxSymbolCopy(handle: SymbolHandle, clonedHandle: SymbolHandleRef): Int
+ @native def mxSymbolListAuxiliaryStates(handle: SymbolHandle,
+ arguments: ArrayBuffer[String]): Int
@native def mxSymbolListOutputs(handle: SymbolHandle,
outputs: ArrayBuffer[String]): Int
@native def mxSymbolCreateGroup(handles: Array[SymbolHandle], out: SymbolHandleRef): Int
@native def mxSymbolPrint(handle: SymbolHandle, str: RefString): Int
+ @native def mxSymbolGetInternals(handle: SymbolHandle, out: SymbolHandleRef): Int
+ @native def mxSymbolInferType(handle: SymbolHandle,
+ keys: Array[String],
+ sdata: Array[Int],
+ argTypeData: ListBuffer[Int],
+ outTypeData: ListBuffer[Int],
+ auxTypeData: ListBuffer[Int],
+ complete: RefInt): Int
+ @native def mxSymbolInferShape(handle: SymbolHandle,
+ numArgs: MXUint,
+ keys: Array[String],
+ argIndPtr: Array[MXUint],
+ argShapeData: Array[MXUint],
+ inShapeData: ListBuffer[Shape],
+ outShapeData: ListBuffer[Shape],
+ auxShapeData: ListBuffer[Shape],
+ complete: RefInt): Int
+ @native def mxSymbolGetOutput(handle: SymbolHandle, index: Int, out: SymbolHandleRef): Int
+ // scalastyle:off parameterNum
+ @native def mxExecutorBindX(handle: SymbolHandle,
+ deviceTypeId: Int,
+ deviceID: Int,
+ numCtx: Int,
+ ctxMapKeys: Array[String],
+ ctxMapDevTypes: Array[Int],
+ ctxMapDevIDs: Array[Int],
+ numArgs: Int,
+ argsHandle: Array[NDArrayHandle],
+ argsGradHandle: Array[NDArrayHandle],
+ reqsArray: Array[Int],
+ auxArgsHandle: Array[NDArrayHandle],
+ out: ExecutorHandleRef): Int
+ // scalastyle:on parameterNum
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
index 1912832341a3..f5100440042e 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
@@ -1,11 +1,16 @@
package ml.dmlc.mxnet
-import org.slf4j.LoggerFactory
+import ml.dmlc.mxnet.Base.Shape
+import ml.dmlc.mxnet.optimizer.SGD
+import org.slf4j.{Logger, LoggerFactory}
+
+import scala.collection.mutable
/**
* Describe the model flow
* @author Yizhi Liu
*/
+class Model
object Model {
private val logger = LoggerFactory.getLogger(classOf[Model])
/**
@@ -13,12 +18,12 @@ object Model {
* This function select and create a proper kvstore given the kvstore type
* @param kvStore KVStore type
* @param numDevice The number of devices
- * @param maxSize max size of the kvstore
+ * @param argParams Model parameter, dict of name to NDArray of net's weights.
* @return Option of created [[KVStore]] and whether or not update weight on it
*/
- private def createKVStore(kvStore: String,
- numDevice: Int,
- maxSize: Int): (Option[KVStore], Boolean) = {
+ private[mxnet] def createKVStore(kvStore: String,
+ numDevice: Int,
+ argParams: Map[String, NDArray]): (Option[KVStore], Boolean) = {
if (numDevice == 1 && !kvStore.contains("dist")) {
// no need to use kv for single device and single machine
(None, false)
@@ -26,6 +31,7 @@ object Model {
var kvType = kvStore
if (kvType == "local") {
// automatically select a proper local
+ val maxSize = argParams.values.map(_.shape.product).max
kvType =
if (maxSize < 1024 * 1024 * 16) {
"local_update_cpu"
@@ -39,8 +45,8 @@ object Model {
}
/**
- * Create a kvstore (wrap it with Option, None if given kvStore == null)
- * @param kvStore
+ * Create a kvStore (wrap it with Option, None if given kvStore == null)
+ * @param kvStore KVStore
* @return Option of created [[KVStore]] and whether or not update weight on it
*/
private def createKVStore(kvStore: KVStore): (Option[KVStore], Boolean) = {
@@ -100,7 +106,273 @@ object Model {
}
}
}
+
+ /**
+ * TODO
+ * Internal training function on multiple devices.
+ * This function will also work for single device as well.
+ * @param symbol The network configuration
+ * @param ctx The training devices.
+ * @param argNames Name of all arguments of the network.
+ * @param paramNames Name of all trainable parameters of the network.
+ * @param auxNames Name of all auxiliary states of the network.
+ * @param argParams Model parameter, dict of name to NDArray of net's weights.
+ * @param auxParams Model parameter, dict of name to NDArray of net's auxiliary states.
+ * @param beginEpoch The begining training epoch.
+ * @param endEpoch The end training epoch.
+ * @param epochSize Number of batches in a epoch.
+ * In default, it is set to ceil(num_train_examples / batch_size)
+ * @param optimizer The optimization algorithm
+ * @param kvStore The KVStore
+ * @param updateOnKVStore whether or not perform weight updating on kvstore
+ * @param trainData Training data iterator.
+ * @param evalData Validation data iterator.
+ * @param evalMetric A evaluation function.
+ * @param epochEndCallback A callback that is invoked at end of each epoch.
+ * This can be used to checkpoint model each epoch.
+ * @param batchEndCallback A callback that is invoked at end of each batch.
+ * This can be used to measure speed,
+ * get result from evaluation metric. etc.
+ * @param logger When not specified, default logger will be used.
+ * @param workLoadList The list of work load for different devices, in the same order as ctx
+ * @param monitor Monitor outputs, weights, and gradients for debugging
+ * @note This function will inplace update the NDArrays in argParams and auxStates.
+ */
+ // scalastyle:off parameterNum
+ private[mxnet] def trainMultiDevice(symbol: Symbol, ctx: Array[Context],
+ argNames: Seq[String], paramNames: Seq[String],
+ auxNames: Seq[String], argParams: Map[String, NDArray],
+ auxParams: Map[String, NDArray],
+ beginEpoch: Int, endEpoch: Int, epochSize: Int,
+ optimizer: Optimizer,
+ kvStore: KVStore, updateOnKVStore: Boolean,
+ trainData: DataIter = null, evalData: DataIter = null,
+ evalMetric: EvalMetric = null,
+ epochEndCallback: EpochEndCallback = null,
+ batchEndCallback: BatchEndCallback = null,
+ logger: Logger = logger,
+ workLoadList: Seq[Float] = Nil,
+ monitor: Monitor = null): Unit = {
+ val executorManager = new DataParallelExecutorManager(
+ symbol = symbol,
+ ctx = ctx,
+ trainData = trainData,
+ paramNames = paramNames,
+ argNames = argNames,
+ auxNames = auxNames,
+ workLoadList = workLoadList,
+ logger = logger)
+ }
+ // scalastyle:on parameterNum
+}
+
+trait EpochEndCallback {
+ def invoke(epoch: Int, symbol: Symbol,
+ argParams: Map[String, NDArray],
+ auxStates: Map[String, NDArray]): Unit
}
-class Model {
+trait BatchEndCallback {
+ def invoke(epoch: Int, nBatch: Int, evalMetric: EvalMetric)
+}
+
+/**
+ * Model class of MXNet for training and predicting feedforward nets.
+ * This class is designed for a single-data single output supervised network.
+ * @param symbol The symbol configuration of computation network.
+ * @param ctx The device context of training and prediction.
+ * To use multi GPU training, pass in a list of gpu contexts.
+ * @param numEpoch Training parameter, number of training epochs(epochs).
+ * @param epochSize Number of batches in a epoch. In default, it is set to
+ * ceil(num_train_examples / batch_size)
+ * @param optimizer Training parameter, name or optimizer object for training.
+ * @param initializer Training parameter, the initialization scheme used.
+ * @param batchSize The batch size of training data.
+ * @param argParams Model parameter, dict of name to NDArray of net's weights.
+ * @param auxParams Model parameter, dict of name to NDArray of net's auxiliary states.
+ * @param allowExtraParams Whether allow extra parameters that are not needed by symbol
+ * to be passed by aux_params and arg_params.
+ * If this is True, no error will be thrown when aux_params and arg_params
+ * contain extra parameters than needed.
+ * @param beginEpoch The begining training epoch.
+ * @param kwargs The additional keyword arguments passed to optimizer.
+ */
+class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(new Context("cpu")),
+ val numEpoch: Int = -1, val epochSize: Int = -1,
+ val optimizer: String = "sgd",
+ val initializer: Initializer = new Uniform(0.01f),
+ val batchSize: Int = 128,
+ argParams: Map[String, NDArray] = null,
+ auxParams: Map[String, NDArray] = null,
+ allowExtraParams: Boolean = false,
+ val beginEpoch: Int = 0,
+ val kwargs: mutable.Map[String, Any]) {
+ private val LOG: Logger = LoggerFactory.getLogger(classOf[FeedForward])
+ // check if symbol contain duplicated names.
+ Executor.checkArguments(symbol)
+
+ // rematch parameters to delete useless ones
+ private var _argParams =
+ if (allowExtraParams) {
+ if (argParams != null) {
+ val argNames = symbol.listArguments().toSet
+ argParams.filter { case (k, v) => argNames.contains(k) }
+ } else {
+ null
+ }
+ } else {
+ argParams
+ }
+ private var _auxParams =
+ if (allowExtraParams) {
+ if (auxParams != null) {
+ val auxNames = symbol.listAuxiliaryStates().toSet
+ auxParams.filter { case (k, v) => auxNames.contains(k) }
+ } else {
+ null
+ }
+ } else {
+ auxParams
+ }
+
+ // internal helper state
+ var predExec: Executor = null
+
+ // Initialize weight parameters and auxiliary states
+ private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false)
+ : (Seq[String], Seq[String], Seq[String]) = {
+ val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes)
+ val argNames = symbol.listArguments()
+ val inputNames = inputShapes.keys
+ val paramNames = argNames.toSet -- inputNames.toSet
+ val auxNames = symbol.listAuxiliaryStates()
+
+ val paramNameShapes = (argNames zip argShapes).filter { case (name, _) =>
+ paramNames.contains(name)
+ }
+ val argParams = paramNameShapes.map { case (name, shape) =>
+ (name, NDArray.zeros(shape))
+ }.toMap
+ val auxParams = (auxNames zip auxShapes).map { case (name, shape) =>
+ (name, NDArray.zeros(shape))
+ }.toMap
+
+ for ((k, v) <- argParams) {
+ if (_argParams != null && _argParams.contains(k) && (!overwrite)) {
+ argParams(k).set(_argParams(k))
+ } else {
+ initializer(k, v)
+ }
+ }
+
+ for ((k, v) <- auxParams) {
+ if (_auxParams != null && _auxParams.contains(k) && (!overwrite)) {
+ auxParams(k).set(_auxParams(k))
+ } else {
+ initializer(k, v)
+ }
+ }
+
+ _argParams = argParams
+ _auxParams = auxParams
+ (argNames, paramNames.toSeq, auxNames)
+ }
+
+ // Initialize the predictor module for running prediction.
+ private def initPredictor(inputShapes: Map[String, Shape]): Unit = {
+ if (this.predExec == null) {
+ val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = inputShapes)
+ predExec.copyParamsFrom(_argParams, _auxParams)
+ Executor.checkArguments(symbol)
+ this.predExec = predExec
+ }
+ }
+
+ // Initialize the iterator given input.
+ private def initIter(X: NDArray, y: NDArray, isTrain: Boolean): DataIter = {
+ require(y != null || !isTrain, "y must be specified")
+ val label = if (y == null) NDArray.zeros(X.shape(0)) else y
+ require(label.shape.length == 1, "Label must be 1D")
+ require(X.shape(0) == label.shape(0), "The numbers of data points and labels not equal")
+ if (isTrain) {
+ new NDArrayIter(X, label, batchSize, shuffle = isTrain, lastBatchHandle = "roll_over")
+ } else {
+ new NDArrayIter(X, label, batchSize, shuffle = false)
+ }
+ }
+
+ // Initialize the iterator given eval_data.
+ private def initEvalIter(evalData: (NDArray, NDArray)): DataIter = {
+ if (evalData == null) {
+ null
+ } else {
+ initIter(evalData._1, evalData._2, isTrain = true)
+ }
+ }
+
+ /**
+ * TODO
+ * Fit the model.
+ * @param trainData Training data. If X is an DataIter, the name or, if not available,
+ * position, of its outputs should match the corresponding variable
+ * names defined in the symbolic graph.
+ * @param evalData If eval_data is numpy.ndarray/list/NDArray pair,
+ * it should be (valid_data, valid_label).
+ * @param evalMetric The evaluation metric, name of evaluation metric.
+ * Or a customize evaluation function that returns the statistics
+ * based on minibatch.
+ * @param epochEndCallback A callback that is invoked at end of each epoch.
+ * This can be used to checkpoint model each epoch.
+ * @param batchEndCallback A callback that is invoked at end of each batch
+ * For print purpose
+ * @param kvStoreType A string kvstore type:
+ * 'local' : multi-devices on a single machine, will automatically
+ * choose one from 'local_update_cpu', 'local_allreduce_cpu', and
+ * 'local_allreduce_device'
+ * 'dist_sync' : multi-machines with BSP
+ * 'dist_async' : multi-machines with partical asynchronous
+ * In default uses 'local', often no need to change for single machiine.
+ * @param logger When not specified, default logger will be used.
+ * @param workLoadList The list of work load for different devices, in the same order as ctx
+ */
+ def fit(trainData: DataIter, evalData: DataIter, evalMetric: String = "acc",
+ kvStoreType: String = "local", logger: Logger = LOG,
+ workLoadList: Seq[Float] = null): Unit = {
+ val (argNames, paramNames, auxNames) =
+ initParams(trainData.provideData ++ trainData.provideLabel)
+ // TODO: kwargs.put("arg_names", argNames)
+
+ // TODO: setup metric
+
+ // create kvstore
+ val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams)
+
+ // init optmizer
+ val batchSizeMultiplier = kvStore.map { kv =>
+ if (kv.`type` == "dist_sync") {
+ kv.numWorkers
+ } else {
+ 1
+ }
+ }
+ val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1)
+ // TODO: temporarily hard-coded sgd optimizer
+ val optimizer = new SGD(rescaleGrad = 1f / batchSize)
+ Model.trainMultiDevice(
+ symbol, ctx, argNames, paramNames, auxNames,
+ _argParams, _auxParams,
+ this.beginEpoch, this.numEpoch,
+ this.epochSize,
+ optimizer,
+ kvStore.orNull, updateOnKVStore,
+ trainData = trainData, evalData = evalData,
+ logger = logger, workLoadList = workLoadList)
+ }
+}
+
+object FeedForward {
+ // Check if name is a data argument.
+ private def isDataArg(name: String): Boolean = {
+ name.endsWith("data") || name.endsWith("label")
+ }
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
index 9a2c7912f104..b301c1abf03c 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
@@ -11,6 +11,21 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
*/
object NDArray {
private val logger = LoggerFactory.getLogger(classOf[NDArray])
+
+ private[mxnet] val DTYPE_NATIVE_TO_MX: Map[Class[_ >: Float with Int with Double], Int] = Map(
+ classOf[Float] -> 0,
+ classOf[Double] -> 1,
+ classOf[Int] -> 4
+ )
+
+ private[mxnet] val DTYPE_MX_TO_NATIVE: Map[Int, Class[_ >: Float with Int with Double]] = Map(
+ 0 -> classOf[Float],
+ 1 -> classOf[Double],
+ 2 -> classOf[Float],
+ 3 -> classOf[Int],
+ 4 -> classOf[Int]
+ )
+
private val functions: Map[String, NDArrayFunction] = initNDArrayModule()
// Definition of internal functions.
@@ -552,6 +567,10 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean =
new NDArray(handle = sliceHandle.value, writable = this.writable)
}
+ def slice(range: (Int, Int)): NDArray = {
+ slice(range._1, range._2)
+ }
+
def slice(start: Int): NDArray = {
slice(start, shape(0))
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala
index d2935590f055..3b14d5b14fcc 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala
@@ -1,5 +1,6 @@
package ml.dmlc.mxnet
+import ml.dmlc.mxnet.Base.Shape
import ml.dmlc.mxnet.NDArray.{randomGaussian, randomUniform, empty}
/**
@@ -19,7 +20,7 @@ object Random {
*/
def uniform(low: Float,
high: Float,
- shape: Array[Int] = null,
+ shape: Shape = null,
ctx: Context = null,
out: NDArray = null): NDArray = {
var outCopy = out
@@ -45,7 +46,7 @@ object Random {
*/
def normal(mean: Float,
stdvar: Float,
- shape: Array[Int] = null,
+ shape: Shape = null,
ctx: Context = null,
out: NDArray = null): NDArray = {
var outCopy = out
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
index 33e10ad5ea7e..2d048422a989 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
@@ -10,7 +10,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
* @author Yizhi Liu
*/
class Symbol(private[mxnet] val handle: SymbolHandle) {
- def +(other: Symbol): Symbol = Symbol.create("_Plus", other)
+ def +(other: Symbol): Symbol = Symbol.create("_Plus", this, other)
override def clone(): Symbol = {
val clonedHandle = new SymbolHandleRef
@@ -18,37 +18,203 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
new Symbol(clonedHandle.value)
}
+ def get(index: Int): Symbol = {
+ val newHandle = new SymbolHandleRef
+ checkCall(_LIB.mxSymbolGetOutput(handle, index, newHandle))
+ new Symbol(handle = newHandle.value)
+ }
+
+ def get(name: String): Symbol = {
+ var index: Int = -1
+ for ((output, i) <- listOutputs().view.zipWithIndex) {
+ if (output == name) {
+ require(index == -1, s"There are multiple outputs with name $name")
+ index = i
+ }
+ }
+ require(index >= 0, s"Cannot find output that matches name $name")
+ get(index)
+ }
+
+ /**
+ * Get a new grouped symbol whose output contains all the internal outputs of this symbol.
+ * @return The internal of the symbol.
+ */
+ def getInternals: Symbol = {
+ val newHandle = new SymbolHandleRef
+ checkCall(_LIB.mxSymbolGetInternals(handle, newHandle))
+ new Symbol(handle = newHandle.value)
+ }
+
/**
* List all the arguments in the symbol.
* @return Array of all the arguments.
*/
- def listArguments(): Array[String] = {
+ def listArguments(): Seq[String] = {
val arr = ArrayBuffer.empty[String]
checkCall(_LIB.mxSymbolListArguments(handle, arr))
- arr.toArray
+ arr
}
/**
* List all outputs in the symbol.
* @return : List of all the outputs.
*/
- def listOutputs(): Array[String] = {
+ def listOutputs(): Seq[String] = {
val arr = ArrayBuffer.empty[String]
checkCall(_LIB.mxSymbolListOutputs(handle, arr))
- arr.toArray
+ arr
}
/**
* List all auxiliary states in the symbol.
* @return The names of the auxiliary states.
- * Notes
- * -----
+ * @note
* Auxiliary states are special states of symbols that do not corresponds to an argument,
* and do not have gradient. But still be useful for the specific operations.
* A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm.
* Most operators do not have Auxiliary states.
*/
- def listAuxiliaryStates(): Array[String] = ???
+ def listAuxiliaryStates(): Seq[String] = {
+ val sarr = ArrayBuffer.empty[String]
+ checkCall(_LIB.mxSymbolListAuxiliaryStates(handle, sarr))
+ sarr
+ }
+
+ /**
+ * Infer the type of outputs and arguments of given known types of arguments.
+ * Tuple of Nones is returned if there is not enough information passed in.
+ * An error will be raised if there is inconsistency found in the known types passed in.
+ * @param args Provide type of arguments in a positional way. Unknown type can be marked as null
+ * @return
+ * argTypes : list of numpy.dtype or None
+ * List of types of arguments.
+ * The order is in the same order as list_arguments()
+ * outTypes : list of numpy.dtype or None
+ * List of types of outputs.
+ * The order is in the same order as list_outputs()
+ * auxTypes : list of numpy.dtype or None
+ * List of types of outputs.
+ * The order is in the same order as list_auxiliary()
+ */
+ def inferType(args: Class[_ >: Float with Int with Double]*)
+ : (Seq[Class[_ >: Float with Int with Double]],
+ Seq[Class[_ >: Float with Int with Double]],
+ Seq[Class[_ >: Float with Int with Double]]) = {
+ val sdata: Array[Int] = args.map(NDArray.DTYPE_NATIVE_TO_MX.getOrElse(_, -1)).toArray
+ inferType(null, sdata)
+ }
+
+ /**
+ * Infer the type of outputs and arguments of given known types of arguments.
+ * Tuple of Nones is returned if there is not enough information passed in.
+ * An error will be raised if there is inconsistency found in the known types passed in.
+ * @param kwargs Provide keyword arguments of known types.
+ * @return
+ * argTypes : list of numpy.dtype or None
+ * List of types of arguments.
+ * The order is in the same order as list_arguments()
+ * outTypes : list of numpy.dtype or None
+ * List of types of outputs.
+ * The order is in the same order as list_outputs()
+ * auxTypes : list of numpy.dtype or None
+ * List of types of outputs.
+ * The order is in the same order as list_auxiliary()
+ */
+ def inferType(kwargs: Map[String, Class[_ >: Float with Int with Double]])
+ : (Seq[Class[_ >: Float with Int with Double]],
+ Seq[Class[_ >: Float with Int with Double]],
+ Seq[Class[_ >: Float with Int with Double]]) = {
+ val filteredArgs = kwargs.filter { case (key, value) =>
+ NDArray.DTYPE_NATIVE_TO_MX.contains(value)
+ }
+ val keys = filteredArgs.keys.toArray
+ val sdata = filteredArgs.values.map(NDArray.DTYPE_NATIVE_TO_MX(_)).toArray
+ inferType(keys, sdata)
+ }
+
+ private def inferType(keys: Array[String], values: Array[Int])
+ : (Seq[Class[_ >: Float with Int with Double]],
+ Seq[Class[_ >: Float with Int with Double]],
+ Seq[Class[_ >: Float with Int with Double]]) = {
+ val argTypeData = ListBuffer.empty[Int]
+ val outTypeData = ListBuffer.empty[Int]
+ val auxTypeData = ListBuffer.empty[Int]
+ val complete = new RefInt
+ checkCall(_LIB.mxSymbolInferType(
+ handle, keys, values, argTypeData, outTypeData, auxTypeData, complete))
+ if (complete.value != 0) {
+ (argTypeData.map(NDArray.DTYPE_MX_TO_NATIVE),
+ outTypeData.map(NDArray.DTYPE_MX_TO_NATIVE),
+ auxTypeData.map(NDArray.DTYPE_MX_TO_NATIVE))
+ } else {
+ (null, null, null)
+ }
+ }
+
+ /**
+ * Infer the shape of outputs and arguments of given known shapes of arguments.
+ * User can either pass in the known shapes in positional way or keyword argument way.
+ * Tuple of Nones is returned if there is not enough information passed in.
+ * An error will be raised if there is inconsistency found in the known shapes passed in.
+ * @param args Provide shape of arguments in a positional way.
+ * Unknown shape can be marked as None
+ * @return
+ * argShapes List of shapes of arguments. The order is in the same order as list_arguments()
+ * outShapes List of shapes of outputs. The order is in the same order as list_outputs()
+ * auxShapes List of shapes of outputs. The order is in the same order as list_auxiliary()
+ */
+ def inferShape(args: Shape*): (Seq[Shape], Seq[Shape], Seq[Shape]) = {
+ val keys: Array[String] = null
+ val indPtr = ArrayBuffer(0)
+ val sdata = ArrayBuffer.empty[Int]
+ args.foreach { shape =>
+ if (shape != null) {
+ sdata ++= shape
+ indPtr += sdata.size
+ }
+ }
+ inferShape(keys, indPtr.toArray, sdata.toArray)
+ }
+
+ /**
+ * Infer the shape of outputs and arguments of given known shapes of arguments.
+ * User can either pass in the known shapes in positional way or keyword argument way.
+ * Tuple of Nones is returned if there is not enough information passed in.
+ * An error will be raised if there is inconsistency found in the known shapes passed in.
+ * @param kwargs Provide keyword arguments of known shapes.
+ * @return
+ * argShapes List of shapes of arguments. The order is in the same order as list_arguments()
+ * outShapes List of shapes of outputs. The order is in the same order as list_outputs()
+ * auxShapes List of shapes of outputs. The order is in the same order as list_auxiliary()
+ */
+ def inferShape(kwargs: Map[String, Shape]): (Seq[Shape], Seq[Shape], Seq[Shape]) = {
+ val keys = ArrayBuffer.empty[String]
+ val indPtr = ArrayBuffer(0)
+ val sdata = ArrayBuffer.empty[Int]
+ kwargs.foreach { case (key, shape) =>
+ keys += key
+ sdata ++= shape
+ indPtr += sdata.size
+ }
+ inferShape(keys.toArray, indPtr.toArray, sdata.toArray)
+ }
+
+ def inferShape(keys: Array[String], indPtr: Array[Int], values: Array[Int])
+ : (Seq[Shape], Seq[Shape], Seq[Shape]) = {
+ val argShapeData = ListBuffer.empty[Shape]
+ val outShapeData = ListBuffer.empty[Shape]
+ val auxShapeData = ListBuffer.empty[Shape]
+ val complete = new RefInt
+
+ checkCall(_LIB.mxSymbolInferShape(handle, indPtr.size - 1, keys, indPtr, values,
+ argShapeData, outShapeData, auxShapeData, complete))
+ if (complete.value != 0) {
+ (argShapeData, outShapeData, auxShapeData)
+ } else {
+ (null, null, null)
+ }
+ }
/**
* Get attribute string from the symbol, this function only works for non-grouped symbol.
@@ -112,11 +278,390 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
val args = symbols.values.map(_.handle).toArray
checkCall(_LIB.mxSymbolCompose(handle, name, keys, args))
}
+
+ /**
+ * Bind current symbol to get an executor, allocate all the ndarrays needed.
+ * Allows specifying data types.
+ * This function will ask user to pass in ndarray of position
+ * they like to bind to, and it will automatically allocate the ndarray
+ * for arguments and auxiliary states that user did not specify explicitly.
+ *
+ * @param ctx The device context the generated executor to run on.
+ * @param gradReq {'write', 'add', 'null'}, or list of str or dict of str to str, optional
+ * Specifies how we should update the gradient to the args_grad.
+ * - 'write' means everytime gradient is write to specified args_grad NDArray.
+ * - 'add' means everytime gradient is add to the specified NDArray.
+ * - 'null' means no action is taken, the gradient may not be calculated.
+ * @param typeDict Input type dictionary, name->dtype
+ * @param shapeDict Input shape dictionary, name->shape
+ * @return The generated Executor
+ */
+ def simpleBind(ctx: Context, gradReq: String = "write",
+ shapeDict: Map[String, Shape],
+ typeDict: Map[String, Class[_ >: Float with Int with Double]] = null): Executor = {
+ val types =
+ if (typeDict == null) listArguments().map((_, classOf[Float])).toMap
+ else typeDict
+ val (argShapes, _, auxShapes) = inferShape(shapeDict)
+ val (argTypes, _, auxTypes) = inferType(types)
+ require(argShapes != null && argTypes != null, "Input node is not complete")
+ // alloc space
+ val argNDArrays = (argShapes zip argTypes) map { case (shape, t) =>
+ // TODO: NDArray dtype
+ NDArray.zeros(shape, ctx)
+ }
+ val gradNDArrays =
+ if (gradReq != "null") {
+ (((listArguments() zip argShapes) zip argTypes) flatMap { case ((name, shape), t) =>
+ if (!(name.endsWith("data") || name.endsWith("label"))) {
+ // TODO: NDArray dtype
+ Map(name -> NDArray.zeros(shape, ctx))
+ } else {
+ Map.empty[String, NDArray]
+ }
+ }).toMap
+ } else {
+ null
+ }
+ val auxNDArrays = (auxShapes zip auxTypes) map { case (shape, t) =>
+ // TODO: NDArray dtype
+ NDArray.zeros(shape, ctx)
+ }
+ bind(ctx, argNDArrays, gradNDArrays, gradReq, auxNDArrays, null)
+ }
+
+ /**
+ * Bind current symbol to get an executor.
+ *
+ * @param ctx Context The device context the generated executor to run on.
+ * @param args Input arguments to the symbol.
+ * - If type is list of NDArray, the position is in the same order of list_arguments.
+ * - If type is dict of str to NDArray, then it maps the name of arguments
+ * to the corresponding NDArray.
+ * - In either case, all the arguments must be provided.
+ * @param argsGrad When specified, args_grad provide NDArrays to hold
+ * the result of gradient value in backward.
+ * - If type is list of NDArray,
+ * the position is in the same order of list_arguments.
+ * - If type is dict of str to NDArray, then it maps the name of arguments
+ * to the corresponding NDArray.
+ * - When the type is dict of str to NDArray, users only need to provide the dict
+ * for needed argument gradient.
+ * Only the specified argument gradient will be calculated.
+ * @param gradReq {'write', 'add', 'null'}, or list of str or dict of str to str, optional
+ * Specifies how we should update the gradient to the args_grad.
+ * - 'write' means everytime gradient is write to specified args_grad NDArray.
+ * - 'add' means everytime gradient is add to the specified NDArray.
+ * - 'null' means no action is taken, the gradient may not be calculated.
+ * @param auxStates Input auxiliary states to the symbol, only need to specify when
+ * list_auxiliary_states is not empty.
+ * - If type is list of NDArray,
+ * the position is in the same order of listAuxiliaryStates
+ * - If type is dict of str to NDArray, then it maps the name of auxiliary_states
+ * to the corresponding NDArray,
+ * - In either case, all the auxiliary_states need to be provided.
+ * @param group2ctx The dict mapping the ``ctx_group`` attribute to the context assignment.
+ * @return The generated Executor
+ * @note
+ * Auxiliary states are special states of symbols that do not corresponds to an argument,
+ * and do not have gradient. But still be useful for the specific operations.
+ * A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm.
+ * Most operators do not have auxiliary states and this parameter can be safely ignored.
+ *
+ * User can give up gradient by using a dict in args_grad and only specify
+ * gradient they interested in.
+ */
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray],
+ gradReq: String, auxStates: Seq[NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad,
+ Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray],
+ gradReq: String, auxStates: Map[String, NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad,
+ Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray],
+ gradReq: String, auxStates: Seq[NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad,
+ Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray],
+ gradReq: String, auxStates: Map[String, NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad,
+ Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray],
+ gradReq: String, auxStates: Seq[NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad,
+ Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray],
+ gradReq: String, auxStates: Map[String, NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad,
+ Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray],
+ gradReq: String, auxStates: Seq[NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad,
+ Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray],
+ gradReq: String, auxStates: Map[String, NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad,
+ Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray],
+ gradsReq: Seq[String], auxStates: Seq[NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray],
+ gradsReq: Seq[String], auxStates: Map[String, NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray],
+ gradsReq: Seq[String], auxStates: Seq[NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray],
+ gradsReq: Seq[String], auxStates: Map[String, NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray],
+ gradsReq: Seq[String], auxStates: Seq[NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray],
+ gradsReq: Seq[String], auxStates: Map[String, NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray],
+ gradsReq: Seq[String], auxStates: Seq[NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray],
+ gradsReq: Seq[String], auxStates: Map[String, NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray],
+ gradsReq: Map[String, String], auxStates: Seq[NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray],
+ gradsReq: Map[String, String], auxStates: Map[String, NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray],
+ gradsReq: Map[String, String], auxStates: Seq[NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray],
+ gradsReq: Map[String, String], auxStates: Map[String, NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray],
+ gradsReq: Map[String, String], auxStates: Seq[NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray],
+ gradsReq: Map[String, String], auxStates: Map[String, NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray],
+ gradsReq: Map[String, String], auxStates: Seq[NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray],
+ gradsReq: Map[String, String], auxStates: Map[String, NDArray],
+ group2ctx: Map[String, Context]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray]): Executor = {
+ bind(ctx, args, argsGrad, "write", Nil, null)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray]): Executor = {
+ bind(ctx, args, argsGrad, "write", Nil, null)
+ }
+
+ def bind(ctx: Context, args: Seq[NDArray]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, null,
+ Seq.fill(symbolArguments.size)("write"), Nil, null)
+ }
+
+ def bind(ctx: Context, args: Map[String, NDArray]): Executor = {
+ val symbolArguments = listArguments()
+ bindHelper(ctx, symbolArguments, args, null,
+ Seq.fill(symbolArguments.size)("write"), Nil, null)
+ }
+
+ private def bindHelper(ctx: Context, symbolArguments: Seq[String],
+ args: Iterable[_], argsGrad: Iterable[_],
+ gradsReq: Iterable[_], auxStates: Iterable[_],
+ group2ctx: Map[String, Context]): Executor = {
+ require(args != null && !args.isInstanceOf[Set[_]])
+ require(argsGrad == null || !argsGrad.isInstanceOf[Set[_]])
+ require(auxStates == null || !auxStates.isInstanceOf[Set[_]])
+ require(gradsReq != null && !gradsReq.isInstanceOf[Set[_]])
+
+ val (argsHandle, argsNDArray) =
+ if (args.isInstanceOf[Seq[_]]) {
+ Symbol.getNDArrayInputs("args", args.asInstanceOf[Seq[NDArray]],
+ symbolArguments, allowMissing = false)
+ } else {
+ Symbol.getNDArrayInputs("args", args.asInstanceOf[Map[String, NDArray]],
+ symbolArguments, allowMissing = false)
+ }
+
+ // setup args gradient
+ val (argsGradHandle, argsGradNDArray) =
+ if (argsGrad == null) {
+ (Array.fill[NDArrayHandle](args.size)(0L), null)
+ } else if (argsGrad.isInstanceOf[Seq[_]]) {
+ Symbol.getNDArrayInputs("args_grad", argsGrad.asInstanceOf[Seq[NDArray]],
+ symbolArguments, allowMissing = true)
+ } else {
+ Symbol.getNDArrayInputs("args_grad", argsGrad.asInstanceOf[Map[String, NDArray]],
+ symbolArguments, allowMissing = true)
+ }
+
+ val (auxArgsHandle, auxStatesNDArray) =
+ if (auxStates == null) {
+ Symbol.getNDArrayInputs("aux_states", Nil, listAuxiliaryStates(), allowMissing = false)
+ } else if (auxStates.isInstanceOf[Seq[_]]) {
+ Symbol.getNDArrayInputs("aux_states", auxStates.asInstanceOf[Seq[NDArray]],
+ listAuxiliaryStates(), allowMissing = false)
+ } else {
+ Symbol.getNDArrayInputs("aux_states", auxStates.asInstanceOf[Map[String, NDArray]],
+ listAuxiliaryStates(), allowMissing = false)
+ }
+
+ // setup requirements
+ val reqsArray =
+ if (gradsReq.isInstanceOf[Seq[_]]) {
+ gradsReq.asInstanceOf[Seq[String]].map { req =>
+ require(Symbol.bindReqMap.contains(req), s"grad_req must be in ${Symbol.bindReqMap}")
+ Symbol.bindReqMap(req)
+ }.toArray
+ } else {
+ val gradsReqMap = gradsReq.asInstanceOf[Map[String, String]]
+ symbolArguments.map { req =>
+ val value = gradsReqMap.getOrElse(req, "null")
+ require(Symbol.bindReqMap.contains(value), s"grad_req must be in ${Symbol.bindReqMap}")
+ Symbol.bindReqMap(value)
+ }.toArray
+ }
+
+ val ctxMapKeys = ArrayBuffer.empty[String]
+ val ctxMapDevTypes = ArrayBuffer.empty[Int]
+ val ctxMapDevIDs = ArrayBuffer.empty[Int]
+
+ if (group2ctx != null) {
+ group2ctx.foreach { case (key, value) =>
+ ctxMapKeys += key
+ ctxMapDevTypes += value.deviceTypeid
+ ctxMapDevIDs += value.deviceId
+ }
+ }
+
+ val execHandle = new ExecutorHandleRef
+ checkCall(_LIB.mxExecutorBindX(handle,
+ ctx.deviceTypeid,
+ ctx.deviceId,
+ ctxMapKeys.size,
+ ctxMapKeys.toArray,
+ ctxMapDevTypes.toArray,
+ ctxMapDevIDs.toArray,
+ args.size,
+ argsHandle,
+ argsGradHandle,
+ reqsArray,
+ auxArgsHandle,
+ execHandle))
+ val executor = new Executor(execHandle.value, this)
+ executor.argArrays = argsNDArray
+ executor.gradArrays = argsGradNDArray
+ executor.auxArrays = auxStatesNDArray
+ executor
+ }
}
object Symbol {
private val logger = LoggerFactory.getLogger(classOf[Symbol])
private val functions: Map[String, SymbolFunction] = initSymbolModule()
+ private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)
/**
* Create a symbolic variable with specified name.
@@ -148,6 +693,34 @@ object Symbol {
createNoCheck("Activation", attr)
}
+ def Convolution(attr: Map[String, String]): Map[String, Any] => Symbol = {
+ createNoCheck("Convolution", attr)
+ }
+
+ def Convolution: Map[String, Any] => Symbol = {
+ Convolution(null)
+ }
+
+ def BatchNorm: Map[String, Any] => Symbol = {
+ createNoCheck("BatchNorm")
+ }
+
+ def Pooling: Map[String, Any] => Symbol = {
+ createNoCheck("Pooling")
+ }
+
+ def Flatten: Map[String, Any] => Symbol = {
+ createNoCheck("Flatten")
+ }
+
+ def SoftmaxOutput: Map[String, Any] => Symbol = {
+ createNoCheck("SoftmaxOutput")
+ }
+
+ def Cast: Map[String, Any] => Symbol = {
+ createNoCheck("Cast")
+ }
+
/**
* Create a symbol that groups symbols together.
* @param symbols List of symbols to be grouped.
@@ -295,6 +868,44 @@ object Symbol {
}
create(operator, symbolArgs, strArgs, attr)
}
+
+ /**
+ * Helper function to get ndarray lists handles from various inputs.
+ * @param argKey The name of argument, used for error message.
+ * @param args list of NDArray or dict of str to NDArray
+ * Input arguments to the symbols.
+ * If type is list of NDArray, the position is in the same order of arg_names.
+ * If type is dict of str to NDArray, then it maps the name of arguments
+ * to the corresponding NDArray
+ * @param argNames List of argument names.
+ * @param allowMissing Whether missing argument is allowed.
+ * When allowed, the missing handle will be set to None(null)
+ * @return The positional list of NDArrayHandles generated from input.
+ */
+ private def getNDArrayInputs(argKey: String, args: Seq[NDArray], argNames: Seq[String],
+ allowMissing: Boolean): (Array[NDArrayHandle], Array[NDArray]) = {
+ require(args.length == argNames.length, s"Length of $argKey do not match number of arguments")
+ val argHandles = args.map(_.handle)
+ (argHandles.toArray, args.toArray)
+ }
+
+ private def getNDArrayInputs(argKey: String, args: Map[String, NDArray], argNames: Seq[String],
+ allowMissing: Boolean): (Array[NDArrayHandle], Array[NDArray]) = {
+ val argArrays = ArrayBuffer.empty[NDArray]
+ val argHandles = ArrayBuffer.empty[NDArrayHandle]
+ argNames.foreach { name =>
+ args.get(name) match {
+ case narr: Some[NDArray] =>
+ argArrays += narr.get
+ argHandles += narr.get.handle
+ case None =>
+ require(allowMissing, s"Must specify all the arguments in $argKey")
+ argArrays += null
+ argHandles += 0L
+ }
+ }
+ (argHandles.toArray, argArrays.toArray)
+ }
}
private case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String)
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala
new file mode 100644
index 000000000000..8716bb181e04
--- /dev/null
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala
@@ -0,0 +1,48 @@
+package ml.dmlc.mxnet
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+class ExecutorSuite extends FunSuite with BeforeAndAfterAll {
+ test("bind") {
+ val shape = Array(100, 30)
+ val lhs = Symbol.Variable("lhs")
+ val rhs = Symbol.Variable("rhs")
+ val ret = lhs + rhs
+ assert(ret.listArguments().toArray === Array("lhs", "rhs"))
+
+ val lhsArr = Random.uniform(-10f, 10f, shape)
+ val rhsArr = Random.uniform(-10f, 10f, shape)
+ val lhsGrad = NDArray.empty(shape)
+ val rhsGrad = NDArray.empty(shape)
+
+ val executor = ret.bind(Context.cpu(), args = Seq(lhsArr, rhsArr),
+ argsGrad = Seq(lhsGrad, rhsGrad))
+ val exec3 = ret.bind(Context.cpu(), args = Seq(lhsArr, rhsArr))
+ val exec4 = ret.bind(Context.cpu(), args = Map("rhs" -> rhsArr, "lhs" -> lhsArr),
+ argsGrad = Map("lhs" -> lhsGrad, "rhs" -> rhsGrad))
+ executor.forward()
+ exec3.forward()
+ exec4.forward()
+
+ val out1 = lhsArr + rhsArr
+ val out2 = executor.outputs(0)
+ val out3 = exec3.outputs(0)
+ val out4 = exec4.outputs(0)
+ assert(reldiff(out1, out2) < 1e-6)
+ assert(reldiff(out1, out3) < 1e-6)
+ assert(reldiff(out1, out4) < 1e-6)
+
+ // test gradient
+ val outGrad = NDArray.ones(shape)
+ val (lhsGrad2, rhsGrad2) = (outGrad, outGrad)
+ executor.backward(Array(outGrad))
+ assert(reldiff(lhsGrad, lhsGrad2) < 1e-6)
+ assert(reldiff(rhsGrad, rhsGrad2) < 1e-6)
+ }
+
+ private def reldiff(a: NDArray, b: NDArray): Float = {
+ val diff = NDArray.sum(NDArray.abs(a - b)).toScalar
+ val norm = NDArray.sum(NDArray.abs(a)).toScalar
+ diff / norm
+ }
+}
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
index 27be2eb228dc..87717977e284 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
@@ -20,7 +20,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
"seed" -> "10"
)
- val mnistIter = IO.createIterator("MNISTIter", params)
+ val mnistIter = IO.MNISTIter(params)
// test_loop
mnistIter.reset()
val nBatch = 600
@@ -34,15 +34,15 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
// test reset
mnistIter.reset()
mnistIter.iterNext()
- val label0 = mnistIter.getLabel().toArray
- val data0 = mnistIter.getData().toArray
+ val label0 = mnistIter.getLabel().head.toArray
+ val data0 = mnistIter.getData().head.toArray
mnistIter.iterNext()
mnistIter.iterNext()
mnistIter.iterNext()
mnistIter.reset()
mnistIter.iterNext()
- val label1 = mnistIter.getLabel().toArray
- val data1 = mnistIter.getData().toArray
+ val label1 = mnistIter.getLabel().head.toArray
+ val data1 = mnistIter.getData().head.toArray
assert(label0 === label1)
assert(data0 === data1)
}
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala
index edc6ace47444..8e14d9565551 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala
@@ -8,7 +8,7 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll {
var net1 = Symbol.FullyConnected(Map("data" -> data, "name" -> "fc1", "num_hidden" -> 10))
net1 = Symbol.FullyConnected(Map("data" -> net1, "name" -> "fc2", "num_hidden" -> 100))
- assert(net1.listArguments() ===
+ assert(net1.listArguments().toArray ===
Array("data", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias"))
var net2 = Symbol.FullyConnected(Map("name" -> "fc3", "num_hidden" -> 10))
@@ -25,4 +25,27 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll {
val multiOut = Symbol.Group(composed, net1)
assert(multiOut.listOutputs().length === 2)
}
+
+ test("symbol internal") {
+ val data = Symbol.Variable("data")
+ val oldfc = Symbol.FullyConnected(Map("data" -> data, "name" -> "fc1", "num_hidden" -> 10))
+ val net1 = Symbol.FullyConnected(Map("data" -> oldfc, "name" -> "fc2", "num_hidden" -> 100))
+ assert(net1.listArguments().toArray
+ === Array("data", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias"))
+ val internal = net1.getInternals
+ val fc1 = internal.get("fc1_output")
+ assert(fc1.listArguments() === oldfc.listArguments())
+ }
+
+ test("symbol infer type") {
+ val data = Symbol.Variable("data")
+ val f32data = Symbol.Cast(Map("data" -> data, "dtype" -> "float32"))
+ val fc1 = Symbol.FullyConnected(Map("data" -> f32data, "name" -> "fc1", "num_hidden" -> 128))
+ val mlp = Symbol.SoftmaxOutput(Map("data" -> fc1, "name" -> "softmax"))
+
+ val (arg, out, aux) = mlp.inferType(Map("data" -> classOf[Double]))
+ assert(arg.toArray === Array(classOf[Double], classOf[Float], classOf[Float], classOf[Float]))
+ assert(out.toArray === Array(classOf[Float]))
+ assert(aux.isEmpty)
+ }
}
diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
index a37d8a699409..a30215b067a0 100644
--- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
@@ -527,7 +527,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterCreateIter
(JNIEnv * env, jobject obj, jlong creator, jobjectArray jkeys,
jobjectArray jvals, jobject dataIterHandleRef) {
- //keys and values
+ // keys and values
int paramSize = env->GetArrayLength(jkeys);
char** keys = new char*[paramSize];
char** vals = new char*[paramSize];
@@ -863,6 +863,23 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListOutputs
return ret;
}
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListAuxiliaryStates
+ (JNIEnv *env, jobject obj, jlong symbolPtr, jobject outputs) {
+ mx_uint outSize;
+ const char **outStrArray;
+ int ret = MXSymbolListAuxiliaryStates((SymbolHandle) symbolPtr, &outSize, &outStrArray);
+
+ jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
+ jmethodID arrayAppend = env->GetMethodID(arrayClass,
+ "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
+ for (int i = 0; i < outSize; i++) {
+ jstring output = env->NewStringUTF(outStrArray[i]);
+ env->CallObjectMethod(outputs, arrayAppend, output);
+ }
+
+ return ret;
+}
+
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCopy
(JNIEnv *env, jobject obj, jlong symbolPtr, jobject clonedSymbolRef) {
SymbolHandle clonedSymbol;
@@ -889,3 +906,231 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolPrint
setStringField(env, out, outStr);
return ret;
}
+
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolGetOutput
+ (JNIEnv *env, jobject obj, jlong symbolPtr, jint index, jobject jout) {
+ SymbolHandle out;
+ int ret = MXSymbolGetOutput((SymbolHandle) symbolPtr, (mx_uint) index, &out);
+ setLongField(env, jout, (long) out);
+ return ret;
+}
+
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolGetInternals
+ (JNIEnv *env, jobject obj, jlong symbolPtr, jobject jout) {
+ SymbolHandle out;
+ int ret = MXSymbolGetInternals((SymbolHandle) symbolPtr, &out);
+ setLongField(env, jout, (long)out);
+ return ret;
+}
+
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolInferType
+ (JNIEnv *env, jobject obj, jlong symbolPtr, jobjectArray jkeys, jintArray jvals,
+ jobject jargTypeData, jobject joutTypeData, jobject jauxTypeData, jobject jcomplete) {
+ int numArgs = env->GetArrayLength(jvals);
+ const char **keys = NULL;
+ if (jkeys != NULL) {
+ keys = new const char *[numArgs];
+ for (int i = 0; i < numArgs; i++) {
+ jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
+ const char *key = env->GetStringUTFChars(jkey, 0);
+ keys[i] = key;
+ }
+ }
+
+ mx_uint inTypeSize;
+ const int *inTypeData;
+ mx_uint outTypeSize;
+ const int *outTypeData;
+ mx_uint auxTypeSize;
+ const int *auxTypeData;
+ int complete;
+
+ jint *vals = env->GetIntArrayElements(jvals, NULL);
+ int ret = MXSymbolInferType((SymbolHandle) symbolPtr,
+ (mx_uint) numArgs, keys, (const int *) vals,
+ &inTypeSize, &inTypeData,
+ &outTypeSize, &outTypeData,
+ &auxTypeSize, &auxTypeData,
+ &complete);
+ env->ReleaseIntArrayElements(jvals, vals, 0);
+
+ jclass integerClass = env->FindClass("java/lang/Integer");
+ jmethodID newInteger = env->GetMethodID(integerClass, "", "(I)V");
+
+ jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
+ jmethodID listAppend = env->GetMethodID(listClass,
+ "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
+
+ for (int i = 0; i < inTypeSize; ++i) {
+ jobject data = env->NewObject(integerClass, newInteger, inTypeData[i]);
+ env->CallObjectMethod(jargTypeData, listAppend, data);
+ }
+ for (int i = 0; i < outTypeSize; ++i) {
+ jobject data = env->NewObject(integerClass, newInteger, outTypeData[i]);
+ env->CallObjectMethod(joutTypeData, listAppend, data);
+ }
+ for (int i = 0; i < auxTypeSize; ++i) {
+ jobject data = env->NewObject(integerClass, newInteger, auxTypeData[i]);
+ env->CallObjectMethod(jauxTypeData, listAppend, data);
+ }
+
+ setIntField(env, jcomplete, complete);
+
+ // release allocated memory
+ if (jkeys != NULL) {
+ for (int i = 0; i < numArgs; i++) {
+ jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
+ env->ReleaseStringUTFChars(jkey, keys[i]);
+ }
+ delete[] keys;
+ }
+
+ return ret;
+}
+
+int FillSymbolInferShape
+ (JNIEnv *env, jmethodID listAppend, jobject joutData,
+ mx_uint shapeSize, const mx_uint *shapeNdim, const mx_uint **shapeData) {
+ for (int i = 0; i < shapeSize; ++i) {
+ jintArray jshape;
+ jshape = env->NewIntArray(shapeNdim[i]);
+ if (jshape == NULL) {
+ // TODO: out of memory error thrown, return a specific error code ?
+ return -1;
+ }
+ env->SetIntArrayRegion(jshape, 0, shapeNdim[i], (const jint *) shapeData[i]);
+ env->CallObjectMethod(joutData, listAppend, jshape);
+ }
+ return 0;
+}
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolInferShape
+ (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray jkeys,
+ jintArray jargIndPtr, jintArray jargShapeData,
+ jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, jobject jcomplete) {
+ const char **keys = NULL;
+ if (jkeys != NULL) {
+ keys = new const char *[jnumArgs];
+ for (int i = 0; i < jnumArgs; i++) {
+ jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
+ const char *key = env->GetStringUTFChars(jkey, 0);
+ keys[i] = key;
+ }
+ }
+
+ mx_uint inShapeSize;
+ const mx_uint *inShapeNdim;
+ const mx_uint **inShapeData;
+
+ mx_uint outShapeSize;
+ const mx_uint *outShapeNdim;
+ const mx_uint **outShapeData;
+
+ mx_uint auxShapeSize;
+ const mx_uint *auxShapeNdim;
+ const mx_uint **auxShapeData;
+
+ int complete;
+
+ jint *argIndPtr = env->GetIntArrayElements(jargIndPtr, NULL);
+ jint *argShapeData = env->GetIntArrayElements(jargShapeData, NULL);
+ int ret = MXSymbolInferShape((SymbolHandle) symbolPtr,
+ (mx_uint) jnumArgs,
+ keys,
+ (const mx_uint *) argIndPtr,
+ (const mx_uint *) argShapeData,
+ &inShapeSize,
+ &inShapeNdim,
+ &inShapeData,
+ &outShapeSize,
+ &outShapeNdim,
+ &outShapeData,
+ &auxShapeSize,
+ &auxShapeNdim,
+ &auxShapeData,
+ &complete);
+ env->ReleaseIntArrayElements(jargShapeData, argShapeData, 0);
+ env->ReleaseIntArrayElements(jargIndPtr, argIndPtr, 0);
+
+ jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
+ jmethodID listAppend = env->GetMethodID(listClass,
+ "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
+
+ if (FillSymbolInferShape(env, listAppend, jinShapeData, inShapeSize, inShapeNdim, inShapeData)) {
+ // TODO: out of memory error thrown, return a specific error code ?
+ return -1;
+ }
+ if (FillSymbolInferShape(
+ env, listAppend, joutShapeData, outShapeSize, outShapeNdim, outShapeData)) {
+ // TODO: out of memory error thrown, return a specific error code ?
+ return -1;
+ }
+ if (FillSymbolInferShape(
+ env, listAppend, jauxShapeData, auxShapeSize, auxShapeNdim, auxShapeData)) {
+ // TODO: out of memory error thrown, return a specific error code ?
+ return -1;
+ }
+
+ setIntField(env, jcomplete, complete);
+
+ // release allocated memory
+ if (jkeys != NULL) {
+ for (int i = 0; i < jnumArgs; i++) {
+ jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
+ env->ReleaseStringUTFChars(jkey, keys[i]);
+ }
+ delete[] keys;
+ }
+
+ return ret;
+}
+
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorBindX
+ (JNIEnv *env, jobject obj, jlong symbolPtr, jint deviceTypeId, jint deviceID, jint numCtx,
+ jobjectArray jctxMapKeys, jintArray jctxMapDevTypes, jintArray jctxMapDevIDs, jint numArgs,
+ jlongArray jargsHandle, jlongArray jargsGradHandle, jintArray jreqsArray,
+ jlongArray jauxArgsHandle, jobject jexecOut) {
+
+ ExecutorHandle out;
+ int auxStatesLen = env->GetArrayLength(jauxArgsHandle);
+
+ const char **mapKeys = new const char *[numCtx];
+ for (int i = 0; i < numCtx; i++) {
+ jstring jkey = (jstring) env->GetObjectArrayElement(jctxMapKeys, i);
+ const char *key = env->GetStringUTFChars(jkey, 0);
+ mapKeys[i] = key;
+ }
+ jlong *auxStates = env->GetLongArrayElements(jauxArgsHandle, NULL);
+ jint *gradReqType = env->GetIntArrayElements(jreqsArray, NULL);
+ jlong *inArgs = env->GetLongArrayElements(jargsHandle, NULL);
+ jlong *argGradStore = env->GetLongArrayElements(jargsGradHandle, NULL);
+ jint *mapDevTypes = env->GetIntArrayElements(jctxMapDevTypes, NULL);
+ jint *mapDevIDs = env->GetIntArrayElements(jctxMapDevIDs, NULL);
+ int ret = MXExecutorBindX((SymbolHandle) symbolPtr,
+ deviceTypeId,
+ deviceID,
+ (mx_uint) numCtx,
+ mapKeys,
+ mapDevTypes,
+ mapDevIDs,
+ (mx_uint) numArgs,
+ (NDArrayHandle *) inArgs,
+ (NDArrayHandle *) argGradStore,
+ (mx_uint *) gradReqType,
+ (mx_uint) auxStatesLen,
+ (NDArrayHandle *) auxStates,
+ &out);
+ env->ReleaseIntArrayElements(jctxMapDevIDs, mapDevIDs, 0);
+ env->ReleaseIntArrayElements(jctxMapDevTypes, mapDevTypes, 0);
+ env->ReleaseLongArrayElements(jargsGradHandle, argGradStore, 0);
+ env->ReleaseLongArrayElements(jargsHandle, inArgs, 0);
+ env->ReleaseIntArrayElements(jreqsArray, gradReqType, 0);
+ env->ReleaseLongArrayElements(jauxArgsHandle, auxStates, 0);
+ for (int i = 0; i < numCtx; i++) {
+ jstring jkey = (jstring) env->GetObjectArrayElement(jctxMapKeys, i);
+ env->ReleaseStringUTFChars(jkey, mapKeys[i]);
+ }
+ delete[] mapKeys;
+
+ setLongField(env, jexecOut, (long) out);
+ return ret;
+}