Skip to content

Commit

Permalink
NativeResource Management in Scala (apache#12647)
Browse files Browse the repository at this point in the history
* add Generic MXNetHandle trait and MXNetHandlePhantomRef class that will be used by all MXNetObjects

* Generic Handle with AutoCloseable

* add NativeResource and NativeResourceManager with Periodic GC calling

* use NativeResource trait in NDArray, Symbol and Executor

* add run train mnist script

* create a Generic ResourceScope that can collect all NativeResources to dispose at the end

* modify NativeResource and ResourceScope, extend NativeResource in NDArray, Symbol and Executor

* remove GCExecutor

* deRegister PhantomReferences by when calling dispose()

* add Finalizer(temporary) to NativeResource

* refactor NativeResource.dispose() method

* update NativeResource/add Unit Test for NativeResource

* updates to NativeResource/NativeResourceRef and unit tests to NativeResource

* remove redundant code added because of the object equality that was needed

* add ResourceScope

* Fix NativeResource to not remove from Scope, add Unit Tests to ResourceScope

* cleanup log/print debug statements

* use TreeSet inplace of ArrayBuffer to speedup removal of resources from ResourceScope
Fix Executor dispose and make KVStore a NativeResource

* fix segfault that was happening because of NDArray creation on the fly in Optimizer

* Add comments for dispose(param:Boolean)
  • Loading branch information
nswamy authored and azai91 committed Dec 1, 2018
1 parent 4ae6086 commit c3693a4
Show file tree
Hide file tree
Showing 14 changed files with 778 additions and 124 deletions.
7 changes: 7 additions & 0 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,12 @@
<artifactId>commons-io</artifactId>
<version>2.1</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.mockito/mockito-all -->
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
20 changes: 11 additions & 9 deletions scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object Executor {
* @see Symbol.bind : to create executor
*/
class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] val symbol: Symbol) extends WarnIfNotDisposed {
private[mxnet] val symbol: Symbol) extends NativeResource {
private[mxnet] var argArrays: Array[NDArray] = null
private[mxnet] var gradArrays: Array[NDArray] = null
private[mxnet] var auxArrays: Array[NDArray] = null
Expand All @@ -59,14 +59,15 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] var _group2ctx: Map[String, Context] = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])

private var disposed = false
protected def isDisposed = disposed

def dispose(): Unit = {
if (!disposed) {
outputs.foreach(_.dispose())
_LIB.mxExecutorFree(handle)
disposed = true
override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
// cannot determine the off-heap size of this object
override val bytesAllocated: Long = 0
override val ref: NativeResourceRef = super.register()
override def dispose(): Unit = {
if (!super.isDisposed) {
super.dispose()
outputs.foreach(o => o.dispose())
}
}

Expand Down Expand Up @@ -305,4 +306,5 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
checkCall(_LIB.mxExecutorPrint(handle, str))
str.value
}

}
21 changes: 8 additions & 13 deletions scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,17 @@ object KVStore {
}
}

class KVStore(private[mxnet] val handle: KVStoreHandle) extends WarnIfNotDisposed {
class KVStore(private[mxnet] val handle: KVStoreHandle) extends NativeResource {
private val logger: Logger = LoggerFactory.getLogger(classOf[KVStore])
private var updaterFunc: MXKVStoreUpdater = null
private var disposed = false
protected def isDisposed = disposed

/**
* Release the native memory.
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
if (!disposed) {
_LIB.mxKVStoreFree(handle)
disposed = true
}
}
override def nativeAddress: CPtrAddress = handle

override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxKVStoreFree

override val ref: NativeResourceRef = super.register()

override val bytesAllocated: Long = 0L

/**
* Initialize a single or a sequence of key-value pairs into the store.
Expand Down
122 changes: 63 additions & 59 deletions scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ object Model {
workLoadList: Seq[Float] = Nil,
monitor: Option[Monitor] = None,
symGen: SymbolGenerator = null): Unit = {
val executorManager = new DataParallelExecutorManager(
ResourceScope.using() {

val executorManager = new DataParallelExecutorManager(
symbol = symbol,
symGen = symGen,
ctx = ctx,
Expand All @@ -269,17 +271,17 @@ object Model {
auxNames = auxNames,
workLoadList = workLoadList)

monitor.foreach(executorManager.installMonitor)
executorManager.setParams(argParams, auxParams)
monitor.foreach(executorManager.installMonitor)
executorManager.setParams(argParams, auxParams)

// updater for updateOnKVStore = false
val updaterLocal = Optimizer.getUpdater(optimizer)
// updater for updateOnKVStore = false
val updaterLocal = Optimizer.getUpdater(optimizer)

kvStore.foreach(initializeKVStore(_, executorManager.paramArrays,
argParams, executorManager.paramNames, updateOnKVStore))
if (updateOnKVStore) {
kvStore.foreach(_.setOptimizer(optimizer))
}
kvStore.foreach(initializeKVStore(_, executorManager.paramArrays,
argParams, executorManager.paramNames, updateOnKVStore))
if (updateOnKVStore) {
kvStore.foreach(_.setOptimizer(optimizer))
}

// Now start training
for (epoch <- beginEpoch until endEpoch) {
Expand All @@ -290,66 +292,69 @@ object Model {
var epochDone = false
// Iterate over training data.
trainData.reset()
while (!epochDone) {
var doReset = true
while (doReset && trainData.hasNext) {
val dataBatch = trainData.next()
executorManager.loadDataBatch(dataBatch)
monitor.foreach(_.tic())
executorManager.forward(isTrain = true)
executorManager.backward()
if (updateOnKVStore) {
updateParamsOnKVStore(executorManager.paramArrays,
executorManager.gradArrays,
kvStore, executorManager.paramNames)
} else {
updateParams(executorManager.paramArrays,
executorManager.gradArrays,
updaterLocal, ctx.length,
executorManager.paramNames,
kvStore)
}
monitor.foreach(_.tocPrint())
// evaluate at end, so out_cpu_array can lazy copy
executorManager.updateMetric(evalMetric, dataBatch.label)
ResourceScope.using() {
while (!epochDone) {
var doReset = true
while (doReset && trainData.hasNext) {
val dataBatch = trainData.next()
executorManager.loadDataBatch(dataBatch)
monitor.foreach(_.tic())
executorManager.forward(isTrain = true)
executorManager.backward()
if (updateOnKVStore) {
updateParamsOnKVStore(executorManager.paramArrays,
executorManager.gradArrays,
kvStore, executorManager.paramNames)
} else {
updateParams(executorManager.paramArrays,
executorManager.gradArrays,
updaterLocal, ctx.length,
executorManager.paramNames,
kvStore)
}
monitor.foreach(_.tocPrint())
// evaluate at end, so out_cpu_array can lazy copy
executorManager.updateMetric(evalMetric, dataBatch.label)

nBatch += 1
batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric))
nBatch += 1
batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric))

// this epoch is done possibly earlier
if (epochSize != -1 && nBatch >= epochSize) {
doReset = false
// this epoch is done possibly earlier
if (epochSize != -1 && nBatch >= epochSize) {
doReset = false
}
}
if (doReset) {
trainData.reset()
}
}
if (doReset) {
trainData.reset()
}

// this epoch is done
epochDone = (epochSize == -1 || nBatch >= epochSize)
// this epoch is done
epochDone = (epochSize == -1 || nBatch >= epochSize)
}
}

val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Train-$n=$v")
}
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")

evalData.foreach { evalDataIter =>
evalMetric.reset()
evalDataIter.reset()
// TODO: make DataIter implement Iterator
while (evalDataIter.hasNext) {
val evalBatch = evalDataIter.next()
executorManager.loadDataBatch(evalBatch)
executorManager.forward(isTrain = false)
executorManager.updateMetric(evalMetric, evalBatch.label)
}
ResourceScope.using() {
evalData.foreach { evalDataIter =>
evalMetric.reset()
evalDataIter.reset()
// TODO: make DataIter implement Iterator
while (evalDataIter.hasNext) {
val evalBatch = evalDataIter.next()
executorManager.loadDataBatch(evalBatch)
executorManager.forward(isTrain = false)
executorManager.updateMetric(evalMetric, evalBatch.label)
}

val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Train-$n=$v")
val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Validation-$n=$v")
}
}
}

Expand All @@ -359,8 +364,7 @@ object Model {
epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams))
}

updaterLocal.dispose()
executorManager.dispose()
}
}
// scalastyle:on parameterNum
}
Expand Down
18 changes: 11 additions & 7 deletions scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -562,16 +562,20 @@ object NDArray extends NDArrayBase {
*/
class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
val writable: Boolean = true,
addToCollector: Boolean = true) extends WarnIfNotDisposed {
addToCollector: Boolean = true) extends NativeResource {
if (addToCollector) {
NDArrayCollector.collect(this)
}

override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree
override val bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product

override val ref: NativeResourceRef = super.register()

// record arrays who construct this array instance
// we use weak reference to prevent gc blocking
private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]
@volatile private var disposed = false
def isDisposed: Boolean = disposed

def serialize(): Array[Byte] = {
val buf = ArrayBuffer.empty[Byte]
Expand All @@ -584,11 +588,10 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
* The NDArrays it depends on will NOT be disposed. <br />
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
if (!disposed) {
_LIB.mxNDArrayFree(handle)
override def dispose(): Unit = {
if (!super.isDisposed) {
super.dispose()
dependencies.clear()
disposed = true
}
}

Expand Down Expand Up @@ -1034,6 +1037,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
// TODO: naive implementation
shape.hashCode + toArray.hashCode
}

}

private[mxnet] object NDArrayConversions {
Expand Down
Loading

0 comments on commit c3693a4

Please sign in to comment.