diff --git a/scala-package/core/scalastyle-config.xml b/scala-package/core/scalastyle-config.xml index 847bbc2babe9..583a815a6fbe 100644 --- a/scala-package/core/scalastyle-config.xml +++ b/scala-package/core/scalastyle-config.xml @@ -60,7 +60,7 @@ You can also disable only one rule, by specifying its rule id, as specified in: - + diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 9f02c6e4e2e0..206b8d13f108 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -10,6 +10,8 @@ object Base { type MXUint = Int type MXFloat = Float type CPtrAddress = Long + // TODO: make it more friendly to java + type Shape = Array[Int] type NDArrayHandle = CPtrAddress type FunctionHandle = CPtrAddress diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala index 1ca12df2a975..c2a2325047f3 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala @@ -4,10 +4,19 @@ object Context { val devtype2str = Map(1 -> "cpu", 2 -> "gpu", 3 -> "cpu_pinned") val devstr2type = Map("cpu" -> 1, "gpu" -> 2, "cpu_pinned" -> 3) val defaultCtx = new Context("cpu", 0) + + def cpu(deviceId: Int = 0): Context = { + new Context("cpu", deviceId) + } + + def gpu(deviceId: Int = 0): Context = { + new Context("gpu", deviceId) + } } /** * Constructing a context. + * @author Yizhi Liu * @param deviceTypeName {'cpu', 'gpu'} String representing the device type * @param deviceId (default=0) The device id of the device, needed for GPU */ @@ -23,4 +32,8 @@ class Context(deviceTypeName: String, val deviceId: Int = 0) { * @return device_type */ def deviceType: String = Context.devtype2str(deviceTypeid) + + override def toString: String = { + deviceType + } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index 60513db13635..5d7ffd8c3795 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -1,12 +1,14 @@ package ml.dmlc.mxnet import ml.dmlc.mxnet.Base._ +import org.slf4j.{Logger, LoggerFactory} import scala.collection.mutable.ArrayBuffer object Executor { // Get the dictionary given name and ndarray pairs. - private def getDict(names: Array[String], ndarrays: Array[NDArray]): Map[String, NDArray] = { + private[mxnet] def getDict(names: Seq[String], + ndarrays: Seq[NDArray]): Map[String, NDArray] = { require(names.toSet.size == names.length, "Duplicate names detected") (names zip ndarrays).toMap } @@ -19,12 +21,11 @@ object Executor { * @throws IllegalArgumentException * If there are two many splits such that some slice can be empty. */ - private def splitInputSlice[@specialized(Int, Float, Double) V] - (batchSize: Int, workLoadList: Array[V]) - (implicit num: Numeric[V]): Array[(Int, Int)] = { - val totalWorkLoad = workLoadList.sum.asInstanceOf[Float] + private[mxnet] def splitInputSlice(batchSize: Int, + workLoadList: Seq[Float]): Array[(Int, Int)] = { + val totalWorkLoad = workLoadList.sum val batchNumList = workLoadList.map(workLoad => - math.round(workLoad.asInstanceOf[Float] * batchSize / totalWorkLoad)) + math.round(workLoad * batchSize / totalWorkLoad)).toArray val batchNumSum = batchNumList.sum if (batchNumSum < batchSize) { batchNumList(batchNumList.length-1) += batchSize - batchNumSum @@ -47,7 +48,7 @@ object Executor { * The check is done for feedforward net for now. * @param symbol The network configuration */ - private def checkArguments(symbol: Symbol): Unit = { + private[mxnet] def checkArguments(symbol: Symbol): Unit = { val argNames = symbol.listArguments() require(argNames.toSet.size == argNames.length, "Find duplicated argument name," + @@ -62,35 +63,50 @@ object Executor { } // Load a list of arrays into a list of arrays - private def loadGeneral(data: Array[NDArray], targets: Array[NDArray]): Unit = { + private[mxnet] def loadGeneral(data: Array[NDArray], targets: Array[NDArray]): Unit = { (data zip targets).foreach { case (dSrc, dTarget) => dSrc.copyTo(dTarget) } } // Load a list of arrays into a list of arrays specified by slices - private def loadGeneral(data: Array[NDArray], targets: Array[(Int, Int, NDArray)]): Unit = { - (data zip targets).foreach { case (dSrc, (start, end, dTarget)) => - dSrc.slice(start, end).copyTo(dTarget) + private[mxnet] def loadGeneral(data: IndexedSeq[NDArray], + targets: Seq[Array[(Int, Int, NDArray)]]): Unit = { + for ((src, dTargets) <- data zip targets) { + for ((start, end, dst) <- dTargets) { + src.slice(start, end).copyTo(dst) + } } } + + // Load data into sliced arrays + private[mxnet] def loadData(batch: DataBatch, + targets: Seq[Array[(Int, Int, NDArray)]]): Unit = { + loadGeneral(batch.data, targets) + } + + // Load label into sliced arrays + private[mxnet] def loadLabel(batch: DataBatch, + targets: Seq[Array[(Int, Int, NDArray)]]): Unit = { + loadGeneral(batch.label, targets) + } } /** * Symbolic Executor component of MXNet * @author Yizhi Liu * - * Constructor: used Symbol.bind and Symbol.simple_bind instead. + * Constructor: please use Symbol.bind and Symbol.simpleBind instead. * @param handle ExecutorHandle generated by calling Bind * @param symbol * @see Symbol.bind : to create executor */ // scalastyle:off finalize class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val symbol: Symbol) { - var argArrays: Array[NDArray] = null - protected var gradArrays: Array[NDArray] = null - protected var auxArrays: Array[NDArray] = null - protected var outputs: Array[NDArray] = getOutputs + private[mxnet] var argArrays: Array[NDArray] = null + private[mxnet] var gradArrays: Array[NDArray] = null + private[mxnet] var auxArrays: Array[NDArray] = null + val outputs: Array[NDArray] = getOutputs protected var _argDict: Map[String, NDArray] = null protected var _auxDict: Map[String, NDArray] = null protected var monitorCallback: MXMonitorCallback = null @@ -128,10 +144,9 @@ class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val sym /** * Do backward pass to get the gradient of arguments. - * @param outGrads - * Gradient on the outputs to be propagated back. - * This parameter is only needed when bind is called - * on outputs that are not a loss function. + * @param outGrads Gradient on the outputs to be propagated back. + * This parameter is only needed when bind is called + * on outputs that are not a loss function. */ def backward(outGrads: Array[NDArray]): Unit = { require(outGrads != null) @@ -233,3 +248,139 @@ class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val sym } } // scalastyle:on finalize + +/** + * Helper class to manage multiple executors for data parallelism. + * @author Yizhi Liu + * @param symbol output symbol + * @param ctx devices to run on + * @param paramNames Name of all trainable parameters of the network. + * @param argNames Name of all arguments of the network. + * @param auxNames Name of all auxiliary states of the network. + * @param trainData Training data iterator. + * @param workLoadList The list of work load for different devices, in the same order as ctx + * @param logger When not specified, default logger will be used. + */ +class DataParallelExecutorManager(symbol: Symbol, + ctx: Array[Context], + paramNames: Seq[String], + argNames: Seq[String], + private val auxNames: Seq[String], + trainData: DataIter, + private var workLoadList: Seq[Float] = null, + logger: Logger = DataParallelExecutorManager.logger) { + // preparation + private val numDevice = ctx.length + logger.info(s"Start training with ${ctx.mkString(",")}") + + // make sure the architecture is valid + Executor.checkArguments(symbol) + + if (workLoadList == null) { + workLoadList = Seq.fill(numDevice)(1f) + } + require(workLoadList.size == numDevice, "Invalid settings for work load.") + + private val slices = Executor.splitInputSlice(trainData.batchSize, workLoadList) + + private val trainExecs = + ctx.zipWithIndex.map { case (context, i) => + val dataShapes = + trainData.provideData.map { case (name: String, shape: Shape) => + (name, Array(slices(i)._2 - slices(i)._1) ++ shape.drop(1)) + } + symbol.simpleBind(context, "write", shapeDict = dataShapes) + } + + // data structure + private val dataNames = trainData.provideData.map(_._1).toArray + private val labelNames = trainData.provideLabel.map(_._1).toArray + + private val dataArrays = + dataNames.map { name => + trainExecs.zipWithIndex.map { case (exec, i) => + val slice = slices(i) + (slice._1, slice._2, exec.argDict(name)) + } + } + private val labelArrays = + labelNames.map { name => + trainExecs.zipWithIndex.map { case (exec, i) => + val slice = slices(i) + (slice._1, slice._2, exec.argDict(name)) + } + } + + private val paramIdx = (0 until argNames.length).filter { i => + paramNames.contains(argNames(i)) + } + private val _paramNames = paramIdx.map(argNames(_)) + private val paramArrays = paramIdx.map { i => trainExecs.map(_.argArrays(i)) }.toArray + private val gradArrays = paramIdx.map { i => trainExecs.map(_.gradArrays(i)) }.toArray + + private val auxArrays = (0 until auxNames.length).map { i => + trainExecs.map(_.auxArrays(i)) + }.toArray + private val batchSize = trainData.batchSize + private val outputShapes: Array[Shape] = trainExecs(0).outputs.map { x: NDArray => + Array(batchSize) ++ x.shape.drop(1) + } + private val cpuOutputArrays = outputShapes.map(NDArray.zeros(_)) + + // Install monitor on all executors + def installMonitor(monitor: Monitor): Unit = { + trainExecs.foreach(monitor.install) + } + + /** + * Set parameter and aux values + * @param argParams source parameter arrays + * @param auxParams source aux arrays + */ + def setParams(argParams: Map[String, NDArray], auxParams: Map[String, NDArray]): Unit = { + trainExecs.foreach(_.copyParamsFrom(argParams, auxParams)) + } + + /** + * Copy data from each executor to `arg_params` and `aux_params` + * @param argParams target parameter arrays + * @param auxParams target aux arrays + * @note This function will inplace update the NDArrays in arg_params and aux_params. + */ + def copyTo(argParams: Map[String, NDArray], auxParams: Map[String, NDArray]): Unit = { + for ((name, block) <- _paramNames zip paramArrays) { + val weight = block.map(_.copyTo(Context.cpu())).reduce(_ + _) / block.length + weight.copyTo(argParams(name)) + } + for ((name, block) <- auxNames zip auxArrays) { + val weight = block.map(_.copyTo(Context.cpu())).reduce(_ + _) / block.length + weight.copyTo(auxParams(name)) + } + } + + // load data and labels into arrays + def loadDataBatch(dataBatch: DataBatch): Unit = { + Executor.loadData(dataBatch, dataArrays) + Executor.loadLabel(dataBatch, labelArrays) + } + + // Perform a forward pass on each executor + def forward(isTrain: Boolean = false): Unit = { + for ((texec, islice) <- trainExecs zip slices) { + texec.forward(isTrain) + for ((cpuOut, devOut) <- cpuOutputArrays zip texec.outputs) { + devOut.copyTo(cpuOut.slice(islice)) + } + } + } + + // Perform a backward pass on each executor + def backward(): Unit = { + trainExecs.foreach(_.backward()) + } +} + +object DataParallelExecutorManager { + private val logger = LoggerFactory.getLogger(classOf[Model]) +} + diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 3b6cac583f75..6cd073461250 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -5,12 +5,18 @@ import org.slf4j.LoggerFactory import scala.collection.mutable.ListBuffer +/** + * IO iterators for loading training & validation data + * @author Zixuan Huang, Yizhi Liu + */ object IO { type IterCreateFunc = (Map[String, String]) => DataIter private val logger = LoggerFactory.getLogger(classOf[DataIter]) private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule() + def MNISTIter: IterCreateFunc = iterCreateFuncs("MNISTIter") + /** * create iterator via iterName and params * @param iterName name of iterator; "MNISTIter" or "ImageRecordIter" @@ -58,6 +64,12 @@ object IO { checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out)) new MXDataIter(out.value) } + + // Convert data into canonical form. + private def initData(data: NDArray, allowEmpty: Boolean, defaultName: String) = { + require(data != null || allowEmpty) + // TODO + } } @@ -68,15 +80,15 @@ object IO { * @param index * @param pad */ -case class DataBatch(data: NDArray, - label: NDArray, - index: List[Long], +case class DataBatch(data: IndexedSeq[NDArray], + label: IndexedSeq[NDArray], + index: IndexedSeq[Long], pad: Int) /** * DataIter object in mxnet. */ -abstract class DataIter (val batchSize: Int = 0) { +abstract class DataIter(val batchSize: Int = 0) { /** * reset the iterator */ @@ -100,13 +112,13 @@ abstract class DataIter (val batchSize: Int = 0) { * get data of current batch * @return the data of current batch */ - def getData(): NDArray + def getData(): IndexedSeq[NDArray] /** * Get label of current batch * @return the label of current batch */ - def getLabel(): NDArray + def getLabel(): IndexedSeq[NDArray] /** * get the number of padding examples @@ -119,8 +131,13 @@ abstract class DataIter (val batchSize: Int = 0) { * the index of current batch * @return */ - def getIndex(): List[Long] + def getIndex(): IndexedSeq[Long] + + // The name and shape of data provided by this iterator + def provideData: Map[String, Shape] + // The name and shape of label provided by this iterator + def provideLabel: Map[String, Shape] } /** @@ -156,31 +173,31 @@ class MXDataIter(val handle: DataIterHandle) extends DataIter { * get data of current batch * @return the data of current batch */ - override def getData(): NDArray = { + override def getData(): IndexedSeq[NDArray] = { val out = new NDArrayHandleRef checkCall(_LIB.mxDataIterGetData(handle, out)) - new NDArray(out.value, writable = false) + IndexedSeq(new NDArray(out.value, writable = false)) } /** * Get label of current batch * @return the label of current batch */ - override def getLabel(): NDArray = { + override def getLabel(): IndexedSeq[NDArray] = { val out = new NDArrayHandleRef checkCall(_LIB.mxDataIterGetLabel(handle, out)) - new NDArray(out.value, writable = false) + IndexedSeq(new NDArray(out.value, writable = false)) } /** * the index of current batch * @return */ - override def getIndex(): List[Long] = { + override def getIndex(): IndexedSeq[Long] = { val outIndex = new ListBuffer[Long] val outSize = new RefLong checkCall(_LIB.mxDataIterGetIndex(handle, outIndex, outSize)) - outIndex.toList + outIndex.toIndexedSeq } /** @@ -193,6 +210,12 @@ class MXDataIter(val handle: DataIterHandle) extends DataIter { checkCall(_LIB.mxDataIterGetPadNum(handle, out)) out.value } + + // The name and shape of data provided by this iterator + override def provideData: Map[String, Shape] = ??? + + // The name and shape of label provided by this iterator + override def provideLabel: Map[String, Shape] = ??? } // scalastyle:on finalize @@ -209,19 +232,19 @@ class ArrayDataIter() extends DataIter { * get data of current batch * @return the data of current batch */ - override def getData(): NDArray = ??? + override def getData(): IndexedSeq[NDArray] = ??? /** * Get label of current batch * @return the label of current batch */ - override def getLabel(): NDArray = ??? + override def getLabel(): IndexedSeq[NDArray] = ??? /** * the index of current batch * @return */ - override def getIndex(): List[Long] = ??? + override def getIndex(): IndexedSeq[Long] = ??? /** * Iterate to next batch @@ -235,6 +258,69 @@ class ArrayDataIter() extends DataIter { * @return number of padding examples in current batch */ override def getPad(): Int = ??? + + // The name and shape of data provided by this iterator + override def provideData: Map[String, Shape] = ??? + + // The name and shape of label provided by this iterator + override def provideLabel: Map[String, Shape] = ??? } +/** + * TODO + * NDArrayIter object in mxnet. Taking NDArray or numpy array to get dataiter. + * @param data NDArrayIter supports single or multiple data and label. + * @param label Same as data, but is not fed to the model during testing. + * @param batchSize Batch Size + * @param shuffle Whether to shuffle the data + * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch + * @note + * This iterator will pad, discard or roll over the last batch if + * the size of data does not match batch_size. Roll over is intended + * for training and can cause problems if used for prediction. + */ +class NDArrayIter(data: NDArray, label: NDArray = null, + batchSize: Int = 1, shuffle: Boolean = false, + lastBatchHandle: String = "pad") extends DataIter(batchSize) { + /** + * reset the iterator + */ + override def reset(): Unit = ??? + + /** + * get data of current batch + * @return the data of current batch + */ + override def getData(): IndexedSeq[NDArray] = ??? + + /** + * Get label of current batch + * @return the label of current batch + */ + override def getLabel(): IndexedSeq[NDArray] = ??? + + /** + * the index of current batch + * @return + */ + override def getIndex(): IndexedSeq[Long] = ??? + /** + * Iterate to next batch + * @return whether the move is successful + */ + override def iterNext(): Boolean = ??? + + /** + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ + override def getPad(): MXUint = ??? + + // The name and shape of data provided by this iterator + override def provideData: Map[String, Shape] = ??? + + // The name and shape of label provided by this iterator + override def provideLabel: Map[String, Shape] = ??? +} diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 8bfdf2f9cc2e..6e4ccbb273e3 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -145,8 +145,43 @@ class LibInfo { @native def mxSymbolListArguments(handle: SymbolHandle, arguments: ArrayBuffer[String]): Int @native def mxSymbolCopy(handle: SymbolHandle, clonedHandle: SymbolHandleRef): Int + @native def mxSymbolListAuxiliaryStates(handle: SymbolHandle, + arguments: ArrayBuffer[String]): Int @native def mxSymbolListOutputs(handle: SymbolHandle, outputs: ArrayBuffer[String]): Int @native def mxSymbolCreateGroup(handles: Array[SymbolHandle], out: SymbolHandleRef): Int @native def mxSymbolPrint(handle: SymbolHandle, str: RefString): Int + @native def mxSymbolGetInternals(handle: SymbolHandle, out: SymbolHandleRef): Int + @native def mxSymbolInferType(handle: SymbolHandle, + keys: Array[String], + sdata: Array[Int], + argTypeData: ListBuffer[Int], + outTypeData: ListBuffer[Int], + auxTypeData: ListBuffer[Int], + complete: RefInt): Int + @native def mxSymbolInferShape(handle: SymbolHandle, + numArgs: MXUint, + keys: Array[String], + argIndPtr: Array[MXUint], + argShapeData: Array[MXUint], + inShapeData: ListBuffer[Shape], + outShapeData: ListBuffer[Shape], + auxShapeData: ListBuffer[Shape], + complete: RefInt): Int + @native def mxSymbolGetOutput(handle: SymbolHandle, index: Int, out: SymbolHandleRef): Int + // scalastyle:off parameterNum + @native def mxExecutorBindX(handle: SymbolHandle, + deviceTypeId: Int, + deviceID: Int, + numCtx: Int, + ctxMapKeys: Array[String], + ctxMapDevTypes: Array[Int], + ctxMapDevIDs: Array[Int], + numArgs: Int, + argsHandle: Array[NDArrayHandle], + argsGradHandle: Array[NDArrayHandle], + reqsArray: Array[Int], + auxArgsHandle: Array[NDArrayHandle], + out: ExecutorHandleRef): Int + // scalastyle:on parameterNum } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala index 1912832341a3..f5100440042e 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala @@ -1,11 +1,16 @@ package ml.dmlc.mxnet -import org.slf4j.LoggerFactory +import ml.dmlc.mxnet.Base.Shape +import ml.dmlc.mxnet.optimizer.SGD +import org.slf4j.{Logger, LoggerFactory} + +import scala.collection.mutable /** * Describe the model flow * @author Yizhi Liu */ +class Model object Model { private val logger = LoggerFactory.getLogger(classOf[Model]) /** @@ -13,12 +18,12 @@ object Model { * This function select and create a proper kvstore given the kvstore type * @param kvStore KVStore type * @param numDevice The number of devices - * @param maxSize max size of the kvstore + * @param argParams Model parameter, dict of name to NDArray of net's weights. * @return Option of created [[KVStore]] and whether or not update weight on it */ - private def createKVStore(kvStore: String, - numDevice: Int, - maxSize: Int): (Option[KVStore], Boolean) = { + private[mxnet] def createKVStore(kvStore: String, + numDevice: Int, + argParams: Map[String, NDArray]): (Option[KVStore], Boolean) = { if (numDevice == 1 && !kvStore.contains("dist")) { // no need to use kv for single device and single machine (None, false) @@ -26,6 +31,7 @@ object Model { var kvType = kvStore if (kvType == "local") { // automatically select a proper local + val maxSize = argParams.values.map(_.shape.product).max kvType = if (maxSize < 1024 * 1024 * 16) { "local_update_cpu" @@ -39,8 +45,8 @@ object Model { } /** - * Create a kvstore (wrap it with Option, None if given kvStore == null) - * @param kvStore + * Create a kvStore (wrap it with Option, None if given kvStore == null) + * @param kvStore KVStore * @return Option of created [[KVStore]] and whether or not update weight on it */ private def createKVStore(kvStore: KVStore): (Option[KVStore], Boolean) = { @@ -100,7 +106,273 @@ object Model { } } } + + /** + * TODO + * Internal training function on multiple devices. + * This function will also work for single device as well. + * @param symbol The network configuration + * @param ctx The training devices. + * @param argNames Name of all arguments of the network. + * @param paramNames Name of all trainable parameters of the network. + * @param auxNames Name of all auxiliary states of the network. + * @param argParams Model parameter, dict of name to NDArray of net's weights. + * @param auxParams Model parameter, dict of name to NDArray of net's auxiliary states. + * @param beginEpoch The begining training epoch. + * @param endEpoch The end training epoch. + * @param epochSize Number of batches in a epoch. + * In default, it is set to ceil(num_train_examples / batch_size) + * @param optimizer The optimization algorithm + * @param kvStore The KVStore + * @param updateOnKVStore whether or not perform weight updating on kvstore + * @param trainData Training data iterator. + * @param evalData Validation data iterator. + * @param evalMetric A evaluation function. + * @param epochEndCallback A callback that is invoked at end of each epoch. + * This can be used to checkpoint model each epoch. + * @param batchEndCallback A callback that is invoked at end of each batch. + * This can be used to measure speed, + * get result from evaluation metric. etc. + * @param logger When not specified, default logger will be used. + * @param workLoadList The list of work load for different devices, in the same order as ctx + * @param monitor Monitor outputs, weights, and gradients for debugging + * @note This function will inplace update the NDArrays in argParams and auxStates. + */ + // scalastyle:off parameterNum + private[mxnet] def trainMultiDevice(symbol: Symbol, ctx: Array[Context], + argNames: Seq[String], paramNames: Seq[String], + auxNames: Seq[String], argParams: Map[String, NDArray], + auxParams: Map[String, NDArray], + beginEpoch: Int, endEpoch: Int, epochSize: Int, + optimizer: Optimizer, + kvStore: KVStore, updateOnKVStore: Boolean, + trainData: DataIter = null, evalData: DataIter = null, + evalMetric: EvalMetric = null, + epochEndCallback: EpochEndCallback = null, + batchEndCallback: BatchEndCallback = null, + logger: Logger = logger, + workLoadList: Seq[Float] = Nil, + monitor: Monitor = null): Unit = { + val executorManager = new DataParallelExecutorManager( + symbol = symbol, + ctx = ctx, + trainData = trainData, + paramNames = paramNames, + argNames = argNames, + auxNames = auxNames, + workLoadList = workLoadList, + logger = logger) + } + // scalastyle:on parameterNum +} + +trait EpochEndCallback { + def invoke(epoch: Int, symbol: Symbol, + argParams: Map[String, NDArray], + auxStates: Map[String, NDArray]): Unit } -class Model { +trait BatchEndCallback { + def invoke(epoch: Int, nBatch: Int, evalMetric: EvalMetric) +} + +/** + * Model class of MXNet for training and predicting feedforward nets. + * This class is designed for a single-data single output supervised network. + * @param symbol The symbol configuration of computation network. + * @param ctx The device context of training and prediction. + * To use multi GPU training, pass in a list of gpu contexts. + * @param numEpoch Training parameter, number of training epochs(epochs). + * @param epochSize Number of batches in a epoch. In default, it is set to + * ceil(num_train_examples / batch_size) + * @param optimizer Training parameter, name or optimizer object for training. + * @param initializer Training parameter, the initialization scheme used. + * @param batchSize The batch size of training data. + * @param argParams Model parameter, dict of name to NDArray of net's weights. + * @param auxParams Model parameter, dict of name to NDArray of net's auxiliary states. + * @param allowExtraParams Whether allow extra parameters that are not needed by symbol + * to be passed by aux_params and arg_params. + * If this is True, no error will be thrown when aux_params and arg_params + * contain extra parameters than needed. + * @param beginEpoch The begining training epoch. + * @param kwargs The additional keyword arguments passed to optimizer. + */ +class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(new Context("cpu")), + val numEpoch: Int = -1, val epochSize: Int = -1, + val optimizer: String = "sgd", + val initializer: Initializer = new Uniform(0.01f), + val batchSize: Int = 128, + argParams: Map[String, NDArray] = null, + auxParams: Map[String, NDArray] = null, + allowExtraParams: Boolean = false, + val beginEpoch: Int = 0, + val kwargs: mutable.Map[String, Any]) { + private val LOG: Logger = LoggerFactory.getLogger(classOf[FeedForward]) + // check if symbol contain duplicated names. + Executor.checkArguments(symbol) + + // rematch parameters to delete useless ones + private var _argParams = + if (allowExtraParams) { + if (argParams != null) { + val argNames = symbol.listArguments().toSet + argParams.filter { case (k, v) => argNames.contains(k) } + } else { + null + } + } else { + argParams + } + private var _auxParams = + if (allowExtraParams) { + if (auxParams != null) { + val auxNames = symbol.listAuxiliaryStates().toSet + auxParams.filter { case (k, v) => auxNames.contains(k) } + } else { + null + } + } else { + auxParams + } + + // internal helper state + var predExec: Executor = null + + // Initialize weight parameters and auxiliary states + private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false) + : (Seq[String], Seq[String], Seq[String]) = { + val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes) + val argNames = symbol.listArguments() + val inputNames = inputShapes.keys + val paramNames = argNames.toSet -- inputNames.toSet + val auxNames = symbol.listAuxiliaryStates() + + val paramNameShapes = (argNames zip argShapes).filter { case (name, _) => + paramNames.contains(name) + } + val argParams = paramNameShapes.map { case (name, shape) => + (name, NDArray.zeros(shape)) + }.toMap + val auxParams = (auxNames zip auxShapes).map { case (name, shape) => + (name, NDArray.zeros(shape)) + }.toMap + + for ((k, v) <- argParams) { + if (_argParams != null && _argParams.contains(k) && (!overwrite)) { + argParams(k).set(_argParams(k)) + } else { + initializer(k, v) + } + } + + for ((k, v) <- auxParams) { + if (_auxParams != null && _auxParams.contains(k) && (!overwrite)) { + auxParams(k).set(_auxParams(k)) + } else { + initializer(k, v) + } + } + + _argParams = argParams + _auxParams = auxParams + (argNames, paramNames.toSeq, auxNames) + } + + // Initialize the predictor module for running prediction. + private def initPredictor(inputShapes: Map[String, Shape]): Unit = { + if (this.predExec == null) { + val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = inputShapes) + predExec.copyParamsFrom(_argParams, _auxParams) + Executor.checkArguments(symbol) + this.predExec = predExec + } + } + + // Initialize the iterator given input. + private def initIter(X: NDArray, y: NDArray, isTrain: Boolean): DataIter = { + require(y != null || !isTrain, "y must be specified") + val label = if (y == null) NDArray.zeros(X.shape(0)) else y + require(label.shape.length == 1, "Label must be 1D") + require(X.shape(0) == label.shape(0), "The numbers of data points and labels not equal") + if (isTrain) { + new NDArrayIter(X, label, batchSize, shuffle = isTrain, lastBatchHandle = "roll_over") + } else { + new NDArrayIter(X, label, batchSize, shuffle = false) + } + } + + // Initialize the iterator given eval_data. + private def initEvalIter(evalData: (NDArray, NDArray)): DataIter = { + if (evalData == null) { + null + } else { + initIter(evalData._1, evalData._2, isTrain = true) + } + } + + /** + * TODO + * Fit the model. + * @param trainData Training data. If X is an DataIter, the name or, if not available, + * position, of its outputs should match the corresponding variable + * names defined in the symbolic graph. + * @param evalData If eval_data is numpy.ndarray/list/NDArray pair, + * it should be (valid_data, valid_label). + * @param evalMetric The evaluation metric, name of evaluation metric. + * Or a customize evaluation function that returns the statistics + * based on minibatch. + * @param epochEndCallback A callback that is invoked at end of each epoch. + * This can be used to checkpoint model each epoch. + * @param batchEndCallback A callback that is invoked at end of each batch + * For print purpose + * @param kvStoreType A string kvstore type: + * 'local' : multi-devices on a single machine, will automatically + * choose one from 'local_update_cpu', 'local_allreduce_cpu', and + * 'local_allreduce_device' + * 'dist_sync' : multi-machines with BSP + * 'dist_async' : multi-machines with partical asynchronous + * In default uses 'local', often no need to change for single machiine. + * @param logger When not specified, default logger will be used. + * @param workLoadList The list of work load for different devices, in the same order as ctx + */ + def fit(trainData: DataIter, evalData: DataIter, evalMetric: String = "acc", + kvStoreType: String = "local", logger: Logger = LOG, + workLoadList: Seq[Float] = null): Unit = { + val (argNames, paramNames, auxNames) = + initParams(trainData.provideData ++ trainData.provideLabel) + // TODO: kwargs.put("arg_names", argNames) + + // TODO: setup metric + + // create kvstore + val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams) + + // init optmizer + val batchSizeMultiplier = kvStore.map { kv => + if (kv.`type` == "dist_sync") { + kv.numWorkers + } else { + 1 + } + } + val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1) + // TODO: temporarily hard-coded sgd optimizer + val optimizer = new SGD(rescaleGrad = 1f / batchSize) + Model.trainMultiDevice( + symbol, ctx, argNames, paramNames, auxNames, + _argParams, _auxParams, + this.beginEpoch, this.numEpoch, + this.epochSize, + optimizer, + kvStore.orNull, updateOnKVStore, + trainData = trainData, evalData = evalData, + logger = logger, workLoadList = workLoadList) + } +} + +object FeedForward { + // Check if name is a data argument. + private def isDataArg(name: String): Boolean = { + name.endsWith("data") || name.endsWith("label") + } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 9a2c7912f104..b301c1abf03c 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -11,6 +11,21 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} */ object NDArray { private val logger = LoggerFactory.getLogger(classOf[NDArray]) + + private[mxnet] val DTYPE_NATIVE_TO_MX: Map[Class[_ >: Float with Int with Double], Int] = Map( + classOf[Float] -> 0, + classOf[Double] -> 1, + classOf[Int] -> 4 + ) + + private[mxnet] val DTYPE_MX_TO_NATIVE: Map[Int, Class[_ >: Float with Int with Double]] = Map( + 0 -> classOf[Float], + 1 -> classOf[Double], + 2 -> classOf[Float], + 3 -> classOf[Int], + 4 -> classOf[Int] + ) + private val functions: Map[String, NDArrayFunction] = initNDArrayModule() // Definition of internal functions. @@ -552,6 +567,10 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = new NDArray(handle = sliceHandle.value, writable = this.writable) } + def slice(range: (Int, Int)): NDArray = { + slice(range._1, range._2) + } + def slice(start: Int): NDArray = { slice(start, shape(0)) } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala index d2935590f055..3b14d5b14fcc 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala @@ -1,5 +1,6 @@ package ml.dmlc.mxnet +import ml.dmlc.mxnet.Base.Shape import ml.dmlc.mxnet.NDArray.{randomGaussian, randomUniform, empty} /** @@ -19,7 +20,7 @@ object Random { */ def uniform(low: Float, high: Float, - shape: Array[Int] = null, + shape: Shape = null, ctx: Context = null, out: NDArray = null): NDArray = { var outCopy = out @@ -45,7 +46,7 @@ object Random { */ def normal(mean: Float, stdvar: Float, - shape: Array[Int] = null, + shape: Shape = null, ctx: Context = null, out: NDArray = null): NDArray = { var outCopy = out diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 33e10ad5ea7e..2d048422a989 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -10,7 +10,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} * @author Yizhi Liu */ class Symbol(private[mxnet] val handle: SymbolHandle) { - def +(other: Symbol): Symbol = Symbol.create("_Plus", other) + def +(other: Symbol): Symbol = Symbol.create("_Plus", this, other) override def clone(): Symbol = { val clonedHandle = new SymbolHandleRef @@ -18,37 +18,203 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { new Symbol(clonedHandle.value) } + def get(index: Int): Symbol = { + val newHandle = new SymbolHandleRef + checkCall(_LIB.mxSymbolGetOutput(handle, index, newHandle)) + new Symbol(handle = newHandle.value) + } + + def get(name: String): Symbol = { + var index: Int = -1 + for ((output, i) <- listOutputs().view.zipWithIndex) { + if (output == name) { + require(index == -1, s"There are multiple outputs with name $name") + index = i + } + } + require(index >= 0, s"Cannot find output that matches name $name") + get(index) + } + + /** + * Get a new grouped symbol whose output contains all the internal outputs of this symbol. + * @return The internal of the symbol. + */ + def getInternals: Symbol = { + val newHandle = new SymbolHandleRef + checkCall(_LIB.mxSymbolGetInternals(handle, newHandle)) + new Symbol(handle = newHandle.value) + } + /** * List all the arguments in the symbol. * @return Array of all the arguments. */ - def listArguments(): Array[String] = { + def listArguments(): Seq[String] = { val arr = ArrayBuffer.empty[String] checkCall(_LIB.mxSymbolListArguments(handle, arr)) - arr.toArray + arr } /** * List all outputs in the symbol. * @return : List of all the outputs. */ - def listOutputs(): Array[String] = { + def listOutputs(): Seq[String] = { val arr = ArrayBuffer.empty[String] checkCall(_LIB.mxSymbolListOutputs(handle, arr)) - arr.toArray + arr } /** * List all auxiliary states in the symbol. * @return The names of the auxiliary states. - * Notes - * ----- + * @note * Auxiliary states are special states of symbols that do not corresponds to an argument, * and do not have gradient. But still be useful for the specific operations. * A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm. * Most operators do not have Auxiliary states. */ - def listAuxiliaryStates(): Array[String] = ??? + def listAuxiliaryStates(): Seq[String] = { + val sarr = ArrayBuffer.empty[String] + checkCall(_LIB.mxSymbolListAuxiliaryStates(handle, sarr)) + sarr + } + + /** + * Infer the type of outputs and arguments of given known types of arguments. + * Tuple of Nones is returned if there is not enough information passed in. + * An error will be raised if there is inconsistency found in the known types passed in. + * @param args Provide type of arguments in a positional way. Unknown type can be marked as null + * @return + * argTypes : list of numpy.dtype or None + * List of types of arguments. + * The order is in the same order as list_arguments() + * outTypes : list of numpy.dtype or None + * List of types of outputs. + * The order is in the same order as list_outputs() + * auxTypes : list of numpy.dtype or None + * List of types of outputs. + * The order is in the same order as list_auxiliary() + */ + def inferType(args: Class[_ >: Float with Int with Double]*) + : (Seq[Class[_ >: Float with Int with Double]], + Seq[Class[_ >: Float with Int with Double]], + Seq[Class[_ >: Float with Int with Double]]) = { + val sdata: Array[Int] = args.map(NDArray.DTYPE_NATIVE_TO_MX.getOrElse(_, -1)).toArray + inferType(null, sdata) + } + + /** + * Infer the type of outputs and arguments of given known types of arguments. + * Tuple of Nones is returned if there is not enough information passed in. + * An error will be raised if there is inconsistency found in the known types passed in. + * @param kwargs Provide keyword arguments of known types. + * @return + * argTypes : list of numpy.dtype or None + * List of types of arguments. + * The order is in the same order as list_arguments() + * outTypes : list of numpy.dtype or None + * List of types of outputs. + * The order is in the same order as list_outputs() + * auxTypes : list of numpy.dtype or None + * List of types of outputs. + * The order is in the same order as list_auxiliary() + */ + def inferType(kwargs: Map[String, Class[_ >: Float with Int with Double]]) + : (Seq[Class[_ >: Float with Int with Double]], + Seq[Class[_ >: Float with Int with Double]], + Seq[Class[_ >: Float with Int with Double]]) = { + val filteredArgs = kwargs.filter { case (key, value) => + NDArray.DTYPE_NATIVE_TO_MX.contains(value) + } + val keys = filteredArgs.keys.toArray + val sdata = filteredArgs.values.map(NDArray.DTYPE_NATIVE_TO_MX(_)).toArray + inferType(keys, sdata) + } + + private def inferType(keys: Array[String], values: Array[Int]) + : (Seq[Class[_ >: Float with Int with Double]], + Seq[Class[_ >: Float with Int with Double]], + Seq[Class[_ >: Float with Int with Double]]) = { + val argTypeData = ListBuffer.empty[Int] + val outTypeData = ListBuffer.empty[Int] + val auxTypeData = ListBuffer.empty[Int] + val complete = new RefInt + checkCall(_LIB.mxSymbolInferType( + handle, keys, values, argTypeData, outTypeData, auxTypeData, complete)) + if (complete.value != 0) { + (argTypeData.map(NDArray.DTYPE_MX_TO_NATIVE), + outTypeData.map(NDArray.DTYPE_MX_TO_NATIVE), + auxTypeData.map(NDArray.DTYPE_MX_TO_NATIVE)) + } else { + (null, null, null) + } + } + + /** + * Infer the shape of outputs and arguments of given known shapes of arguments. + * User can either pass in the known shapes in positional way or keyword argument way. + * Tuple of Nones is returned if there is not enough information passed in. + * An error will be raised if there is inconsistency found in the known shapes passed in. + * @param args Provide shape of arguments in a positional way. + * Unknown shape can be marked as None + * @return + * argShapes List of shapes of arguments. The order is in the same order as list_arguments() + * outShapes List of shapes of outputs. The order is in the same order as list_outputs() + * auxShapes List of shapes of outputs. The order is in the same order as list_auxiliary() + */ + def inferShape(args: Shape*): (Seq[Shape], Seq[Shape], Seq[Shape]) = { + val keys: Array[String] = null + val indPtr = ArrayBuffer(0) + val sdata = ArrayBuffer.empty[Int] + args.foreach { shape => + if (shape != null) { + sdata ++= shape + indPtr += sdata.size + } + } + inferShape(keys, indPtr.toArray, sdata.toArray) + } + + /** + * Infer the shape of outputs and arguments of given known shapes of arguments. + * User can either pass in the known shapes in positional way or keyword argument way. + * Tuple of Nones is returned if there is not enough information passed in. + * An error will be raised if there is inconsistency found in the known shapes passed in. + * @param kwargs Provide keyword arguments of known shapes. + * @return + * argShapes List of shapes of arguments. The order is in the same order as list_arguments() + * outShapes List of shapes of outputs. The order is in the same order as list_outputs() + * auxShapes List of shapes of outputs. The order is in the same order as list_auxiliary() + */ + def inferShape(kwargs: Map[String, Shape]): (Seq[Shape], Seq[Shape], Seq[Shape]) = { + val keys = ArrayBuffer.empty[String] + val indPtr = ArrayBuffer(0) + val sdata = ArrayBuffer.empty[Int] + kwargs.foreach { case (key, shape) => + keys += key + sdata ++= shape + indPtr += sdata.size + } + inferShape(keys.toArray, indPtr.toArray, sdata.toArray) + } + + def inferShape(keys: Array[String], indPtr: Array[Int], values: Array[Int]) + : (Seq[Shape], Seq[Shape], Seq[Shape]) = { + val argShapeData = ListBuffer.empty[Shape] + val outShapeData = ListBuffer.empty[Shape] + val auxShapeData = ListBuffer.empty[Shape] + val complete = new RefInt + + checkCall(_LIB.mxSymbolInferShape(handle, indPtr.size - 1, keys, indPtr, values, + argShapeData, outShapeData, auxShapeData, complete)) + if (complete.value != 0) { + (argShapeData, outShapeData, auxShapeData) + } else { + (null, null, null) + } + } /** * Get attribute string from the symbol, this function only works for non-grouped symbol. @@ -112,11 +278,390 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { val args = symbols.values.map(_.handle).toArray checkCall(_LIB.mxSymbolCompose(handle, name, keys, args)) } + + /** + * Bind current symbol to get an executor, allocate all the ndarrays needed. + * Allows specifying data types. + * This function will ask user to pass in ndarray of position + * they like to bind to, and it will automatically allocate the ndarray + * for arguments and auxiliary states that user did not specify explicitly. + * + * @param ctx The device context the generated executor to run on. + * @param gradReq {'write', 'add', 'null'}, or list of str or dict of str to str, optional + * Specifies how we should update the gradient to the args_grad. + * - 'write' means everytime gradient is write to specified args_grad NDArray. + * - 'add' means everytime gradient is add to the specified NDArray. + * - 'null' means no action is taken, the gradient may not be calculated. + * @param typeDict Input type dictionary, name->dtype + * @param shapeDict Input shape dictionary, name->shape + * @return The generated Executor + */ + def simpleBind(ctx: Context, gradReq: String = "write", + shapeDict: Map[String, Shape], + typeDict: Map[String, Class[_ >: Float with Int with Double]] = null): Executor = { + val types = + if (typeDict == null) listArguments().map((_, classOf[Float])).toMap + else typeDict + val (argShapes, _, auxShapes) = inferShape(shapeDict) + val (argTypes, _, auxTypes) = inferType(types) + require(argShapes != null && argTypes != null, "Input node is not complete") + // alloc space + val argNDArrays = (argShapes zip argTypes) map { case (shape, t) => + // TODO: NDArray dtype + NDArray.zeros(shape, ctx) + } + val gradNDArrays = + if (gradReq != "null") { + (((listArguments() zip argShapes) zip argTypes) flatMap { case ((name, shape), t) => + if (!(name.endsWith("data") || name.endsWith("label"))) { + // TODO: NDArray dtype + Map(name -> NDArray.zeros(shape, ctx)) + } else { + Map.empty[String, NDArray] + } + }).toMap + } else { + null + } + val auxNDArrays = (auxShapes zip auxTypes) map { case (shape, t) => + // TODO: NDArray dtype + NDArray.zeros(shape, ctx) + } + bind(ctx, argNDArrays, gradNDArrays, gradReq, auxNDArrays, null) + } + + /** + * Bind current symbol to get an executor. + * + * @param ctx Context The device context the generated executor to run on. + * @param args Input arguments to the symbol. + * - If type is list of NDArray, the position is in the same order of list_arguments. + * - If type is dict of str to NDArray, then it maps the name of arguments + * to the corresponding NDArray. + * - In either case, all the arguments must be provided. + * @param argsGrad When specified, args_grad provide NDArrays to hold + * the result of gradient value in backward. + * - If type is list of NDArray, + * the position is in the same order of list_arguments. + * - If type is dict of str to NDArray, then it maps the name of arguments + * to the corresponding NDArray. + * - When the type is dict of str to NDArray, users only need to provide the dict + * for needed argument gradient. + * Only the specified argument gradient will be calculated. + * @param gradReq {'write', 'add', 'null'}, or list of str or dict of str to str, optional + * Specifies how we should update the gradient to the args_grad. + * - 'write' means everytime gradient is write to specified args_grad NDArray. + * - 'add' means everytime gradient is add to the specified NDArray. + * - 'null' means no action is taken, the gradient may not be calculated. + * @param auxStates Input auxiliary states to the symbol, only need to specify when + * list_auxiliary_states is not empty. + * - If type is list of NDArray, + * the position is in the same order of listAuxiliaryStates + * - If type is dict of str to NDArray, then it maps the name of auxiliary_states + * to the corresponding NDArray, + * - In either case, all the auxiliary_states need to be provided. + * @param group2ctx The dict mapping the ``ctx_group`` attribute to the context assignment. + * @return The generated Executor + * @note + * Auxiliary states are special states of symbols that do not corresponds to an argument, + * and do not have gradient. But still be useful for the specific operations. + * A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm. + * Most operators do not have auxiliary states and this parameter can be safely ignored. + * + * User can give up gradient by using a dict in args_grad and only specify + * gradient they interested in. + */ + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray], + gradReq: String, auxStates: Seq[NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, + Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx) + } + + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray], + gradReq: String, auxStates: Map[String, NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, + Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx) + } + + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray], + gradReq: String, auxStates: Seq[NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, + Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx) + } + + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray], + gradReq: String, auxStates: Map[String, NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, + Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray], + gradReq: String, auxStates: Seq[NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, + Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray], + gradReq: String, auxStates: Map[String, NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, + Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray], + gradReq: String, auxStates: Seq[NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, + Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray], + gradReq: String, auxStates: Map[String, NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, + Seq.fill(symbolArguments.size)(gradReq), auxStates, group2ctx) + } + + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray], + gradsReq: Seq[String], auxStates: Seq[NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray], + gradsReq: Seq[String], auxStates: Map[String, NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray], + gradsReq: Seq[String], auxStates: Seq[NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray], + gradsReq: Seq[String], auxStates: Map[String, NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray], + gradsReq: Seq[String], auxStates: Seq[NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray], + gradsReq: Seq[String], auxStates: Map[String, NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray], + gradsReq: Seq[String], auxStates: Seq[NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray], + gradsReq: Seq[String], auxStates: Map[String, NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray], + gradsReq: Map[String, String], auxStates: Seq[NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray], + gradsReq: Map[String, String], auxStates: Map[String, NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray], + gradsReq: Map[String, String], auxStates: Seq[NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray], + gradsReq: Map[String, String], auxStates: Map[String, NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray], + gradsReq: Map[String, String], auxStates: Seq[NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Seq[NDArray], + gradsReq: Map[String, String], auxStates: Map[String, NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray], + gradsReq: Map[String, String], auxStates: Seq[NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray], + gradsReq: Map[String, String], auxStates: Map[String, NDArray], + group2ctx: Map[String, Context]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, argsGrad, gradsReq, auxStates, group2ctx) + } + + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Seq[NDArray]): Executor = { + bind(ctx, args, argsGrad, "write", Nil, null) + } + + def bind(ctx: Context, args: Map[String, NDArray], argsGrad: Map[String, NDArray]): Executor = { + bind(ctx, args, argsGrad, "write", Nil, null) + } + + def bind(ctx: Context, args: Seq[NDArray]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, null, + Seq.fill(symbolArguments.size)("write"), Nil, null) + } + + def bind(ctx: Context, args: Map[String, NDArray]): Executor = { + val symbolArguments = listArguments() + bindHelper(ctx, symbolArguments, args, null, + Seq.fill(symbolArguments.size)("write"), Nil, null) + } + + private def bindHelper(ctx: Context, symbolArguments: Seq[String], + args: Iterable[_], argsGrad: Iterable[_], + gradsReq: Iterable[_], auxStates: Iterable[_], + group2ctx: Map[String, Context]): Executor = { + require(args != null && !args.isInstanceOf[Set[_]]) + require(argsGrad == null || !argsGrad.isInstanceOf[Set[_]]) + require(auxStates == null || !auxStates.isInstanceOf[Set[_]]) + require(gradsReq != null && !gradsReq.isInstanceOf[Set[_]]) + + val (argsHandle, argsNDArray) = + if (args.isInstanceOf[Seq[_]]) { + Symbol.getNDArrayInputs("args", args.asInstanceOf[Seq[NDArray]], + symbolArguments, allowMissing = false) + } else { + Symbol.getNDArrayInputs("args", args.asInstanceOf[Map[String, NDArray]], + symbolArguments, allowMissing = false) + } + + // setup args gradient + val (argsGradHandle, argsGradNDArray) = + if (argsGrad == null) { + (Array.fill[NDArrayHandle](args.size)(0L), null) + } else if (argsGrad.isInstanceOf[Seq[_]]) { + Symbol.getNDArrayInputs("args_grad", argsGrad.asInstanceOf[Seq[NDArray]], + symbolArguments, allowMissing = true) + } else { + Symbol.getNDArrayInputs("args_grad", argsGrad.asInstanceOf[Map[String, NDArray]], + symbolArguments, allowMissing = true) + } + + val (auxArgsHandle, auxStatesNDArray) = + if (auxStates == null) { + Symbol.getNDArrayInputs("aux_states", Nil, listAuxiliaryStates(), allowMissing = false) + } else if (auxStates.isInstanceOf[Seq[_]]) { + Symbol.getNDArrayInputs("aux_states", auxStates.asInstanceOf[Seq[NDArray]], + listAuxiliaryStates(), allowMissing = false) + } else { + Symbol.getNDArrayInputs("aux_states", auxStates.asInstanceOf[Map[String, NDArray]], + listAuxiliaryStates(), allowMissing = false) + } + + // setup requirements + val reqsArray = + if (gradsReq.isInstanceOf[Seq[_]]) { + gradsReq.asInstanceOf[Seq[String]].map { req => + require(Symbol.bindReqMap.contains(req), s"grad_req must be in ${Symbol.bindReqMap}") + Symbol.bindReqMap(req) + }.toArray + } else { + val gradsReqMap = gradsReq.asInstanceOf[Map[String, String]] + symbolArguments.map { req => + val value = gradsReqMap.getOrElse(req, "null") + require(Symbol.bindReqMap.contains(value), s"grad_req must be in ${Symbol.bindReqMap}") + Symbol.bindReqMap(value) + }.toArray + } + + val ctxMapKeys = ArrayBuffer.empty[String] + val ctxMapDevTypes = ArrayBuffer.empty[Int] + val ctxMapDevIDs = ArrayBuffer.empty[Int] + + if (group2ctx != null) { + group2ctx.foreach { case (key, value) => + ctxMapKeys += key + ctxMapDevTypes += value.deviceTypeid + ctxMapDevIDs += value.deviceId + } + } + + val execHandle = new ExecutorHandleRef + checkCall(_LIB.mxExecutorBindX(handle, + ctx.deviceTypeid, + ctx.deviceId, + ctxMapKeys.size, + ctxMapKeys.toArray, + ctxMapDevTypes.toArray, + ctxMapDevIDs.toArray, + args.size, + argsHandle, + argsGradHandle, + reqsArray, + auxArgsHandle, + execHandle)) + val executor = new Executor(execHandle.value, this) + executor.argArrays = argsNDArray + executor.gradArrays = argsGradNDArray + executor.auxArrays = auxStatesNDArray + executor + } } object Symbol { private val logger = LoggerFactory.getLogger(classOf[Symbol]) private val functions: Map[String, SymbolFunction] = initSymbolModule() + private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3) /** * Create a symbolic variable with specified name. @@ -148,6 +693,34 @@ object Symbol { createNoCheck("Activation", attr) } + def Convolution(attr: Map[String, String]): Map[String, Any] => Symbol = { + createNoCheck("Convolution", attr) + } + + def Convolution: Map[String, Any] => Symbol = { + Convolution(null) + } + + def BatchNorm: Map[String, Any] => Symbol = { + createNoCheck("BatchNorm") + } + + def Pooling: Map[String, Any] => Symbol = { + createNoCheck("Pooling") + } + + def Flatten: Map[String, Any] => Symbol = { + createNoCheck("Flatten") + } + + def SoftmaxOutput: Map[String, Any] => Symbol = { + createNoCheck("SoftmaxOutput") + } + + def Cast: Map[String, Any] => Symbol = { + createNoCheck("Cast") + } + /** * Create a symbol that groups symbols together. * @param symbols List of symbols to be grouped. @@ -295,6 +868,44 @@ object Symbol { } create(operator, symbolArgs, strArgs, attr) } + + /** + * Helper function to get ndarray lists handles from various inputs. + * @param argKey The name of argument, used for error message. + * @param args list of NDArray or dict of str to NDArray + * Input arguments to the symbols. + * If type is list of NDArray, the position is in the same order of arg_names. + * If type is dict of str to NDArray, then it maps the name of arguments + * to the corresponding NDArray + * @param argNames List of argument names. + * @param allowMissing Whether missing argument is allowed. + * When allowed, the missing handle will be set to None(null) + * @return The positional list of NDArrayHandles generated from input. + */ + private def getNDArrayInputs(argKey: String, args: Seq[NDArray], argNames: Seq[String], + allowMissing: Boolean): (Array[NDArrayHandle], Array[NDArray]) = { + require(args.length == argNames.length, s"Length of $argKey do not match number of arguments") + val argHandles = args.map(_.handle) + (argHandles.toArray, args.toArray) + } + + private def getNDArrayInputs(argKey: String, args: Map[String, NDArray], argNames: Seq[String], + allowMissing: Boolean): (Array[NDArrayHandle], Array[NDArray]) = { + val argArrays = ArrayBuffer.empty[NDArray] + val argHandles = ArrayBuffer.empty[NDArrayHandle] + argNames.foreach { name => + args.get(name) match { + case narr: Some[NDArray] => + argArrays += narr.get + argHandles += narr.get.handle + case None => + require(allowMissing, s"Must specify all the arguments in $argKey") + argArrays += null + argHandles += 0L + } + } + (argHandles.toArray, argArrays.toArray) + } } private case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala new file mode 100644 index 000000000000..8716bb181e04 --- /dev/null +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala @@ -0,0 +1,48 @@ +package ml.dmlc.mxnet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class ExecutorSuite extends FunSuite with BeforeAndAfterAll { + test("bind") { + val shape = Array(100, 30) + val lhs = Symbol.Variable("lhs") + val rhs = Symbol.Variable("rhs") + val ret = lhs + rhs + assert(ret.listArguments().toArray === Array("lhs", "rhs")) + + val lhsArr = Random.uniform(-10f, 10f, shape) + val rhsArr = Random.uniform(-10f, 10f, shape) + val lhsGrad = NDArray.empty(shape) + val rhsGrad = NDArray.empty(shape) + + val executor = ret.bind(Context.cpu(), args = Seq(lhsArr, rhsArr), + argsGrad = Seq(lhsGrad, rhsGrad)) + val exec3 = ret.bind(Context.cpu(), args = Seq(lhsArr, rhsArr)) + val exec4 = ret.bind(Context.cpu(), args = Map("rhs" -> rhsArr, "lhs" -> lhsArr), + argsGrad = Map("lhs" -> lhsGrad, "rhs" -> rhsGrad)) + executor.forward() + exec3.forward() + exec4.forward() + + val out1 = lhsArr + rhsArr + val out2 = executor.outputs(0) + val out3 = exec3.outputs(0) + val out4 = exec4.outputs(0) + assert(reldiff(out1, out2) < 1e-6) + assert(reldiff(out1, out3) < 1e-6) + assert(reldiff(out1, out4) < 1e-6) + + // test gradient + val outGrad = NDArray.ones(shape) + val (lhsGrad2, rhsGrad2) = (outGrad, outGrad) + executor.backward(Array(outGrad)) + assert(reldiff(lhsGrad, lhsGrad2) < 1e-6) + assert(reldiff(rhsGrad, rhsGrad2) < 1e-6) + } + + private def reldiff(a: NDArray, b: NDArray): Float = { + val diff = NDArray.sum(NDArray.abs(a - b)).toScalar + val norm = NDArray.sum(NDArray.abs(a)).toScalar + diff / norm + } +} diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala index 27be2eb228dc..87717977e284 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala @@ -20,7 +20,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { "seed" -> "10" ) - val mnistIter = IO.createIterator("MNISTIter", params) + val mnistIter = IO.MNISTIter(params) // test_loop mnistIter.reset() val nBatch = 600 @@ -34,15 +34,15 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // test reset mnistIter.reset() mnistIter.iterNext() - val label0 = mnistIter.getLabel().toArray - val data0 = mnistIter.getData().toArray + val label0 = mnistIter.getLabel().head.toArray + val data0 = mnistIter.getData().head.toArray mnistIter.iterNext() mnistIter.iterNext() mnistIter.iterNext() mnistIter.reset() mnistIter.iterNext() - val label1 = mnistIter.getLabel().toArray - val data1 = mnistIter.getData().toArray + val label1 = mnistIter.getLabel().head.toArray + val data1 = mnistIter.getData().head.toArray assert(label0 === label1) assert(data0 === data1) } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala index edc6ace47444..8e14d9565551 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala @@ -8,7 +8,7 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll { var net1 = Symbol.FullyConnected(Map("data" -> data, "name" -> "fc1", "num_hidden" -> 10)) net1 = Symbol.FullyConnected(Map("data" -> net1, "name" -> "fc2", "num_hidden" -> 100)) - assert(net1.listArguments() === + assert(net1.listArguments().toArray === Array("data", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias")) var net2 = Symbol.FullyConnected(Map("name" -> "fc3", "num_hidden" -> 10)) @@ -25,4 +25,27 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll { val multiOut = Symbol.Group(composed, net1) assert(multiOut.listOutputs().length === 2) } + + test("symbol internal") { + val data = Symbol.Variable("data") + val oldfc = Symbol.FullyConnected(Map("data" -> data, "name" -> "fc1", "num_hidden" -> 10)) + val net1 = Symbol.FullyConnected(Map("data" -> oldfc, "name" -> "fc2", "num_hidden" -> 100)) + assert(net1.listArguments().toArray + === Array("data", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias")) + val internal = net1.getInternals + val fc1 = internal.get("fc1_output") + assert(fc1.listArguments() === oldfc.listArguments()) + } + + test("symbol infer type") { + val data = Symbol.Variable("data") + val f32data = Symbol.Cast(Map("data" -> data, "dtype" -> "float32")) + val fc1 = Symbol.FullyConnected(Map("data" -> f32data, "name" -> "fc1", "num_hidden" -> 128)) + val mlp = Symbol.SoftmaxOutput(Map("data" -> fc1, "name" -> "softmax")) + + val (arg, out, aux) = mlp.inferType(Map("data" -> classOf[Double])) + assert(arg.toArray === Array(classOf[Double], classOf[Float], classOf[Float], classOf[Float])) + assert(out.toArray === Array(classOf[Float])) + assert(aux.isEmpty) + } } diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index a37d8a699409..a30215b067a0 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -527,7 +527,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterCreateIter (JNIEnv * env, jobject obj, jlong creator, jobjectArray jkeys, jobjectArray jvals, jobject dataIterHandleRef) { - //keys and values + // keys and values int paramSize = env->GetArrayLength(jkeys); char** keys = new char*[paramSize]; char** vals = new char*[paramSize]; @@ -863,6 +863,23 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListOutputs return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListAuxiliaryStates + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject outputs) { + mx_uint outSize; + const char **outStrArray; + int ret = MXSymbolListAuxiliaryStates((SymbolHandle) symbolPtr, &outSize, &outStrArray); + + jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); + jmethodID arrayAppend = env->GetMethodID(arrayClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); + for (int i = 0; i < outSize; i++) { + jstring output = env->NewStringUTF(outStrArray[i]); + env->CallObjectMethod(outputs, arrayAppend, output); + } + + return ret; +} + JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCopy (JNIEnv *env, jobject obj, jlong symbolPtr, jobject clonedSymbolRef) { SymbolHandle clonedSymbol; @@ -889,3 +906,231 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolPrint setStringField(env, out, outStr); return ret; } + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolGetOutput + (JNIEnv *env, jobject obj, jlong symbolPtr, jint index, jobject jout) { + SymbolHandle out; + int ret = MXSymbolGetOutput((SymbolHandle) symbolPtr, (mx_uint) index, &out); + setLongField(env, jout, (long) out); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolGetInternals + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject jout) { + SymbolHandle out; + int ret = MXSymbolGetInternals((SymbolHandle) symbolPtr, &out); + setLongField(env, jout, (long)out); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolInferType + (JNIEnv *env, jobject obj, jlong symbolPtr, jobjectArray jkeys, jintArray jvals, + jobject jargTypeData, jobject joutTypeData, jobject jauxTypeData, jobject jcomplete) { + int numArgs = env->GetArrayLength(jvals); + const char **keys = NULL; + if (jkeys != NULL) { + keys = new const char *[numArgs]; + for (int i = 0; i < numArgs; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); + const char *key = env->GetStringUTFChars(jkey, 0); + keys[i] = key; + } + } + + mx_uint inTypeSize; + const int *inTypeData; + mx_uint outTypeSize; + const int *outTypeData; + mx_uint auxTypeSize; + const int *auxTypeData; + int complete; + + jint *vals = env->GetIntArrayElements(jvals, NULL); + int ret = MXSymbolInferType((SymbolHandle) symbolPtr, + (mx_uint) numArgs, keys, (const int *) vals, + &inTypeSize, &inTypeData, + &outTypeSize, &outTypeData, + &auxTypeSize, &auxTypeData, + &complete); + env->ReleaseIntArrayElements(jvals, vals, 0); + + jclass integerClass = env->FindClass("java/lang/Integer"); + jmethodID newInteger = env->GetMethodID(integerClass, "", "(I)V"); + + jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + + for (int i = 0; i < inTypeSize; ++i) { + jobject data = env->NewObject(integerClass, newInteger, inTypeData[i]); + env->CallObjectMethod(jargTypeData, listAppend, data); + } + for (int i = 0; i < outTypeSize; ++i) { + jobject data = env->NewObject(integerClass, newInteger, outTypeData[i]); + env->CallObjectMethod(joutTypeData, listAppend, data); + } + for (int i = 0; i < auxTypeSize; ++i) { + jobject data = env->NewObject(integerClass, newInteger, auxTypeData[i]); + env->CallObjectMethod(jauxTypeData, listAppend, data); + } + + setIntField(env, jcomplete, complete); + + // release allocated memory + if (jkeys != NULL) { + for (int i = 0; i < numArgs; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); + env->ReleaseStringUTFChars(jkey, keys[i]); + } + delete[] keys; + } + + return ret; +} + +int FillSymbolInferShape + (JNIEnv *env, jmethodID listAppend, jobject joutData, + mx_uint shapeSize, const mx_uint *shapeNdim, const mx_uint **shapeData) { + for (int i = 0; i < shapeSize; ++i) { + jintArray jshape; + jshape = env->NewIntArray(shapeNdim[i]); + if (jshape == NULL) { + // TODO: out of memory error thrown, return a specific error code ? + return -1; + } + env->SetIntArrayRegion(jshape, 0, shapeNdim[i], (const jint *) shapeData[i]); + env->CallObjectMethod(joutData, listAppend, jshape); + } + return 0; +} +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolInferShape + (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray jkeys, + jintArray jargIndPtr, jintArray jargShapeData, + jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, jobject jcomplete) { + const char **keys = NULL; + if (jkeys != NULL) { + keys = new const char *[jnumArgs]; + for (int i = 0; i < jnumArgs; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); + const char *key = env->GetStringUTFChars(jkey, 0); + keys[i] = key; + } + } + + mx_uint inShapeSize; + const mx_uint *inShapeNdim; + const mx_uint **inShapeData; + + mx_uint outShapeSize; + const mx_uint *outShapeNdim; + const mx_uint **outShapeData; + + mx_uint auxShapeSize; + const mx_uint *auxShapeNdim; + const mx_uint **auxShapeData; + + int complete; + + jint *argIndPtr = env->GetIntArrayElements(jargIndPtr, NULL); + jint *argShapeData = env->GetIntArrayElements(jargShapeData, NULL); + int ret = MXSymbolInferShape((SymbolHandle) symbolPtr, + (mx_uint) jnumArgs, + keys, + (const mx_uint *) argIndPtr, + (const mx_uint *) argShapeData, + &inShapeSize, + &inShapeNdim, + &inShapeData, + &outShapeSize, + &outShapeNdim, + &outShapeData, + &auxShapeSize, + &auxShapeNdim, + &auxShapeData, + &complete); + env->ReleaseIntArrayElements(jargShapeData, argShapeData, 0); + env->ReleaseIntArrayElements(jargIndPtr, argIndPtr, 0); + + jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + + if (FillSymbolInferShape(env, listAppend, jinShapeData, inShapeSize, inShapeNdim, inShapeData)) { + // TODO: out of memory error thrown, return a specific error code ? + return -1; + } + if (FillSymbolInferShape( + env, listAppend, joutShapeData, outShapeSize, outShapeNdim, outShapeData)) { + // TODO: out of memory error thrown, return a specific error code ? + return -1; + } + if (FillSymbolInferShape( + env, listAppend, jauxShapeData, auxShapeSize, auxShapeNdim, auxShapeData)) { + // TODO: out of memory error thrown, return a specific error code ? + return -1; + } + + setIntField(env, jcomplete, complete); + + // release allocated memory + if (jkeys != NULL) { + for (int i = 0; i < jnumArgs; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); + env->ReleaseStringUTFChars(jkey, keys[i]); + } + delete[] keys; + } + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorBindX + (JNIEnv *env, jobject obj, jlong symbolPtr, jint deviceTypeId, jint deviceID, jint numCtx, + jobjectArray jctxMapKeys, jintArray jctxMapDevTypes, jintArray jctxMapDevIDs, jint numArgs, + jlongArray jargsHandle, jlongArray jargsGradHandle, jintArray jreqsArray, + jlongArray jauxArgsHandle, jobject jexecOut) { + + ExecutorHandle out; + int auxStatesLen = env->GetArrayLength(jauxArgsHandle); + + const char **mapKeys = new const char *[numCtx]; + for (int i = 0; i < numCtx; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jctxMapKeys, i); + const char *key = env->GetStringUTFChars(jkey, 0); + mapKeys[i] = key; + } + jlong *auxStates = env->GetLongArrayElements(jauxArgsHandle, NULL); + jint *gradReqType = env->GetIntArrayElements(jreqsArray, NULL); + jlong *inArgs = env->GetLongArrayElements(jargsHandle, NULL); + jlong *argGradStore = env->GetLongArrayElements(jargsGradHandle, NULL); + jint *mapDevTypes = env->GetIntArrayElements(jctxMapDevTypes, NULL); + jint *mapDevIDs = env->GetIntArrayElements(jctxMapDevIDs, NULL); + int ret = MXExecutorBindX((SymbolHandle) symbolPtr, + deviceTypeId, + deviceID, + (mx_uint) numCtx, + mapKeys, + mapDevTypes, + mapDevIDs, + (mx_uint) numArgs, + (NDArrayHandle *) inArgs, + (NDArrayHandle *) argGradStore, + (mx_uint *) gradReqType, + (mx_uint) auxStatesLen, + (NDArrayHandle *) auxStates, + &out); + env->ReleaseIntArrayElements(jctxMapDevIDs, mapDevIDs, 0); + env->ReleaseIntArrayElements(jctxMapDevTypes, mapDevTypes, 0); + env->ReleaseLongArrayElements(jargsGradHandle, argGradStore, 0); + env->ReleaseLongArrayElements(jargsHandle, inArgs, 0); + env->ReleaseIntArrayElements(jreqsArray, gradReqType, 0); + env->ReleaseLongArrayElements(jauxArgsHandle, auxStates, 0); + for (int i = 0; i < numCtx; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jctxMapKeys, i); + env->ReleaseStringUTFChars(jkey, mapKeys[i]); + } + delete[] mapKeys; + + setLongField(env, jexecOut, (long) out); + return ret; +}