diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index ea3a2d68c9f4..e93169f08faa 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -123,5 +123,12 @@ commons-io 2.1 + + + org.mockito + mockito-all + 1.10.19 + test + diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index fc791d5cd9a3..19fb6fe5cee5 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -45,7 +45,7 @@ object Executor { * @see Symbol.bind : to create executor */ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, - private[mxnet] val symbol: Symbol) extends WarnIfNotDisposed { + private[mxnet] val symbol: Symbol) extends NativeResource { private[mxnet] var argArrays: Array[NDArray] = null private[mxnet] var gradArrays: Array[NDArray] = null private[mxnet] var auxArrays: Array[NDArray] = null @@ -59,14 +59,15 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, private[mxnet] var _group2ctx: Map[String, Context] = null private val logger: Logger = LoggerFactory.getLogger(classOf[Executor]) - private var disposed = false - protected def isDisposed = disposed - - def dispose(): Unit = { - if (!disposed) { - outputs.foreach(_.dispose()) - _LIB.mxExecutorFree(handle) - disposed = true + override def nativeAddress: CPtrAddress = handle + override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree + // cannot determine the off-heap size of this object + override val bytesAllocated: Long = 0 + override val ref: NativeResourceRef = super.register() + override def dispose(): Unit = { + if (!super.isDisposed) { + super.dispose() + outputs.foreach(o => o.dispose()) } } @@ -305,4 +306,5 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, checkCall(_LIB.mxExecutorPrint(handle, str)) str.value } + } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala b/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala index 8e89ce76b877..45189a13aefc 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala @@ -52,22 +52,17 @@ object KVStore { } } -class KVStore(private[mxnet] val handle: KVStoreHandle) extends WarnIfNotDisposed { +class KVStore(private[mxnet] val handle: KVStoreHandle) extends NativeResource { private val logger: Logger = LoggerFactory.getLogger(classOf[KVStore]) private var updaterFunc: MXKVStoreUpdater = null - private var disposed = false - protected def isDisposed = disposed - /** - * Release the native memory. - * The object shall never be used after it is disposed. - */ - def dispose(): Unit = { - if (!disposed) { - _LIB.mxKVStoreFree(handle) - disposed = true - } - } + override def nativeAddress: CPtrAddress = handle + + override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxKVStoreFree + + override val ref: NativeResourceRef = super.register() + + override val bytesAllocated: Long = 0L /** * Initialize a single or a sequence of key-value pairs into the store. diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala index 4bb9cdd331a6..b835c4964dd0 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala @@ -259,7 +259,9 @@ object Model { workLoadList: Seq[Float] = Nil, monitor: Option[Monitor] = None, symGen: SymbolGenerator = null): Unit = { - val executorManager = new DataParallelExecutorManager( + ResourceScope.using() { + + val executorManager = new DataParallelExecutorManager( symbol = symbol, symGen = symGen, ctx = ctx, @@ -269,17 +271,17 @@ object Model { auxNames = auxNames, workLoadList = workLoadList) - monitor.foreach(executorManager.installMonitor) - executorManager.setParams(argParams, auxParams) + monitor.foreach(executorManager.installMonitor) + executorManager.setParams(argParams, auxParams) - // updater for updateOnKVStore = false - val updaterLocal = Optimizer.getUpdater(optimizer) + // updater for updateOnKVStore = false + val updaterLocal = Optimizer.getUpdater(optimizer) - kvStore.foreach(initializeKVStore(_, executorManager.paramArrays, - argParams, executorManager.paramNames, updateOnKVStore)) - if (updateOnKVStore) { - kvStore.foreach(_.setOptimizer(optimizer)) - } + kvStore.foreach(initializeKVStore(_, executorManager.paramArrays, + argParams, executorManager.paramNames, updateOnKVStore)) + if (updateOnKVStore) { + kvStore.foreach(_.setOptimizer(optimizer)) + } // Now start training for (epoch <- beginEpoch until endEpoch) { @@ -290,45 +292,46 @@ object Model { var epochDone = false // Iterate over training data. trainData.reset() - while (!epochDone) { - var doReset = true - while (doReset && trainData.hasNext) { - val dataBatch = trainData.next() - executorManager.loadDataBatch(dataBatch) - monitor.foreach(_.tic()) - executorManager.forward(isTrain = true) - executorManager.backward() - if (updateOnKVStore) { - updateParamsOnKVStore(executorManager.paramArrays, - executorManager.gradArrays, - kvStore, executorManager.paramNames) - } else { - updateParams(executorManager.paramArrays, - executorManager.gradArrays, - updaterLocal, ctx.length, - executorManager.paramNames, - kvStore) - } - monitor.foreach(_.tocPrint()) - // evaluate at end, so out_cpu_array can lazy copy - executorManager.updateMetric(evalMetric, dataBatch.label) + ResourceScope.using() { + while (!epochDone) { + var doReset = true + while (doReset && trainData.hasNext) { + val dataBatch = trainData.next() + executorManager.loadDataBatch(dataBatch) + monitor.foreach(_.tic()) + executorManager.forward(isTrain = true) + executorManager.backward() + if (updateOnKVStore) { + updateParamsOnKVStore(executorManager.paramArrays, + executorManager.gradArrays, + kvStore, executorManager.paramNames) + } else { + updateParams(executorManager.paramArrays, + executorManager.gradArrays, + updaterLocal, ctx.length, + executorManager.paramNames, + kvStore) + } + monitor.foreach(_.tocPrint()) + // evaluate at end, so out_cpu_array can lazy copy + executorManager.updateMetric(evalMetric, dataBatch.label) - nBatch += 1 - batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric)) + nBatch += 1 + batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric)) - // this epoch is done possibly earlier - if (epochSize != -1 && nBatch >= epochSize) { - doReset = false + // this epoch is done possibly earlier + if (epochSize != -1 && nBatch >= epochSize) { + doReset = false + } + } + if (doReset) { + trainData.reset() } - } - if (doReset) { - trainData.reset() - } - // this epoch is done - epochDone = (epochSize == -1 || nBatch >= epochSize) + // this epoch is done + epochDone = (epochSize == -1 || nBatch >= epochSize) + } } - val (name, value) = evalMetric.get name.zip(value).foreach { case (n, v) => logger.info(s"Epoch[$epoch] Train-$n=$v") @@ -336,20 +339,22 @@ object Model { val toc = System.currentTimeMillis logger.info(s"Epoch[$epoch] Time cost=${toc - tic}") - evalData.foreach { evalDataIter => - evalMetric.reset() - evalDataIter.reset() - // TODO: make DataIter implement Iterator - while (evalDataIter.hasNext) { - val evalBatch = evalDataIter.next() - executorManager.loadDataBatch(evalBatch) - executorManager.forward(isTrain = false) - executorManager.updateMetric(evalMetric, evalBatch.label) - } + ResourceScope.using() { + evalData.foreach { evalDataIter => + evalMetric.reset() + evalDataIter.reset() + // TODO: make DataIter implement Iterator + while (evalDataIter.hasNext) { + val evalBatch = evalDataIter.next() + executorManager.loadDataBatch(evalBatch) + executorManager.forward(isTrain = false) + executorManager.updateMetric(evalMetric, evalBatch.label) + } - val (name, value) = evalMetric.get - name.zip(value).foreach { case (n, v) => - logger.info(s"Epoch[$epoch] Train-$n=$v") + val (name, value) = evalMetric.get + name.zip(value).foreach { case (n, v) => + logger.info(s"Epoch[$epoch] Validation-$n=$v") + } } } @@ -359,8 +364,7 @@ object Model { epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams)) } - updaterLocal.dispose() - executorManager.dispose() + } } // scalastyle:on parameterNum } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 9b6a7dc66540..f2a7603caa85 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -562,16 +562,20 @@ object NDArray extends NDArrayBase { */ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, val writable: Boolean = true, - addToCollector: Boolean = true) extends WarnIfNotDisposed { + addToCollector: Boolean = true) extends NativeResource { if (addToCollector) { NDArrayCollector.collect(this) } + override def nativeAddress: CPtrAddress = handle + override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree + override val bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product + + override val ref: NativeResourceRef = super.register() + // record arrays who construct this array instance // we use weak reference to prevent gc blocking private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]] - @volatile private var disposed = false - def isDisposed: Boolean = disposed def serialize(): Array[Byte] = { val buf = ArrayBuffer.empty[Byte] @@ -584,11 +588,10 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, * The NDArrays it depends on will NOT be disposed.
* The object shall never be used after it is disposed. */ - def dispose(): Unit = { - if (!disposed) { - _LIB.mxNDArrayFree(handle) + override def dispose(): Unit = { + if (!super.isDisposed) { + super.dispose() dependencies.clear() - disposed = true } } @@ -1034,6 +1037,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, // TODO: naive implementation shape.hashCode + toArray.hashCode } + } private[mxnet] object NDArrayConversions { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala new file mode 100644 index 000000000000..48d4b0c193b1 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.apache.mxnet.Base.CPtrAddress +import java.lang.ref.{PhantomReference, ReferenceQueue, WeakReference} +import java.util.concurrent._ + +import org.apache.mxnet.Base.checkCall +import java.util.concurrent.atomic.AtomicLong + + +/** + * NativeResource trait is used to manage MXNet Objects + * such as NDArray, Symbol, Executor, etc., + * The MXNet Object calls NativeResource.register + * and assign the returned NativeResourceRef to PhantomReference + * NativeResource also implements AutoCloseable so MXNetObjects + * can be used like Resources in try-with-resources paradigm + */ +private[mxnet] trait NativeResource + extends AutoCloseable with WarnIfNotDisposed { + + /** + * native Address associated with this object + */ + def nativeAddress: CPtrAddress + + /** + * Function Pointer to the NativeDeAllocator of nativeAddress + */ + def nativeDeAllocator: (CPtrAddress => Int) + + /** Call NativeResource.register to get the reference + */ + val ref: NativeResourceRef + + /** + * Off-Heap Bytes Allocated for this object + */ + // intentionally making it a val, so it gets evaluated when defined + val bytesAllocated: Long + + private[mxnet] var scope: Option[ResourceScope] = None + + @volatile private var disposed = false + + override def isDisposed: Boolean = disposed || isDeAllocated + + /** + * Register this object for PhantomReference tracking and in + * ResourceScope if used inside ResourceScope. + * @return NativeResourceRef that tracks reachability of this object + * using PhantomReference + */ + def register(): NativeResourceRef = { + scope = ResourceScope.getCurrentScope() + if (scope.isDefined) scope.get.add(this) + + NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated) + // register with PhantomRef tracking to release incase the objects go + // out of reference within scope but are held for long time + NativeResourceRef.register(this, nativeDeAllocator) + } + + // Implements [[@link AutoCloseable.close]] + override def close(): Unit = { + dispose() + } + + // Implements [[@link WarnIfNotDisposed.dispose]] + def dispose(): Unit = dispose(true) + + /** + * This method deAllocates nativeResource and deRegisters + * from PhantomRef and removes from Scope if + * removeFromScope is set to true. + * @param removeFromScope remove from the currentScope if true + */ + // the parameter here controls whether to remove from current scope. + // [[ResourceScope.close]] calls NativeResource.dispose + // if we remove from the ResourceScope ie., from the container in ResourceScope. + // while iterating on the container, calling iterator.next is undefined and not safe. + // Note that ResourceScope automatically disposes all the resources within. + private[mxnet] def dispose(removeFromScope: Boolean = true): Unit = { + if (!disposed) { + checkCall(nativeDeAllocator(this.nativeAddress)) + NativeResourceRef.deRegister(ref) // removes from PhantomRef tracking + if (removeFromScope && scope.isDefined) scope.get.remove(this) + NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated) + disposed = true + } + } + + /* + this is used by the WarnIfNotDisposed finalizer, + the object could be disposed by the GC without the need for explicit disposal + but the finalizer might not have run, then the WarnIfNotDisposed throws a warning + */ + private[mxnet] def isDeAllocated(): Boolean = NativeResourceRef.isDeAllocated(ref) + +} + +private[mxnet] object NativeResource { + var totalBytesAllocated : AtomicLong = new AtomicLong(0) +} + +// Do not make [[NativeResource.resource]] a member of the class, +// this will hold reference and GC will not clear the object. +private[mxnet] class NativeResourceRef(resource: NativeResource, + val resourceDeAllocator: CPtrAddress => Int) + extends PhantomReference[NativeResource](resource, NativeResourceRef.refQ) {} + +private[mxnet] object NativeResourceRef { + + private[mxnet] val refQ: ReferenceQueue[NativeResource] + = new ReferenceQueue[NativeResource] + + private[mxnet] val refMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]() + + private[mxnet] val cleaner = new ResourceCleanupThread() + + cleaner.start() + + def register(resource: NativeResource, nativeDeAllocator: (CPtrAddress => Int)): + NativeResourceRef = { + val ref = new NativeResourceRef(resource, nativeDeAllocator) + refMap.put(ref, resource.nativeAddress) + ref + } + + // remove from PhantomRef tracking + def deRegister(ref: NativeResourceRef): Unit = refMap.remove(ref) + + /** + * This method will check if the cleaner ran and deAllocated the object + * As a part of GC, when the object is unreachable GC inserts a phantomRef + * to the ReferenceQueue which the cleaner thread will deallocate, however + * the finalizer runs much later depending on the GC. + * @param resource resource to verify if it has been deAllocated + * @return true if already deAllocated + */ + def isDeAllocated(ref: NativeResourceRef): Boolean = { + !refMap.containsKey(ref) + } + + def cleanup: Unit = { + // remove is a blocking call + val ref: NativeResourceRef = refQ.remove().asInstanceOf[NativeResourceRef] + // phantomRef will be removed from the map when NativeResource.close is called. + val resource = refMap.get(ref) + if (resource != 0L) { // since CPtrAddress is Scala a Long, it cannot be null + ref.resourceDeAllocator(resource) + refMap.remove(ref) + } + } + + protected class ResourceCleanupThread extends Thread { + setPriority(Thread.MAX_PRIORITY) + setName("NativeResourceDeAllocatorThread") + setDaemon(true) + + override def run(): Unit = { + while (true) { + try { + NativeResourceRef.cleanup + } + catch { + case _: InterruptedException => Thread.currentThread().interrupt() + } + } + } + } +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala index 758cbc829618..c3f8aaec6d60 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala @@ -19,6 +19,8 @@ package org.apache.mxnet import java.io._ +import org.apache.mxnet.Base.CPtrAddress + import scala.collection.mutable import scala.util.Either @@ -38,8 +40,10 @@ object Optimizer { } override def dispose(): Unit = { - states.values.foreach(optimizer.disposeState) - states.clear() + if (!super.isDisposed) { + states.values.foreach(optimizer.disposeState) + states.clear() + } } override def serializeState(): Array[Byte] = { @@ -285,7 +289,8 @@ abstract class Optimizer extends Serializable { } } -trait MXKVStoreUpdater { +trait MXKVStoreUpdater extends + NativeResource { /** * user-defined updater for the kvstore * It's this updater's responsibility to delete recv and local @@ -294,9 +299,14 @@ trait MXKVStoreUpdater { * @param local the value stored on local on this key */ def update(key: Int, recv: NDArray, local: NDArray): Unit - def dispose(): Unit - // def serializeState(): Array[Byte] - // def deserializeState(bytes: Array[Byte]): Unit + + // This is a hack to make Optimizers work with ResourceScope + // otherwise the user has to manage calling dispose on this object. + override def nativeAddress: CPtrAddress = hashCode() + override def nativeDeAllocator: CPtrAddress => Int = doNothingDeAllocator + private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0 + override val ref: NativeResourceRef = super.register() + override val bytesAllocated: Long = 0L } trait MXKVStoreCachedStates { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala new file mode 100644 index 000000000000..1c5782d873a9 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import java.util.HashSet + +import org.slf4j.LoggerFactory + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Try +import scala.util.control.{ControlThrowable, NonFatal} + +/** + * This class manages automatically releasing of [[NativeResource]]s + */ +class ResourceScope extends AutoCloseable { + + // HashSet does not take a custom comparator + private[mxnet] val resourceQ = new mutable.TreeSet[NativeResource]()(nativeAddressOrdering) + + private object nativeAddressOrdering extends Ordering[NativeResource] { + def compare(a: NativeResource, b: NativeResource): Int = { + a.nativeAddress compare b.nativeAddress + } + } + + ResourceScope.addToThreadLocal(this) + + /** + * Releases all the [[NativeResource]] by calling + * the associated [[NativeResource.close()]] method + */ + override def close(): Unit = { + ResourceScope.removeFromThreadLocal(this) + resourceQ.foreach(resource => if (resource != null) resource.dispose(false) ) + resourceQ.clear() + } + + /** + * Add a NativeResource to the scope + * @param resource + */ + def add(resource: NativeResource): Unit = { + resourceQ.+=(resource) + } + + /** + * Remove NativeResource from the Scope, this uses + * object equality to find the resource in the stack. + * @param resource + */ + def remove(resource: NativeResource): Unit = { + resourceQ.-=(resource) + } +} + +object ResourceScope { + + private val logger = LoggerFactory.getLogger(classOf[ResourceScope]) + + /** + * Captures all Native Resources created using the ResourceScope and + * at the end of the body, de allocates all the Native resources by calling close on them. + * This method will not deAllocate NativeResources returned from the block. + * @param scope (Optional). Scope in which to capture the native resources + * @param body block of code to execute in this scope + * @tparam A return type + * @return result of the operation, if the result is of type NativeResource, it is not + * de allocated so the user can use it and then de allocate manually by calling + * close or enclose in another resourceScope. + */ + // inspired from slide 21 of https://www.slideshare.net/Odersky/fosdem-2009-1013261 + // and https://github.com/scala/scala/blob/2.13.x/src/library/scala/util/Using.scala + // TODO: we should move to the Scala util's Using method when we move to Scala 2.13 + def using[A](scope: ResourceScope = null)(body: => A): A = { + + val curScope = if (scope != null) scope else new ResourceScope() + + val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope() + + @inline def resourceInGeneric(g: scala.collection.Iterable[_]) = { + g.foreach( n => + n match { + case nRes: NativeResource => { + removeAndAddToPrevScope(nRes) + } + case kv: scala.Tuple2[_, _] => { + if (kv._1.isInstanceOf[NativeResource]) removeAndAddToPrevScope( + kv._1.asInstanceOf[NativeResource]) + if (kv._2.isInstanceOf[NativeResource]) removeAndAddToPrevScope( + kv._2.asInstanceOf[NativeResource]) + } + } + ) + } + + @inline def removeAndAddToPrevScope(r: NativeResource) = { + curScope.remove(r) + if (prevScope.isDefined) { + prevScope.get.add(r) + r.scope = prevScope + } + } + + @inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = { + if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed) + } + + var retThrowable: Throwable = null + + try { + val ret = body + ret match { + // don't de-allocate if returning any collection that contains NativeResource. + case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric) + case nRes: NativeResource => removeAndAddToPrevScope(nRes) + case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => removeAndAddToPrevScope(nd) ) + case _ => // do nothing + } + ret + } catch { + case t: Throwable => + retThrowable = t + null.asInstanceOf[A] // we'll throw in finally + } finally { + var toThrow: Throwable = retThrowable + if (retThrowable eq null) curScope.close() + else { + try { + curScope.close + } catch { + case closeThrowable: Throwable => + if (NonFatal(retThrowable) && !NonFatal(closeThrowable)) toThrow = closeThrowable + else safeAddSuppressed(retThrowable, closeThrowable) + } finally { + throw toThrow + } + } + } + } + + // thread local Scopes + private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] { + override def initialValue(): ArrayBuffer[ResourceScope] = + new ArrayBuffer[ResourceScope]() + } + + /** + * Add resource to current ThreadLocal DataStructure + * @param r ResourceScope to add. + */ + private[mxnet] def addToThreadLocal(r: ResourceScope): Unit = { + threadLocalScopes.get() += r + } + + /** + * Remove resource from current ThreadLocal DataStructure + * @param r ResourceScope to remove + */ + private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = { + threadLocalScopes.get() -= r + } + + /** + * Get the latest Scope in the stack + * @return + */ + private[mxnet] def getCurrentScope(): Option[ResourceScope] = { + Try(Some(threadLocalScopes.get().last)).getOrElse(None) + } + + /** + * Get the Last but one Scope from threadLocal Scopes. + * @return n-1th scope or None when not found + */ + private[mxnet] def getPrevScope(): Option[ResourceScope] = { + val scopes = threadLocalScopes.get() + Try(Some(scopes(scopes.size - 2))).getOrElse(None) + } +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index b1a3e392f41e..a009e7e343f2 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -29,21 +29,15 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} * WARNING: it is your responsibility to clear this object through dispose(). * */ -class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotDisposed { +class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeResource { private val logger: Logger = LoggerFactory.getLogger(classOf[Symbol]) - private var disposed = false - protected def isDisposed = disposed - /** - * Release the native memory. - * The object shall never be used after it is disposed. - */ - def dispose(): Unit = { - if (!disposed) { - _LIB.mxSymbolFree(handle) - disposed = true - } - } + // unable to get the byteAllocated for Symbol + override val bytesAllocated: Long = 0L + override def nativeAddress: CPtrAddress = handle + override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxSymbolFree + override val ref: NativeResourceRef = super.register() + def +(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Plus")(Array(this, other)) def +[@specialized(Int, Float, Double) V](other: V): Symbol = { @@ -793,7 +787,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD } val execHandle = new ExecutorHandleRef - val sharedHadle = if (sharedExec != null) sharedExec.handle else 0L + val sharedHandle = if (sharedExec != null) sharedExec.handle else 0L checkCall(_LIB.mxExecutorBindEX(handle, ctx.deviceTypeid, ctx.deviceId, @@ -806,7 +800,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD argsGradHandle, reqsArray, auxArgsHandle, - sharedHadle, + sharedHandle, execHandle)) val executor = new Executor(execHandle.value, this.clone()) executor.argArrays = argsNDArray @@ -832,6 +826,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD checkCall(_LIB.mxSymbolSaveToJSON(handle, jsonStr)) jsonStr.value } + } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index f7f858deb82d..998017750db2 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -33,7 +33,7 @@ import scala.collection.mutable.ListBuffer private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, dataName: String = "data", labelName: String = "label") - extends DataIter with WarnIfNotDisposed { + extends DataIter with NativeResource { private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) @@ -67,20 +67,13 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, } } + override def nativeAddress: CPtrAddress = handle - private var disposed = false - protected def isDisposed = disposed + override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxDataIterFree - /** - * Release the native memory. - * The object shall never be used after it is disposed. - */ - def dispose(): Unit = { - if (!disposed) { - _LIB.mxDataIterFree(handle) - disposed = true - } - } + override val ref: NativeResourceRef = super.register() + + override val bytesAllocated: Long = 0L /** * reset the iterator diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala index e20b433ed1ed..d349feac3e93 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala @@ -17,7 +17,7 @@ package org.apache.mxnet.optimizer -import org.apache.mxnet.{Optimizer, LRScheduler, NDArray} +import org.apache.mxnet._ import org.apache.mxnet.NDArrayConversions._ /** @@ -92,7 +92,13 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f, if (momentum == 0.0f) { null } else { - NDArray.zeros(weight.shape, weight.context) + val s = NDArray.zeros(weight.shape, weight.context) + // this is created on the fly and shared between runs, + // we don't want it to be dispose from the scope + // and should be handled by the dispose + val scope = ResourceScope.getCurrentScope() + if (scope.isDefined) scope.get.remove(s) + s } } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala new file mode 100644 index 000000000000..81a9f605a887 --- /dev/null +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import java.lang.ref.ReferenceQueue +import java.util.concurrent.ConcurrentHashMap + +import org.apache.mxnet.Base.CPtrAddress +import org.mockito.Matchers.any +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, TagAnnotation} +import org.mockito.Mockito._ + +@TagAnnotation("resource") +class NativeResourceSuite extends FunSuite with BeforeAndAfterAll with Matchers { + + object TestRef { + def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.refQ} + def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress] + = {NativeResourceRef.refMap} + def getCleaner: Thread = { NativeResourceRef.cleaner } + } + + class TestRef(resource: NativeResource, + resourceDeAllocator: CPtrAddress => Int) + extends NativeResourceRef(resource, resourceDeAllocator) { + } + + test(testName = "test native resource setup/teardown") { + val a = spy(NDArray.ones(Shape(2, 3))) + val aRef = a.ref + val spyRef = spy(aRef) + + assert(TestRef.getRefMap.containsKey(aRef) == true) + a.close() + verify(a).dispose() + verify(a).nativeDeAllocator + // resourceDeAllocator does not get called when explicitly closing + verify(spyRef, times(0)).resourceDeAllocator + + assert(TestRef.getRefMap.containsKey(aRef) == false) + assert(a.isDisposed == true, "isDisposed should be set to true after calling close") + } + + test(testName = "test dispose") { + val a: NDArray = spy(NDArray.ones(Shape(3, 4))) + val aRef = a.ref + val spyRef = spy(aRef) + a.dispose() + verify(a).nativeDeAllocator + assert(TestRef.getRefMap.containsKey(aRef) == false) + assert(a.isDisposed == true, "isDisposed should be set to true after calling close") + } +} + diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala new file mode 100644 index 000000000000..41dfa7d0ead2 --- /dev/null +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import java.lang.ref.ReferenceQueue +import java.util.concurrent.ConcurrentHashMap + +import org.apache.mxnet.Base.CPtrAddress +import org.apache.mxnet.ResourceScope.logger +import org.mockito.Matchers.any +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.mockito.Mockito._ +import scala.collection.mutable.HashMap + +class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { + + class TestNativeResource extends NativeResource { + /** + * native Address associated with this object + */ + override def nativeAddress: CPtrAddress = hashCode() + + /** + * Function Pointer to the NativeDeAllocator of nativeAddress + */ + override def nativeDeAllocator: CPtrAddress => Int = TestNativeResource.deAllocator + + /** Call NativeResource.register to get the reference + */ + override val ref: NativeResourceRef = super.register() + /** + * Off-Heap Bytes Allocated for this object + */ + override val bytesAllocated: Long = 0 + } + object TestNativeResource { + def deAllocator(handle: CPtrAddress): Int = 0 + } + + object TestPhantomRef { + def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.refQ} + def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress] + = {NativeResourceRef.refMap} + def getCleaner: Thread = { NativeResourceRef.cleaner } + + } + + class TestPhantomRef(resource: NativeResource, + resourceDeAllocator: CPtrAddress => Int) + extends NativeResourceRef(resource, resourceDeAllocator) { + } + + test(testName = "test NDArray Auto Release") { + var a: NDArray = null + var aRef: NativeResourceRef = null + var b: NDArray = null + + ResourceScope.using() { + b = ResourceScope.using() { + a = NDArray.ones(Shape(3, 4)) + aRef = a.ref + val x = NDArray.ones(Shape(3, 4)) + x + } + val bRef: NativeResourceRef = b.ref + assert(a.isDisposed == true, + "objects created within scope should have isDisposed set to true") + assert(b.isDisposed == false, + "returned NativeResource should not be released") + assert(TestPhantomRef.getRefMap.containsKey(aRef) == false, + "reference of resource in Scope should be removed refMap") + assert(TestPhantomRef.getRefMap.containsKey(bRef) == true, + "reference of resource outside scope should be not removed refMap") + } + assert(b.isDisposed, "resource returned from inner scope should be released in outer scope") + } + + test("test return object release from outer scope") { + var a: TestNativeResource = null + ResourceScope.using() { + a = ResourceScope.using() { + new TestNativeResource() + } + assert(a.isDisposed == false, "returned object should not be disposed within Using") + } + assert(a.isDisposed == true, "returned object should be disposed in the outer scope") + } + + test(testName = "test NativeResources in returned Lists are not disposed") { + var ndListRet: IndexedSeq[TestNativeResource] = null + ResourceScope.using() { + ndListRet = ResourceScope.using() { + val ndList: IndexedSeq[TestNativeResource] = + IndexedSeq(new TestNativeResource(), new TestNativeResource()) + ndList + } + ndListRet.foreach(nd => assert(nd.isDisposed == false, + "NativeResources within a returned collection should not be disposed")) + } + ndListRet.foreach(nd => assert(nd.isDisposed == true, + "NativeResources returned from inner scope should be disposed in outer scope")) + } + + test("test native resource inside a map") { + var nRInKeyOfMap: HashMap[TestNativeResource, String] = null + var nRInValOfMap: HashMap[String, TestNativeResource] = HashMap[String, TestNativeResource]() + + ResourceScope.using() { + nRInKeyOfMap = ResourceScope.using() { + val ret = HashMap[TestNativeResource, String]() + ret.put(new TestNativeResource, "hello") + ret + } + assert(!nRInKeyOfMap.isEmpty) + + nRInKeyOfMap.keysIterator.foreach(it => assert(it.isDisposed == false, + "NativeResources returned in Traversable should not be disposed")) + } + + nRInKeyOfMap.keysIterator.foreach(it => assert(it.isDisposed)) + + ResourceScope.using() { + + nRInValOfMap = ResourceScope.using() { + val ret = HashMap[String, TestNativeResource]() + ret.put("world!", new TestNativeResource) + ret + } + assert(!nRInValOfMap.isEmpty) + nRInValOfMap.valuesIterator.foreach(it => assert(it.isDisposed == false, + "NativeResources returned in Collection should not be disposed")) + } + nRInValOfMap.valuesIterator.foreach(it => assert(it.isDisposed)) + } + +} diff --git a/scala-package/examples/scripts/run_train_mnist.sh b/scala-package/examples/scripts/run_train_mnist.sh new file mode 100755 index 000000000000..ea53c1ade66f --- /dev/null +++ b/scala-package/examples/scripts/run_train_mnist.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd) +echo $MXNET_ROOT +CLASS_PATH=$MXNET_ROOT/scala-package/assembly/linux-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/* + +# model dir +DATA_PATH=$2 + +java -XX:+PrintGC -Xms256M -Xmx512M -Dmxnet.traceLeakedObjects=false -cp $CLASS_PATH \ + org.apache.mxnetexamples.imclassification.TrainMnist \ + --data-dir /home/ubuntu/mxnet_scala/scala-package/examples/mnist/ \ + --num-epochs 10000000 \ + --batch-size 1024 \ No newline at end of file