Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

NativeResource Management in Scala #12647

Merged
merged 21 commits into from
Oct 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
db87a2b
add Generic MXNetHandle trait and MXNetHandlePhantomRef class that wi…
nswamy Aug 27, 2018
cba8a43
use nswamy@ personal repo for mac testing
nswamy Aug 27, 2018
34106a4
Generic Handle with AutoCloseable
nswamy Aug 28, 2018
373ac78
add NativeResource and NativeResourceManager with Periodic GC calling
nswamy Aug 30, 2018
e0016d7
use NativeResource trait in NDArray, Symbol and Executor
nswamy Aug 30, 2018
5cd3cd3
add run train mnist script
nswamy Aug 30, 2018
ef4bfe8
create a Generic ResourceScope that can collect all NativeResources t…
nswamy Sep 4, 2018
c04e4f0
modify NativeResource and ResourceScope, extend NativeResource in NDA…
nswamy Sep 7, 2018
bb36934
remove GCExecutor
nswamy Sep 7, 2018
9d92dc1
deRegister PhantomReferences by when calling dispose()
nswamy Sep 7, 2018
1465717
add Finalizer(temporary) to NativeResource
nswamy Sep 7, 2018
e9b4b70
refactor NativeResource.dispose() method
nswamy Sep 7, 2018
dd294f0
update NativeResource/add Unit Test for NativeResource
nswamy Sep 21, 2018
980db5a
updates to NativeResource/NativeResourceRef and unit tests to NativeR…
nswamy Sep 24, 2018
2b0b073
remove redundant code added because of the object equality that was n…
nswamy Oct 12, 2018
18b1175
add ResourceScope
nswamy Oct 12, 2018
e2d5c99
Fix NativeResource to not remove from Scope, add Unit Tests to Resour…
nswamy Oct 14, 2018
21140cf
cleanup log/print debug statements
nswamy Oct 14, 2018
362ae18
use TreeSet inplace of ArrayBuffer to speedup removal of resources fr…
nswamy Oct 15, 2018
d78f571
fix segfault that was happening because of NDArray creation on the fl…
nswamy Oct 18, 2018
f0e873b
Add comments for dispose(param:Boolean)
nswamy Oct 18, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,12 @@
<artifactId>commons-io</artifactId>
<version>2.1</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.mockito/mockito-all -->
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this module for the Unittest?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object Executor {
* @see Symbol.bind : to create executor
*/
class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] val symbol: Symbol) extends WarnIfNotDisposed {
private[mxnet] val symbol: Symbol) extends NativeResource {
private[mxnet] var argArrays: Array[NDArray] = null
private[mxnet] var gradArrays: Array[NDArray] = null
private[mxnet] var auxArrays: Array[NDArray] = null
Expand All @@ -59,14 +59,15 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] var _group2ctx: Map[String, Context] = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])

private var disposed = false
protected def isDisposed = disposed

def dispose(): Unit = {
if (!disposed) {
outputs.foreach(_.dispose())
_LIB.mxExecutorFree(handle)
disposed = true
override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
// cannot determine the off-heap size of this object
override val bytesAllocated: Long = 0
override val ref: NativeResourceRef = super.register()
override def dispose(): Unit = {
if (!super.isDisposed) {
super.dispose()
outputs.foreach(o => o.dispose())
}
}

Expand Down Expand Up @@ -305,4 +306,5 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
checkCall(_LIB.mxExecutorPrint(handle, str))
str.value
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,17 @@ object KVStore {
}
}

class KVStore(private[mxnet] val handle: KVStoreHandle) extends WarnIfNotDisposed {
class KVStore(private[mxnet] val handle: KVStoreHandle) extends NativeResource {
private val logger: Logger = LoggerFactory.getLogger(classOf[KVStore])
private var updaterFunc: MXKVStoreUpdater = null
private var disposed = false
protected def isDisposed = disposed

/**
* Release the native memory.
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
if (!disposed) {
_LIB.mxKVStoreFree(handle)
disposed = true
}
}
override def nativeAddress: CPtrAddress = handle

override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxKVStoreFree

override val ref: NativeResourceRef = super.register()

override val bytesAllocated: Long = 0L

/**
* Initialize a single or a sequence of key-value pairs into the store.
Expand Down
122 changes: 63 additions & 59 deletions scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ object Model {
workLoadList: Seq[Float] = Nil,
monitor: Option[Monitor] = None,
symGen: SymbolGenerator = null): Unit = {
val executorManager = new DataParallelExecutorManager(
ResourceScope.using() {

val executorManager = new DataParallelExecutorManager(
symbol = symbol,
symGen = symGen,
ctx = ctx,
Expand All @@ -269,17 +271,17 @@ object Model {
auxNames = auxNames,
workLoadList = workLoadList)

monitor.foreach(executorManager.installMonitor)
executorManager.setParams(argParams, auxParams)
monitor.foreach(executorManager.installMonitor)
executorManager.setParams(argParams, auxParams)

// updater for updateOnKVStore = false
val updaterLocal = Optimizer.getUpdater(optimizer)
// updater for updateOnKVStore = false
val updaterLocal = Optimizer.getUpdater(optimizer)

kvStore.foreach(initializeKVStore(_, executorManager.paramArrays,
argParams, executorManager.paramNames, updateOnKVStore))
if (updateOnKVStore) {
kvStore.foreach(_.setOptimizer(optimizer))
}
kvStore.foreach(initializeKVStore(_, executorManager.paramArrays,
argParams, executorManager.paramNames, updateOnKVStore))
if (updateOnKVStore) {
kvStore.foreach(_.setOptimizer(optimizer))
}

// Now start training
for (epoch <- beginEpoch until endEpoch) {
Expand All @@ -290,66 +292,69 @@ object Model {
var epochDone = false
// Iterate over training data.
trainData.reset()
while (!epochDone) {
var doReset = true
while (doReset && trainData.hasNext) {
val dataBatch = trainData.next()
executorManager.loadDataBatch(dataBatch)
monitor.foreach(_.tic())
executorManager.forward(isTrain = true)
executorManager.backward()
if (updateOnKVStore) {
updateParamsOnKVStore(executorManager.paramArrays,
executorManager.gradArrays,
kvStore, executorManager.paramNames)
} else {
updateParams(executorManager.paramArrays,
executorManager.gradArrays,
updaterLocal, ctx.length,
executorManager.paramNames,
kvStore)
}
monitor.foreach(_.tocPrint())
// evaluate at end, so out_cpu_array can lazy copy
executorManager.updateMetric(evalMetric, dataBatch.label)
ResourceScope.using() {
while (!epochDone) {
var doReset = true
while (doReset && trainData.hasNext) {
val dataBatch = trainData.next()
executorManager.loadDataBatch(dataBatch)
monitor.foreach(_.tic())
executorManager.forward(isTrain = true)
executorManager.backward()
if (updateOnKVStore) {
updateParamsOnKVStore(executorManager.paramArrays,
executorManager.gradArrays,
kvStore, executorManager.paramNames)
} else {
updateParams(executorManager.paramArrays,
executorManager.gradArrays,
updaterLocal, ctx.length,
executorManager.paramNames,
kvStore)
}
monitor.foreach(_.tocPrint())
// evaluate at end, so out_cpu_array can lazy copy
executorManager.updateMetric(evalMetric, dataBatch.label)

nBatch += 1
batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric))
nBatch += 1
batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric))

// this epoch is done possibly earlier
if (epochSize != -1 && nBatch >= epochSize) {
doReset = false
// this epoch is done possibly earlier
if (epochSize != -1 && nBatch >= epochSize) {
doReset = false
}
}
if (doReset) {
trainData.reset()
}
}
if (doReset) {
trainData.reset()
}

// this epoch is done
epochDone = (epochSize == -1 || nBatch >= epochSize)
// this epoch is done
epochDone = (epochSize == -1 || nBatch >= epochSize)
}
}

val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Train-$n=$v")
}
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")

evalData.foreach { evalDataIter =>
evalMetric.reset()
evalDataIter.reset()
// TODO: make DataIter implement Iterator
while (evalDataIter.hasNext) {
val evalBatch = evalDataIter.next()
executorManager.loadDataBatch(evalBatch)
executorManager.forward(isTrain = false)
executorManager.updateMetric(evalMetric, evalBatch.label)
}
ResourceScope.using() {
evalData.foreach { evalDataIter =>
evalMetric.reset()
evalDataIter.reset()
// TODO: make DataIter implement Iterator
while (evalDataIter.hasNext) {
val evalBatch = evalDataIter.next()
executorManager.loadDataBatch(evalBatch)
executorManager.forward(isTrain = false)
executorManager.updateMetric(evalMetric, evalBatch.label)
}

val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Train-$n=$v")
val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Validation-$n=$v")
}
}
}

Expand All @@ -359,8 +364,7 @@ object Model {
epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams))
}

updaterLocal.dispose()
executorManager.dispose()
}
}
// scalastyle:on parameterNum
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,16 +562,20 @@ object NDArray extends NDArrayBase {
*/
class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
val writable: Boolean = true,
addToCollector: Boolean = true) extends WarnIfNotDisposed {
addToCollector: Boolean = true) extends NativeResource {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why extend this? is that possible to make it composite (such as create a NativeResource Object in NDArray)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the right patten. NDArray is a NativeResource and all the functionality of NativeResource is provided through inheritance provided by NDArray.

if (addToCollector) {
NDArrayCollector.collect(this)
}

override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree
override val bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product

override val ref: NativeResourceRef = super.register()

// record arrays who construct this array instance
// we use weak reference to prevent gc blocking
private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]
@volatile private var disposed = false
def isDisposed: Boolean = disposed

def serialize(): Array[Byte] = {
val buf = ArrayBuffer.empty[Byte]
Expand All @@ -584,11 +588,10 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
* The NDArrays it depends on will NOT be disposed. <br />
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
if (!disposed) {
_LIB.mxNDArrayFree(handle)
override def dispose(): Unit = {
if (!super.isDisposed) {
super.dispose()
dependencies.clear()
disposed = true
}
}

Expand Down Expand Up @@ -1034,6 +1037,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
// TODO: naive implementation
shape.hashCode + toArray.hashCode
}

}

private[mxnet] object NDArrayConversions {
Expand Down
Loading