diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala index 00a1450089f7..2ed9d8cfbb84 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala @@ -17,9 +17,10 @@ package org.apache.mxnet +import org.apache.mxnet.Base.CPtrAddress import org.apache.mxnet.io.NDArrayIter import org.apache.mxnet.optimizer.SGD -import org.slf4j.{LoggerFactory, Logger} +import org.slf4j.{Logger, LoggerFactory} import scala.collection.mutable.ListBuffer @@ -55,7 +56,7 @@ class FeedForward private( argParams: Map[String, NDArray], auxParams: Map[String, NDArray], private val allowExtraParams: Boolean, - val beginEpoch: Int) { + val beginEpoch: Int) extends NativeResource { val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward]) private var argumentChecked = false @@ -126,6 +127,8 @@ class FeedForward private( } // Initialize weight parameters and auxiliary states + // The NDArrays associated with the _argParms and _auxParams are not disposed instead + // they are passed a outer scope if available. private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false) : (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = { val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes) @@ -137,16 +140,26 @@ class FeedForward private( val paramNameShapes = (argNames zip argShapes).filter { case (name, _) => paramNames.contains(name) } - val argParams = paramNameShapes.map { case (name, shape) => - (name, NDArray.zeros(shape)) + val argParams = paramNameShapes.map { case (name, shape) => { + val param = NDArray.zeros(shape) + val curScope = ResourceScope.getCurrentScope() + if (curScope.isDefined) curScope.get.moveToOuterScope(param) + (name, param) + } }.toMap - val auxParams = (auxNames zip auxShapes).map { case (name, shape) => - (name, NDArray.zeros(shape)) + + val auxParams = (auxNames zip auxShapes).map { case (name, shape) => { + val param = NDArray.zeros(shape) + val curScope = ResourceScope.getCurrentScope() + if (curScope.isDefined) curScope.get.moveToOuterScope(param) + (name, param) + } }.toMap for ((k, v) <- argParams) { if (_argParams != null && _argParams.contains(k) && (!overwrite)) { argParams(k).set(_argParams(k)) + } else { initializer(k, v) } @@ -277,13 +290,15 @@ class FeedForward private( def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric, kvStoreType: String, epochEndCallback: EpochEndCallback, batchEndCallback: BatchEndCallback, logger: Logger, workLoadList: Seq[Float]): Unit = { - // init params first to allow kv store use _argParams to decide its type - initSymbolParams(trainData) - // create kvstore - val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams) - fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore, - epochEndCallback, batchEndCallback, logger, workLoadList) - kvStore.foreach(_.dispose()) + ResourceScope.using() { + // init params first to allow kv store use _argParams to decide its type + initSymbolParams(trainData) + // create kvstore + val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams) + fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore, + epochEndCallback, batchEndCallback, logger, workLoadList) +// kvStore.foreach(_.dispose()) + } } def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric, @@ -313,11 +328,13 @@ class FeedForward private( batchEndCallback: BatchEndCallback, logger: Logger, workLoadList: Seq[Float]): Unit = { // init params first to allow kv store use _argParams to decide its type - initSymbolParams(trainData) - // create kvstore - val (kvStore, updateOnKVStore) = Model.createKVStore(kv) - fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore, - epochEndCallback, batchEndCallback, logger, workLoadList) + ResourceScope.using() { + initSymbolParams(trainData) + // create kvstore + val (kvStore, updateOnKVStore) = Model.createKVStore(kv) + fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore, + epochEndCallback, batchEndCallback, logger, workLoadList) + } } def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric, @@ -352,44 +369,49 @@ class FeedForward private( batchEndCallback: BatchEndCallback = null, logger: Logger = FeedForward.logger, workLoadList: Seq[Float] = null): Unit = { require(evalMetric != null, "evalMetric cannot be null") - val (argNames, paramNames, auxNames) = initSymbolParams(trainData) - - // init optimizer - val batchSizeMultiplier = kvStore.map { kv => - if (kv.`type` == "dist_sync") { - kv.numWorkers - } else { - 1 - } - } - val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1) - this.optimizer.setArgNames(argNames) - this.optimizer.setRescaleGrad(1f / batchSize) - this.optimizer.setSymbol(this.symbol) - val paramIdx2Name = - if (updateOnKVStore) { - paramNames.zipWithIndex.map { case (name, idx) => idx -> name }.toMap - } else { - paramNames.zipWithIndex.flatMap { case (name, idx) => - (0 until ctx.length).map(k => (idx * ctx.length + k) -> name).toMap - }.toMap + // TODO: https://issues.apache.org/jira/browse/MXNET-1171 + // this leaks memory, initSymbolParams->initParams is already called which allocates + // NDArray in argParams, auxParams and here we are overwriting it by calling again. + // PhantomRef should take care of releasing this when GC is called, however we have to + // wait for the GC call to happen. + val (argNames, paramNames, auxNames) = initSymbolParams(trainData) + + // init optimizer + val batchSizeMultiplier = kvStore.map { kv => + if (kv.`type` == "dist_sync") { + kv.numWorkers + } else { + 1 + } } - this.optimizer.setIdx2Name(paramIdx2Name) - - logger.debug("Start training on multi-device") - Model.trainMultiDevice( - symbol, ctx, argNames, paramNames, auxNames, - _argParams, _auxParams, - this.beginEpoch, this.numEpoch, - this.epochSize, this.optimizer, - kvStore, updateOnKVStore, - trainData = trainData, evalData = Option(evalData), - evalMetric = evalMetric, - epochEndCallback = Option(epochEndCallback), - batchEndCallback = Option(batchEndCallback), - workLoadList = workLoadList, - monitor = monitor, - symGen = symGen) + val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1) + this.optimizer.setArgNames(argNames) + this.optimizer.setRescaleGrad(1f / batchSize) + this.optimizer.setSymbol(this.symbol) + val paramIdx2Name = + if (updateOnKVStore) { + paramNames.zipWithIndex.map { case (name, idx) => idx -> name }.toMap + } else { + paramNames.zipWithIndex.flatMap { case (name, idx) => + (0 until ctx.length).map(k => (idx * ctx.length + k) -> name).toMap + }.toMap + } + this.optimizer.setIdx2Name(paramIdx2Name) + + logger.debug("Start training on multi-device") + Model.trainMultiDevice( + symbol, ctx, argNames, paramNames, auxNames, + _argParams, _auxParams, + this.beginEpoch, this.numEpoch, + this.epochSize, this.optimizer, + kvStore, updateOnKVStore, + trainData = trainData, evalData = Option(evalData), + evalMetric = evalMetric, + epochEndCallback = Option(epochEndCallback), + batchEndCallback = Option(batchEndCallback), + workLoadList = workLoadList, + monitor = monitor, + symGen = symGen) } /** @@ -416,9 +438,29 @@ class FeedForward private( def serialize(): Array[Byte] = { Model.serialize(this.symbol, getArgParams, getAuxParams) } + + // hack to make the FeedForward.scala work with ResourceScope and + // automatically release _argParms and _auxParms + override def nativeAddress: CPtrAddress = hashCode() + + override def nativeDeAllocator: CPtrAddress => Int = FeedForward.doNothingDeAllocator + + override val ref: NativeResourceRef = super.register() + + override val bytesAllocated: Long = 0L + + override def dispose(): Unit = { + if (!super.isDisposed) { + _argParams.foreach { case (_, param) => param.dispose() } + _auxParams.foreach { case (_, param) => param.dispose() } + } + } } object FeedForward { + + private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0 + private val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward]) // Check if name is a data argument. private def isDataArg(name: String): Boolean = { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index 48d4b0c193b1..1806b8653376 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -46,7 +46,8 @@ private[mxnet] trait NativeResource */ def nativeDeAllocator: (CPtrAddress => Int) - /** Call NativeResource.register to get the reference + /** + * Call NativeResource.register to get the reference */ val ref: NativeResourceRef @@ -56,6 +57,7 @@ private[mxnet] trait NativeResource // intentionally making it a val, so it gets evaluated when defined val bytesAllocated: Long + // this is set and unset by [[ResourceScope.add]] and [[ResourceScope.remove]] private[mxnet] var scope: Option[ResourceScope] = None @volatile private var disposed = false @@ -69,11 +71,11 @@ private[mxnet] trait NativeResource * using PhantomReference */ def register(): NativeResourceRef = { - scope = ResourceScope.getCurrentScope() + val scope = ResourceScope.getCurrentScope() if (scope.isDefined) scope.get.add(this) NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated) - // register with PhantomRef tracking to release incase the objects go + // register with PhantomRef tracking to release in case the objects go // out of reference within scope but are held for long time NativeResourceRef.register(this, nativeDeAllocator) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index 1c5782d873a9..30fe1473a2cd 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -58,6 +58,7 @@ class ResourceScope extends AutoCloseable { */ def add(resource: NativeResource): Unit = { resourceQ.+=(resource) + resource.scope = Some(this) } /** @@ -67,7 +68,21 @@ class ResourceScope extends AutoCloseable { */ def remove(resource: NativeResource): Unit = { resourceQ.-=(resource) + resource.scope = None } + + /** + * Removes from current Scope and moves to outer scope if it exists + * @param resource Resource to be moved to an outer scope + */ + def moveToOuterScope(resource: NativeResource): Unit = { + val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope() + if (prevScope.isDefined) { + this.remove(resource) + prevScope.get.add(resource) + } else this.remove(resource) + } + } object ResourceScope { @@ -92,32 +107,22 @@ object ResourceScope { val curScope = if (scope != null) scope else new ResourceScope() - val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope() - @inline def resourceInGeneric(g: scala.collection.Iterable[_]) = { g.foreach( n => n match { case nRes: NativeResource => { - removeAndAddToPrevScope(nRes) + curScope.moveToOuterScope(nRes) } case kv: scala.Tuple2[_, _] => { - if (kv._1.isInstanceOf[NativeResource]) removeAndAddToPrevScope( + if (kv._1.isInstanceOf[NativeResource]) curScope.moveToOuterScope( kv._1.asInstanceOf[NativeResource]) - if (kv._2.isInstanceOf[NativeResource]) removeAndAddToPrevScope( + if (kv._2.isInstanceOf[NativeResource]) curScope.moveToOuterScope( kv._2.asInstanceOf[NativeResource]) } } ) } - @inline def removeAndAddToPrevScope(r: NativeResource) = { - curScope.remove(r) - if (prevScope.isDefined) { - prevScope.get.add(r) - r.scope = prevScope - } - } - @inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = { if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed) } @@ -129,8 +134,8 @@ object ResourceScope { ret match { // don't de-allocate if returning any collection that contains NativeResource. case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric) - case nRes: NativeResource => removeAndAddToPrevScope(nRes) - case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => removeAndAddToPrevScope(nd) ) + case nRes: NativeResource => curScope.moveToOuterScope(nRes) + case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => curScope.moveToOuterScope(nd) ) case _ => // do nothing } ret diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala index 608e191e019f..f6c283c3dfb2 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala @@ -43,7 +43,7 @@ object TrainModel { */ def test(model: String, dataPath: String, numExamples: Int = 60000, numEpochs: Int = 10, benchmark: Boolean = false): Float = { - NDArrayCollector.auto().withScope { + ResourceScope.using() { val devs = Array(Context.cpu(0)) val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String] val (dataLoader, net) = dataLoaderAndModel("mnist", model, dataPath, @@ -110,44 +110,46 @@ object TrainModel { val inst = new TrainModel val parser: CmdLineParser = new CmdLineParser(inst) try { - parser.parseArgument(args.toList.asJava) - - val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME") - else inst.dataDir - - val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network, dataPath, - inst.numLayers, inst.numExamples, inst.benchmark) - - val devs = - if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt)) - else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt)) - else Array(Context.cpu(0)) - - val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String] - envs.put("DMLC_ROLE", inst.role) - if (inst.schedulerHost != null) { - require(inst.schedulerPort > 0, "scheduler port not specified") - envs.put("DMLC_PS_ROOT_URI", inst.schedulerHost) - envs.put("DMLC_PS_ROOT_PORT", inst.schedulerPort.toString) - require(inst.numWorker > 0, "Num of workers must > 0") - envs.put("DMLC_NUM_WORKER", inst.numWorker.toString) - require(inst.numServer > 0, "Num of servers must > 0") - envs.put("DMLC_NUM_SERVER", inst.numServer.toString) - logger.info("Init PS environments") - KVStoreServer.init(envs.toMap) - } - - if (inst.role != "worker") { - logger.info("Start KVStoreServer for scheduler & servers") - KVStoreServer.start() - } else { - Trainer.fit(batchSize = inst.batchSize, numExamples = inst.numExamples, devs = devs, - network = net, dataLoader = dataLoader, - kvStore = inst.kvStore, numEpochs = inst.numEpochs, - modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch, - lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch, - monitorSize = inst.monitor) - logger.info("Finish fit ...") + ResourceScope.using() { + parser.parseArgument(args.toList.asJava) + + val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME") + else inst.dataDir + + val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network, dataPath, + inst.numLayers, inst.numExamples, inst.benchmark) + + val devs = + if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt)) + else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt)) + else Array(Context.cpu(0)) + + val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String] + envs.put("DMLC_ROLE", inst.role) + if (inst.schedulerHost != null) { + require(inst.schedulerPort > 0, "scheduler port not specified") + envs.put("DMLC_PS_ROOT_URI", inst.schedulerHost) + envs.put("DMLC_PS_ROOT_PORT", inst.schedulerPort.toString) + require(inst.numWorker > 0, "Num of workers must > 0") + envs.put("DMLC_NUM_WORKER", inst.numWorker.toString) + require(inst.numServer > 0, "Num of servers must > 0") + envs.put("DMLC_NUM_SERVER", inst.numServer.toString) + logger.info("Init PS environments") + KVStoreServer.init(envs.toMap) + } + + if (inst.role != "worker") { + logger.info("Start KVStoreServer for scheduler & servers") + KVStoreServer.start() + } else { + Trainer.fit(batchSize = inst.batchSize, numExamples = inst.numExamples, devs = devs, + network = net, dataLoader = dataLoader, + kvStore = inst.kvStore, numEpochs = inst.numEpochs, + modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch, + lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch, + monitorSize = inst.monitor) + logger.info("Finish fit ...") + } } } catch { case ex: Exception => { diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala index 9a54e58b653e..276816cf8c8c 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala @@ -50,83 +50,84 @@ object Trainer { lr: Float = 0.1f, lrFactor: Float = 1f, lrFactorEpoch: Float = 1f, clipGradient: Float = 0f, monitorSize: Int = -1): Accuracy = { // kvstore - var kv = KVStore.create(kvStore) + ResourceScope.using() { + var kv = KVStore.create(kvStore) - // load model - val modelPrefixWithRank = - if (modelPrefix == null) null - else modelPrefix + s"-${kv.rank}" + // load model + val modelPrefixWithRank = + if (modelPrefix == null) null + else modelPrefix + s"-${kv.rank}" - val (argParams, auxParams, beginEpoch) = - if (loadEpoch >= 0) { - require(modelPrefixWithRank != null) - val tmp = FeedForward.load(modelPrefix, loadEpoch) - (tmp.getArgParams, tmp.getAuxParams, loadEpoch) - } else { - (null, null, 0) - } + val (argParams, auxParams, beginEpoch) = + if (loadEpoch >= 0) { + require(modelPrefixWithRank != null) + val tmp = FeedForward.load(modelPrefix, loadEpoch) + (tmp.getArgParams, tmp.getAuxParams, loadEpoch) + } else { + (null, null, 0) + } - // save model - val checkpoint: EpochEndCallback = - if (modelPrefix == null) null - else new EpochEndCallback { - override def invoke(epoch: Int, symbol: Symbol, - argParams: Map[String, NDArray], - auxStates: Map[String, NDArray]): Unit = { - Model.saveCheckpoint(modelPrefix, epoch + 1, symbol, argParams, auxParams) + // save model + val checkpoint: EpochEndCallback = + if (modelPrefix == null) null + else new EpochEndCallback { + override def invoke(epoch: Int, symbol: Symbol, + argParams: Map[String, NDArray], + auxStates: Map[String, NDArray]): Unit = { + Model.saveCheckpoint(modelPrefix, epoch + 1, symbol, argParams, auxParams) + } } - } - // data - val (train, validation) = dataLoader(batchSize, kv) + // data + val (train, validation) = dataLoader(batchSize, kv) - // train - val epochSize = - if (kvStore == "dist_sync") numExamples / batchSize / kv.numWorkers - else numExamples / batchSize + // train + val epochSize = + if (kvStore == "dist_sync") numExamples / batchSize / kv.numWorkers + else numExamples / batchSize - val lrScheduler = - if (lrFactor < 1f) { - new FactorScheduler(step = Math.max((epochSize * lrFactorEpoch).toInt, 1), - factor = lrFactor) - } else { - null - } - val optimizer: Optimizer = new SGD(learningRate = lr, - lrScheduler = lrScheduler, clipGradient = clipGradient, - momentum = 0.9f, wd = 0.00001f) + val lrScheduler = + if (lrFactor < 1f) { + new FactorScheduler(step = Math.max((epochSize * lrFactorEpoch).toInt, 1), + factor = lrFactor) + } else { + null + } + val optimizer: Optimizer = new SGD(learningRate = lr, + lrScheduler = lrScheduler, clipGradient = clipGradient, + momentum = 0.9f, wd = 0.00001f) - // disable kvstore for single device - if (kv.`type`.contains("local") && (devs.length == 1 || devs(0).deviceType != "gpu")) { - kv.dispose() - kv = null - } + // disable kvstore for single device + if (kv.`type`.contains("local") && (devs.length == 1 || devs(0).deviceType != "gpu")) { + kv.dispose() + kv = null + } - val model = new FeedForward(ctx = devs, - symbol = network, - numEpoch = numEpochs, - optimizer = optimizer, - initializer = new Xavier(factorType = "in", magnitude = 2.34f), - argParams = argParams, - auxParams = auxParams, - beginEpoch = beginEpoch, - epochSize = epochSize) - if (monitorSize > 0) { - model.setMonitor(new Monitor(monitorSize)) - } - val acc = new Accuracy() - model.fit(trainData = train, - evalData = validation, - evalMetric = acc, - kvStore = kv, - batchEndCallback = new Speedometer(batchSize, 50), - epochEndCallback = checkpoint) - if (kv != null) { - kv.dispose() + val model = new FeedForward(ctx = devs, + symbol = network, + numEpoch = numEpochs, + optimizer = optimizer, + initializer = new Xavier(factorType = "in", magnitude = 2.34f), + argParams = argParams, + auxParams = auxParams, + beginEpoch = beginEpoch, + epochSize = epochSize) + if (monitorSize > 0) { + model.setMonitor(new Monitor(monitorSize)) + } + val acc = new Accuracy() + model.fit(trainData = train, + evalData = validation, + evalMetric = acc, + kvStore = kv, + batchEndCallback = new Speedometer(batchSize, 50), + epochEndCallback = checkpoint) + if (kv != null) { + kv.dispose() + } + acc } - acc } - // scalastyle:on parameterNum }