diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala index 2ed9d8cfbb84..2b1765531824 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala @@ -180,6 +180,7 @@ class FeedForward private( // Initialize the predictor module for running prediction. private def initPredictor(inputShapes: Map[String, Shape]): Unit = { + var shouldInit = true if (this.predExec != null) { val (argShapes, _, _) = symbol.inferShape(inputShapes) require(argShapes != null, "Shape inference failed." + @@ -187,14 +188,16 @@ class FeedForward private( s"and aux states ${symbol.listAuxiliaryStates()}") val predShapes = this.predExec.argArrays.map(_.shape) if (argShapes.sameElements(predShapes)) { - return + shouldInit = false } } - // for now only use the first device - val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = inputShapes) - predExec.copyParamsFrom(_argParams, _auxParams) - ExecutorManager.checkArguments(symbol) - this.predExec = predExec + if(shouldInit) { + // for now only use the first device + val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = inputShapes) + predExec.copyParamsFrom(_argParams, _auxParams) + ExecutorManager.checkArguments(symbol) + this.predExec = predExec + } } // Initialize the iterator given input. diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala index 1ac798e1b617..41a6f69394d2 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala @@ -173,14 +173,13 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[ allowMissing: Boolean = false, forceInit: Boolean = false, allowExtra: Boolean = false): Unit = { - if (paramsInitialized && !forceInit) { - return + if (!paramsInitialized || forceInit) { + require(binded, "call bind before initializing the parameters") + this._currModule.initParams(initializer, argParams, auxParams, + allowMissing, forceInit, allowExtra) + this.paramsDirty = false + this.paramsInitialized = true } - require(binded, "call bind before initializing the parameters") - this._currModule.initParams(initializer, argParams, auxParams, - allowMissing, forceInit, allowExtra) - this.paramsDirty = false - this.paramsInitialized = true } /** @@ -218,28 +217,27 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[ if (this.binded) { logger.warn("Already bound, ignoring bind()") - return - } + } else { + require(sharedModule.isEmpty, + "sharedModule for BucketingModule is not supported") - require(sharedModule.isEmpty, - "sharedModule for BucketingModule is not supported") - - this.forTraining = forTraining - this.inputsNeedGrad = inputsNeedGrad - this.binded = true - - val (sym, dNames, lNames) = this.symGen(this.defaultBucketKey) - val module = new Module(sym, dNames, lNames, this.contexts, - this.workLoadList, this.fixedParamNames) - module.bind(dataShapes, labelShapes, forTraining, inputsNeedGrad, - forceRebind = false, sharedModule = None, gradReq) - this._currModule = module - this._currBucketKey = this.defaultBucketKey - this._buckets(this.defaultBucketKey) = module - - // copy back saved params, if already initialized - if (this.paramsInitialized) { - this.setParams(argParams, auxParams) + this.forTraining = forTraining + this.inputsNeedGrad = inputsNeedGrad + this.binded = true + + val (sym, dNames, lNames) = this.symGen(this.defaultBucketKey) + val module = new Module(sym, dNames, lNames, this.contexts, + this.workLoadList, this.fixedParamNames) + module.bind(dataShapes, labelShapes, forTraining, inputsNeedGrad, + forceRebind = false, sharedModule = None, gradReq) + this._currModule = module + this._currBucketKey = this.defaultBucketKey + this._buckets(this.defaultBucketKey) = module + + // copy back saved params, if already initialized + if (this.paramsInitialized) { + this.setParams(argParams, auxParams) + } } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala index 97df3dcb307d..3255d9346b80 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala @@ -121,36 +121,35 @@ class Module(symbolVar: Symbol, allowMissing: Boolean = false, forceInit: Boolean = false, allowExtra: Boolean = false): Unit = { - if (paramsInitialized && !forceInit) { - return - } - require(binded, "call bind before initializing the parameters") + if (!paramsInitialized || forceInit) { + require(binded, "call bind before initializing the parameters") - if (this.argParams == null) { - val paramArrays = - execGroup.paramArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = nds(0).dtype)) - this.argParams = this.paramNames.zip(paramArrays).toMap - } + if (this.argParams == null) { + val paramArrays = + execGroup.paramArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = nds(0).dtype)) + this.argParams = this.paramNames.zip(paramArrays).toMap + } - if (this.auxParams == null) { - val auxArrays = - execGroup.auxArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = nds(0).dtype)) - this.auxParams = this.auxNames.zip(auxArrays).toMap - } + if (this.auxParams == null) { + val auxArrays = + execGroup.auxArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = nds(0).dtype)) + this.auxParams = this.auxNames.zip(auxArrays).toMap + } - this.argParams.foreach { case (name, arr) => - impl(name, arr, allowMissing, Option(initializer), argParams) - } + this.argParams.foreach { case (name, arr) => + impl(name, arr, allowMissing, Option(initializer), argParams) + } - this.auxParams.foreach { case (name, arr) => - impl(name, arr, allowMissing, Option(initializer), auxParams) - } + this.auxParams.foreach { case (name, arr) => + impl(name, arr, allowMissing, Option(initializer), auxParams) + } - this.paramsInitialized = true - this.paramsDirty = false + this.paramsInitialized = true + this.paramsDirty = false - // copy the initialized parameters to devices - this.execGroup.setParams(this.argParams, this.auxParams, allowExtra = allowExtra) + // copy the initialized parameters to devices + this.execGroup.setParams(this.argParams, this.auxParams, allowExtra = allowExtra) + } } // Internal helper for parameter initialization @@ -246,64 +245,64 @@ class Module(symbolVar: Symbol, if (binded) { logger.warn("Already binded, ignoring bind()") - return - } + } else { + this.forTraining = forTraining + this.inputsNeedGrad = inputsNeedGrad + this.binded = true - this.forTraining = forTraining - this.inputsNeedGrad = inputsNeedGrad - this.binded = true + if (!forTraining) { + require(!inputsNeedGrad, "Invalid inputsNeedGrad (cannot be true if not forTraining)") + } else { + // this is not True, as some module might not contains a loss function + // that consumes the labels + // require(labelShapes != None) + } - if (!forTraining) { - require(!inputsNeedGrad, "Invalid inputsNeedGrad (cannot be true if not forTraining)") - } else { - // this is not True, as some module might not contains a loss function - // that consumes the labels - // require(labelShapes != None) - } + this.dataShapesVar = dataShapes + this.labelShapesVar = labelShapes - this.dataShapesVar = dataShapes - this.labelShapesVar = labelShapes - - val sharedGroup = - sharedModule.map(sharedModuleInst => { - require(sharedModuleInst.binded && sharedModuleInst.paramsInitialized, - s"bind() and initParams() must be called first on shared module.") - sharedModuleInst.execGroup - }) - - val inputTypes = this.dataShapesVar.map(dataDesc => (dataDesc.name, dataDesc.dtype)).toMap ++ - labelShapes.map(shapes => shapes.map(dataDesc => (dataDesc.name, dataDesc.dtype)).toMap) - .getOrElse(Map.empty[String, DType]) - - execGroup = new Builder(symbol, contexts, paramNames) - .setWorkLoadList(workLoads) - .setDataShapes(dataShapes) - .setLabelShapes(labelShapes.orNull) - .setForTraining(forTraining) - .setInputsNeedGrad(inputsNeedGrad) - .setSharedGroup(sharedGroup.orNull) - .setFixedParamNames(fixedParamNames.orNull) - .setGradReq(gradReq) - .setInputTypes(inputTypes) - .build() - - if (sharedModule.isDefined) { - paramsInitialized = true - argParams = sharedModule.get.argParams - auxParams = sharedModule.get.auxParams - } else if (paramsInitialized) { - // if the parameters are already initialized, we are re-binding - // so automatically copy the already initialized params - execGroup.setParams(argParams, auxParams) - } + val sharedGroup = + sharedModule.map(sharedModuleInst => { + require(sharedModuleInst.binded && sharedModuleInst.paramsInitialized, + s"bind() and initParams() must be called first on shared module.") + sharedModuleInst.execGroup + }) - sharedModule.foreach { - case sharedModuleInst: Module => - if (sharedModuleInst.optimizerInitialized) { - borrowOptimizer(sharedModuleInst) - } - case _ => + val inputTypes = this.dataShapesVar.map(dataDesc => (dataDesc.name, dataDesc.dtype)).toMap ++ + labelShapes.map(shapes => shapes.map(dataDesc => (dataDesc.name, dataDesc.dtype)).toMap) + .getOrElse(Map.empty[String, DType]) + + execGroup = new Builder(symbol, contexts, paramNames) + .setWorkLoadList(workLoads) + .setDataShapes(dataShapes) + .setLabelShapes(labelShapes.orNull) + .setForTraining(forTraining) + .setInputsNeedGrad(inputsNeedGrad) + .setSharedGroup(sharedGroup.orNull) + .setFixedParamNames(fixedParamNames.orNull) + .setGradReq(gradReq) + .setInputTypes(inputTypes) + .build() + + if (sharedModule.isDefined) { + paramsInitialized = true + argParams = sharedModule.get.argParams + auxParams = sharedModule.get.auxParams + } else if (paramsInitialized) { + // if the parameters are already initialized, we are re-binding + // so automatically copy the already initialized params + execGroup.setParams(argParams, auxParams) + } + + sharedModule.foreach { + case sharedModuleInst: Module => + if (sharedModuleInst.optimizerInitialized) { + borrowOptimizer(sharedModuleInst) + } + case _ => + } } + } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala index 2e506c08e548..3c3eeb97f201 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala @@ -154,38 +154,37 @@ class SequentialModule extends BaseModule { allowMissing: Boolean = false, forceInit: Boolean = false, allowExtra: Boolean = false): Unit = { - if (this.paramsInitialized && !forceInit) { - return - } - require(this.binded, "call bind before initializing the parameters") + if (!this.paramsInitialized || forceInit) { + require(this.binded, "call bind before initializing the parameters") - for (module <- this.modules) { - module.initParams(initializer = initializer, argParams = argParams, - auxParams = auxParams, allowMissing = allowMissing, - forceInit = forceInit, allowExtra = allowExtra) - } + for (module <- this.modules) { + module.initParams(initializer = initializer, argParams = argParams, + auxParams = auxParams, allowMissing = allowMissing, + forceInit = forceInit, allowExtra = allowExtra) + } - // Internal function to help checking duplicated names, - // make sure we do not have duplicated parameter names. - def checkName(knownNames: scala.collection.mutable.Map[String, Int], - newNames: Array[String], modules: ArrayBuffer[BaseModule], i: Int): Unit = { - for (name <- newNames) { - require(!knownNames.contains(name), s"Duplicated parameter names: " + - s"name $name in layer $i (${modules(i).getClass.getName}) is already " + - s"used in layer ${knownNames("name")}" + - s"(${modules(knownNames("name")).getClass.getName})") - knownNames(name) = i + // Internal function to help checking duplicated names, + // make sure we do not have duplicated parameter names. + def checkName(knownNames: scala.collection.mutable.Map[String, Int], + newNames: Array[String], modules: ArrayBuffer[BaseModule], i: Int): Unit = { + for (name <- newNames) { + require(!knownNames.contains(name), s"Duplicated parameter names: " + + s"name $name in layer $i (${modules(i).getClass.getName}) is already " + + s"used in layer ${knownNames("name")}" + + s"(${modules(knownNames("name")).getClass.getName})") + knownNames(name) = i + } } - } - val argNames = scala.collection.mutable.Map[String, Int]() - val auxNames = scala.collection.mutable.Map[String, Int]() - for ((module, iLayer) <- this.modules.zipWithIndex) { - val (argParams, auxParams) = module.getParams - checkName(argNames, argParams.keys.toArray, this.modules, iLayer) - checkName(auxNames, auxParams.keys.toArray, this.modules, iLayer) + val argNames = scala.collection.mutable.Map[String, Int]() + val auxNames = scala.collection.mutable.Map[String, Int]() + for ((module, iLayer) <- this.modules.zipWithIndex) { + val (argParams, auxParams) = module.getParams + checkName(argNames, argParams.keys.toArray, this.modules, iLayer) + checkName(auxNames, auxParams.keys.toArray, this.modules, iLayer) + } + this.paramsInitialized = true } - this.paramsInitialized = true } /** @@ -216,54 +215,54 @@ class SequentialModule extends BaseModule { gradReq: String = "write"): Unit = { if (this.binded && !forceRebind) { logger.warn(s"Already binded, ignoring bind()") - return - } - - if (inputsNeedGrad) { - require(forTraining, "inputsNeedGrad can be set only for training") - } - - require(sharedModule == None, "Shared module is not supported") - require(this.modules.length > 0, "Attempting to bind an empty SequentialModule") - - this.forTraining = forTraining - this.inputsNeedGrad = inputsNeedGrad - this.binded = true - - // the same label shapes are used for all chained modules - this.labelShapesVar = labelShapes + } else { + if (inputsNeedGrad) { + require(forTraining, "inputsNeedGrad can be set only for training") + } - var myDataShapes = dataShapes - var myLabelShapes = labelShapes - var anybodyEverNeedsLabel = false - for ((module, iLayer) <- this.modules.zipWithIndex) { - val meta = this.metas(iLayer) - if (meta.contains(META_TAKE_LABELS) && meta(META_TAKE_LABELS)) { - myLabelShapes = labelShapes - anybodyEverNeedsLabel = true - } else myLabelShapes = None - - val myInputsNeedGrad = if (inputsNeedGrad || (forTraining && iLayer > 0)) true else false - if (meta.contains(META_AUTO_WIRING) && meta(META_AUTO_WIRING)) { - val dataNames = module.dataNames - require(dataNames.length == myDataShapes.length, - s"dataNmes $dataNames and dataShapes $myDataShapes do not match") - myDataShapes = dataNames.zip(myDataShapes).map { case (newName, dataDes) => - DataDesc(newName, dataDes.shape) + require(sharedModule == None, "Shared module is not supported") + require(this.modules.length > 0, "Attempting to bind an empty SequentialModule") + + this.forTraining = forTraining + this.inputsNeedGrad = inputsNeedGrad + this.binded = true + + // the same label shapes are used for all chained modules + this.labelShapesVar = labelShapes + + var myDataShapes = dataShapes + var myLabelShapes = labelShapes + var anybodyEverNeedsLabel = false + for ((module, iLayer) <- this.modules.zipWithIndex) { + val meta = this.metas(iLayer) + if (meta.contains(META_TAKE_LABELS) && meta(META_TAKE_LABELS)) { + myLabelShapes = labelShapes + anybodyEverNeedsLabel = true + } else myLabelShapes = None + + val myInputsNeedGrad = if (inputsNeedGrad || (forTraining && iLayer > 0)) true else false + if (meta.contains(META_AUTO_WIRING) && meta(META_AUTO_WIRING)) { + val dataNames = module.dataNames + require(dataNames.length == myDataShapes.length, + s"dataNmes $dataNames and dataShapes $myDataShapes do not match") + myDataShapes = dataNames.zip(myDataShapes).map { case (newName, dataDes) => + DataDesc(newName, dataDes.shape) + } } - } - module.bind(myDataShapes, myLabelShapes, forTraining, myInputsNeedGrad, + module.bind(myDataShapes, myLabelShapes, forTraining, myInputsNeedGrad, forceRebind, sharedModule = None, gradReq) - // the output of the previous module is the data of the next module - myDataShapes = module.outputShapes.map{case (name, shape) => DataDesc(name, shape)} - } + // the output of the previous module is the data of the next module + myDataShapes = module.outputShapes.map{case (name, shape) => DataDesc(name, shape)} + } - if (!anybodyEverNeedsLabel) { - // then I do not need label either - this.labelShapesVar = None + if (!anybodyEverNeedsLabel) { + // then I do not need label either + this.labelShapesVar = None + } } + } /** diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index 6d414bb0328a..350e28cf8634 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -202,10 +202,10 @@ object BucketIo { labelBuf.set(labels.flatten) iBucket += 1 - val batchProvideData = { val tmp = ListMap("data" -> dataBuf.shape) - tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2)) - } - val batchProvideLabel = ListMap("softmax_label" -> labelBuf.shape) + val batchProvideData = IndexedSeq(DataDesc("data", dataBuf.shape, dataBuf.dtype)) ++ + initStates.map { + case (name, shape) => DataDesc(name, Shape(shape._1, shape._2), DType.Float32)} + val batchProvideLabel = IndexedSeq(DataDesc("softmax_label", labelBuf.shape, labelBuf.dtype)) val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2)) new DataBatch(IndexedSeq(dataBuf.copy()) ++ initStateArrays, IndexedSeq(labelBuf.copy()), diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/utils/Network.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/utils/Network.scala index c61229af0035..836901f69f8f 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/utils/Network.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/utils/Network.scala @@ -20,6 +20,7 @@ package org.apache.mxnet.spark.utils import java.io.IOException import java.net.{ServerSocket, NetworkInterface} import java.util.regex.Pattern +import scala.collection.JavaConverters._ /** * Helper functions to decide ip address / port @@ -33,19 +34,16 @@ object Network { "([01]?\\d\\d?|2[0-4]\\d|25[0-5])$") def ipAddress: String = { - val interfaces = NetworkInterface.getNetworkInterfaces - while (interfaces.hasMoreElements) { - val interface = interfaces.nextElement - val addresses = interface.getInetAddresses - while (addresses.hasMoreElements) { - val address = addresses.nextElement - val ip = address.getHostAddress - if (!ip.startsWith("127.") && IPADDRESS_PATTERN.matcher(ip).matches()) { - return ip + val interfaces = NetworkInterface.getNetworkInterfaces.asScala + val interface = interfaces.toStream.flatMap( + _.getInetAddresses.asScala.toStream.flatMap( + address => { + val ip = address.getHostAddress + Option(ip).filter(ip => !ip.startsWith("127.") && IPADDRESS_PATTERN.matcher(ip).matches()) } - } - } - "127.0.0.1" + ) + ).headOption + interface.getOrElse("127.0.0.1") } def availablePort: Int = { diff --git a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala index 6d36ca51db90..293cfa13cfce 100644 --- a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala +++ b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala @@ -92,20 +92,12 @@ trait SharedSparkContext extends FunSuite with BeforeAndAfterEach with BeforeAnd private def getJarFilePath(root: String): String = { val jarFiles = findJars(s"$root/target/") - if (jarFiles != null && jarFiles.nonEmpty) { - jarFiles.head.getAbsolutePath - } else { - null - } + Option(jarFiles).flatMap(_.headOption).map(_.getAbsolutePath).orNull } private def getSparkJar: String = { val jarFiles = findJars(s"$composeWorkingDirPath/target/") - if (jarFiles != null && jarFiles.nonEmpty) { - jarFiles.head.getAbsolutePath - } else { - null - } + Option(jarFiles).flatMap(_.headOption).map(_.getAbsolutePath).orNull } private def getNativeJars(root: String): String =