Skip to content

Commit

Permalink
use ResourceScope in Model/Trainer/FeedForward.scala (apache#12882)
Browse files Browse the repository at this point in the history
* use ResourceScope in Model/Trainer/FeedForward.scala

* add moveToOuterScope public method to move resources to a outerScope if it exists

* fix memory leak in FeedForward.scala by making it a native resource and disposing argparams, auxParams
in dispose() method
  • Loading branch information
nswamy authored and lanking520 committed Oct 24, 2018
1 parent 9a0116d commit d6b13e1
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 178 deletions.
152 changes: 97 additions & 55 deletions scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.mxnet

import org.apache.mxnet.Base.CPtrAddress
import org.apache.mxnet.io.NDArrayIter
import org.apache.mxnet.optimizer.SGD
import org.slf4j.{LoggerFactory, Logger}
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.mutable.ListBuffer

Expand Down Expand Up @@ -55,7 +56,7 @@ class FeedForward private(
argParams: Map[String, NDArray],
auxParams: Map[String, NDArray],
private val allowExtraParams: Boolean,
val beginEpoch: Int) {
val beginEpoch: Int) extends NativeResource {

val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward])
private var argumentChecked = false
Expand Down Expand Up @@ -126,6 +127,8 @@ class FeedForward private(
}

// Initialize weight parameters and auxiliary states
// The NDArrays associated with the _argParms and _auxParams are not disposed instead
// they are passed a outer scope if available.
private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false)
: (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = {
val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes)
Expand All @@ -137,16 +140,26 @@ class FeedForward private(
val paramNameShapes = (argNames zip argShapes).filter { case (name, _) =>
paramNames.contains(name)
}
val argParams = paramNameShapes.map { case (name, shape) =>
(name, NDArray.zeros(shape))
val argParams = paramNameShapes.map { case (name, shape) => {
val param = NDArray.zeros(shape)
val curScope = ResourceScope.getCurrentScope()
if (curScope.isDefined) curScope.get.moveToOuterScope(param)
(name, param)
}
}.toMap
val auxParams = (auxNames zip auxShapes).map { case (name, shape) =>
(name, NDArray.zeros(shape))

val auxParams = (auxNames zip auxShapes).map { case (name, shape) => {
val param = NDArray.zeros(shape)
val curScope = ResourceScope.getCurrentScope()
if (curScope.isDefined) curScope.get.moveToOuterScope(param)
(name, param)
}
}.toMap

for ((k, v) <- argParams) {
if (_argParams != null && _argParams.contains(k) && (!overwrite)) {
argParams(k).set(_argParams(k))

} else {
initializer(k, v)
}
Expand Down Expand Up @@ -277,13 +290,15 @@ class FeedForward private(
def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric, kvStoreType: String,
epochEndCallback: EpochEndCallback, batchEndCallback: BatchEndCallback,
logger: Logger, workLoadList: Seq[Float]): Unit = {
// init params first to allow kv store use _argParams to decide its type
initSymbolParams(trainData)
// create kvstore
val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams)
fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
epochEndCallback, batchEndCallback, logger, workLoadList)
kvStore.foreach(_.dispose())
ResourceScope.using() {
// init params first to allow kv store use _argParams to decide its type
initSymbolParams(trainData)
// create kvstore
val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams)
fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
epochEndCallback, batchEndCallback, logger, workLoadList)
// kvStore.foreach(_.dispose())
}
}

def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric,
Expand Down Expand Up @@ -313,11 +328,13 @@ class FeedForward private(
batchEndCallback: BatchEndCallback, logger: Logger,
workLoadList: Seq[Float]): Unit = {
// init params first to allow kv store use _argParams to decide its type
initSymbolParams(trainData)
// create kvstore
val (kvStore, updateOnKVStore) = Model.createKVStore(kv)
fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
epochEndCallback, batchEndCallback, logger, workLoadList)
ResourceScope.using() {
initSymbolParams(trainData)
// create kvstore
val (kvStore, updateOnKVStore) = Model.createKVStore(kv)
fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
epochEndCallback, batchEndCallback, logger, workLoadList)
}
}

def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric,
Expand Down Expand Up @@ -352,44 +369,49 @@ class FeedForward private(
batchEndCallback: BatchEndCallback = null, logger: Logger = FeedForward.logger,
workLoadList: Seq[Float] = null): Unit = {
require(evalMetric != null, "evalMetric cannot be null")
val (argNames, paramNames, auxNames) = initSymbolParams(trainData)

// init optimizer
val batchSizeMultiplier = kvStore.map { kv =>
if (kv.`type` == "dist_sync") {
kv.numWorkers
} else {
1
}
}
val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1)
this.optimizer.setArgNames(argNames)
this.optimizer.setRescaleGrad(1f / batchSize)
this.optimizer.setSymbol(this.symbol)
val paramIdx2Name =
if (updateOnKVStore) {
paramNames.zipWithIndex.map { case (name, idx) => idx -> name }.toMap
} else {
paramNames.zipWithIndex.flatMap { case (name, idx) =>
(0 until ctx.length).map(k => (idx * ctx.length + k) -> name).toMap
}.toMap
// TODO: https://issues.apache.org/jira/browse/MXNET-1171
// this leaks memory, initSymbolParams->initParams is already called which allocates
// NDArray in argParams, auxParams and here we are overwriting it by calling again.
// PhantomRef should take care of releasing this when GC is called, however we have to
// wait for the GC call to happen.
val (argNames, paramNames, auxNames) = initSymbolParams(trainData)

// init optimizer
val batchSizeMultiplier = kvStore.map { kv =>
if (kv.`type` == "dist_sync") {
kv.numWorkers
} else {
1
}
}
this.optimizer.setIdx2Name(paramIdx2Name)

logger.debug("Start training on multi-device")
Model.trainMultiDevice(
symbol, ctx, argNames, paramNames, auxNames,
_argParams, _auxParams,
this.beginEpoch, this.numEpoch,
this.epochSize, this.optimizer,
kvStore, updateOnKVStore,
trainData = trainData, evalData = Option(evalData),
evalMetric = evalMetric,
epochEndCallback = Option(epochEndCallback),
batchEndCallback = Option(batchEndCallback),
workLoadList = workLoadList,
monitor = monitor,
symGen = symGen)
val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1)
this.optimizer.setArgNames(argNames)
this.optimizer.setRescaleGrad(1f / batchSize)
this.optimizer.setSymbol(this.symbol)
val paramIdx2Name =
if (updateOnKVStore) {
paramNames.zipWithIndex.map { case (name, idx) => idx -> name }.toMap
} else {
paramNames.zipWithIndex.flatMap { case (name, idx) =>
(0 until ctx.length).map(k => (idx * ctx.length + k) -> name).toMap
}.toMap
}
this.optimizer.setIdx2Name(paramIdx2Name)

logger.debug("Start training on multi-device")
Model.trainMultiDevice(
symbol, ctx, argNames, paramNames, auxNames,
_argParams, _auxParams,
this.beginEpoch, this.numEpoch,
this.epochSize, this.optimizer,
kvStore, updateOnKVStore,
trainData = trainData, evalData = Option(evalData),
evalMetric = evalMetric,
epochEndCallback = Option(epochEndCallback),
batchEndCallback = Option(batchEndCallback),
workLoadList = workLoadList,
monitor = monitor,
symGen = symGen)
}

/**
Expand All @@ -416,9 +438,29 @@ class FeedForward private(
def serialize(): Array[Byte] = {
Model.serialize(this.symbol, getArgParams, getAuxParams)
}

// hack to make the FeedForward.scala work with ResourceScope and
// automatically release _argParms and _auxParms
override def nativeAddress: CPtrAddress = hashCode()

override def nativeDeAllocator: CPtrAddress => Int = FeedForward.doNothingDeAllocator

override val ref: NativeResourceRef = super.register()

override val bytesAllocated: Long = 0L

override def dispose(): Unit = {
if (!super.isDisposed) {
_argParams.foreach { case (_, param) => param.dispose() }
_auxParams.foreach { case (_, param) => param.dispose() }
}
}
}

object FeedForward {

private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0

private val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward])
// Check if name is a data argument.
private def isDataArg(name: String): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ private[mxnet] trait NativeResource
*/
def nativeDeAllocator: (CPtrAddress => Int)

/** Call NativeResource.register to get the reference
/**
* Call NativeResource.register to get the reference
*/
val ref: NativeResourceRef

Expand All @@ -56,6 +57,7 @@ private[mxnet] trait NativeResource
// intentionally making it a val, so it gets evaluated when defined
val bytesAllocated: Long

// this is set and unset by [[ResourceScope.add]] and [[ResourceScope.remove]]
private[mxnet] var scope: Option[ResourceScope] = None

@volatile private var disposed = false
Expand All @@ -69,11 +71,11 @@ private[mxnet] trait NativeResource
* using PhantomReference
*/
def register(): NativeResourceRef = {
scope = ResourceScope.getCurrentScope()
val scope = ResourceScope.getCurrentScope()
if (scope.isDefined) scope.get.add(this)

NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated)
// register with PhantomRef tracking to release incase the objects go
// register with PhantomRef tracking to release in case the objects go
// out of reference within scope but are held for long time
NativeResourceRef.register(this, nativeDeAllocator)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ResourceScope extends AutoCloseable {
*/
def add(resource: NativeResource): Unit = {
resourceQ.+=(resource)
resource.scope = Some(this)
}

/**
Expand All @@ -67,7 +68,21 @@ class ResourceScope extends AutoCloseable {
*/
def remove(resource: NativeResource): Unit = {
resourceQ.-=(resource)
resource.scope = None
}

/**
* Removes from current Scope and moves to outer scope if it exists
* @param resource Resource to be moved to an outer scope
*/
def moveToOuterScope(resource: NativeResource): Unit = {
val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
if (prevScope.isDefined) {
this.remove(resource)
prevScope.get.add(resource)
} else this.remove(resource)
}

}

object ResourceScope {
Expand All @@ -92,32 +107,22 @@ object ResourceScope {

val curScope = if (scope != null) scope else new ResourceScope()

val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()

@inline def resourceInGeneric(g: scala.collection.Iterable[_]) = {
g.foreach( n =>
n match {
case nRes: NativeResource => {
removeAndAddToPrevScope(nRes)
curScope.moveToOuterScope(nRes)
}
case kv: scala.Tuple2[_, _] => {
if (kv._1.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
if (kv._1.isInstanceOf[NativeResource]) curScope.moveToOuterScope(
kv._1.asInstanceOf[NativeResource])
if (kv._2.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
if (kv._2.isInstanceOf[NativeResource]) curScope.moveToOuterScope(
kv._2.asInstanceOf[NativeResource])
}
}
)
}

@inline def removeAndAddToPrevScope(r: NativeResource) = {
curScope.remove(r)
if (prevScope.isDefined) {
prevScope.get.add(r)
r.scope = prevScope
}
}

@inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = {
if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed)
}
Expand All @@ -129,8 +134,8 @@ object ResourceScope {
ret match {
// don't de-allocate if returning any collection that contains NativeResource.
case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric)
case nRes: NativeResource => removeAndAddToPrevScope(nRes)
case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => removeAndAddToPrevScope(nd) )
case nRes: NativeResource => curScope.moveToOuterScope(nRes)
case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => curScope.moveToOuterScope(nd) )
case _ => // do nothing
}
ret
Expand Down
Loading

0 comments on commit d6b13e1

Please sign in to comment.